Skip to content

Commit 45dfc3d

Browse files
committed
[GR-47856] Fix using recv_into with buffers without internal byte array
PullRequest: graalpython/2909
2 parents e834279 + 73f6b0f commit 45dfc3d

File tree

5 files changed

+83
-9
lines changed

5 files changed

+83
-9
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_socket.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -36,10 +36,10 @@
3636
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
3737
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3838
# SOFTWARE.
39-
40-
import unittest
41-
4239
import socket
40+
import sys
41+
import threading
42+
import unittest
4343

4444

4545
def test_inet_aton():
@@ -58,10 +58,43 @@ def test_inet_aton_errs(self):
5858
self.assertRaises(OSError, lambda : socket.inet_aton('255.255.256.1'))
5959
self.assertRaises(TypeError, lambda : socket.inet_aton(255))
6060

61-
def test_get_name_info():
61+
def test_get_name_info():
6262
import socket
6363
try :
6464
socket.getnameinfo((1, 0, 0, 0), 0)
6565
except TypeError:
6666
raised = True
6767
assert raised
68+
69+
70+
def test_recv_into():
71+
port = None
72+
event = threading.Event()
73+
def server():
74+
nonlocal port
75+
with socket.create_server(('localhost', 0)) as sock:
76+
port = sock.getsockname()[1]
77+
event.set()
78+
conn, addr = sock.accept()
79+
conn.send(b'123')
80+
conn.close()
81+
thread = threading.Thread(target=server)
82+
thread.start()
83+
event.wait()
84+
with socket.create_connection(('localhost', port)) as sock:
85+
# Byte buffer with direct access to internal array
86+
b = bytearray(b'aaa')
87+
sock.recv_into(b, 1)
88+
assert b == b'1aa'
89+
# Byte buffer with offset, this currently doesn't have internal array acces, but we might implement it later
90+
buffer = memoryview(b)[1:]
91+
sock.recv_into(buffer, 1)
92+
assert b == b'12a'
93+
if sys.implementation.name == 'graalpy':
94+
assert hasattr(__graalpython__, 'storage_to_native'), "Needs to be run with --python.EnableDebuggingBuiltins"
95+
__graalpython__.storage_to_native(b)
96+
# Native buffer, no internal array
97+
buffer = memoryview(buffer)[1:]
98+
sock.recv_into(buffer, 1)
99+
assert b == b'123'
100+
thread.join()

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/GraalPythonModuleBuiltins.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import static com.oracle.graal.python.nodes.SpecialMethodNames.T_INSERT;
5050
import static com.oracle.graal.python.nodes.StringLiterals.J_LLVM_LANGUAGE;
5151
import static com.oracle.graal.python.nodes.StringLiterals.T_COLON;
52+
import static com.oracle.graal.python.nodes.StringLiterals.T_EMPTY_STRING;
5253
import static com.oracle.graal.python.nodes.StringLiterals.T_PATH;
5354
import static com.oracle.graal.python.nodes.StringLiterals.T_STRICT;
5455
import static com.oracle.graal.python.nodes.StringLiterals.T_SURROGATEESCAPE;
@@ -90,7 +91,10 @@
9091
import com.oracle.graal.python.builtins.modules.GraalPythonModuleBuiltinsFactory.DebugNodeFactory;
9192
import com.oracle.graal.python.builtins.objects.PNone;
9293
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
94+
import com.oracle.graal.python.builtins.objects.bytes.PBytesLike;
9395
import com.oracle.graal.python.builtins.objects.cext.PythonAbstractNativeObject;
96+
import com.oracle.graal.python.builtins.objects.cext.capi.CApiContext;
97+
import com.oracle.graal.python.builtins.objects.cext.capi.PySequenceArrayWrapper.ToNativeStorageNode;
9498
import com.oracle.graal.python.builtins.objects.code.CodeNodes;
9599
import com.oracle.graal.python.builtins.objects.code.PCode;
96100
import com.oracle.graal.python.builtins.objects.common.DynamicObjectStorage;
@@ -140,6 +144,8 @@
140144
import com.oracle.graal.python.runtime.exception.PException;
141145
import com.oracle.graal.python.runtime.exception.PythonExitException;
142146
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
147+
import com.oracle.graal.python.runtime.sequence.PSequence;
148+
import com.oracle.graal.python.runtime.sequence.storage.NativeSequenceStorage;
143149
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
144150
import com.oracle.graal.python.util.PythonUtils;
145151
import com.oracle.truffle.api.CallTarget;
@@ -253,6 +259,7 @@ public void postInitialize(Python3Core core) {
253259
mod.setAttribute(tsLiteral("dump_truffle_ast"), PNone.NO_VALUE);
254260
mod.setAttribute(tsLiteral("tdebug"), PNone.NO_VALUE);
255261
mod.setAttribute(tsLiteral("set_storage_strategy"), PNone.NO_VALUE);
262+
mod.setAttribute(tsLiteral("storage_to_native"), PNone.NO_VALUE);
256263
mod.setAttribute(tsLiteral("dump_heap"), PNone.NO_VALUE);
257264
mod.setAttribute(tsLiteral("is_native_object"), PNone.NO_VALUE);
258265
}
@@ -779,6 +786,36 @@ private void validate(HashingStorage dictStorage) {
779786
}
780787
}
781788

789+
@Builtin(name = "storage_to_native", minNumOfPositionalArgs = 1)
790+
@GenerateNodeFactory
791+
abstract static class StorageToNative extends PythonUnaryBuiltinNode {
792+
@Specialization
793+
@TruffleBoundary
794+
Object toNative(PBytesLike bytes) {
795+
ensureCapi();
796+
NativeSequenceStorage newStorage = ToNativeStorageNode.getUncached().execute(bytes.getSequenceStorage(), true);
797+
bytes.setSequenceStorage(newStorage);
798+
return bytes;
799+
}
800+
801+
@Specialization
802+
@TruffleBoundary
803+
Object toNative(PSequence sequence) {
804+
ensureCapi();
805+
NativeSequenceStorage newStorage = ToNativeStorageNode.getUncached().execute(sequence.getSequenceStorage(), false);
806+
sequence.setSequenceStorage(newStorage);
807+
return sequence;
808+
}
809+
810+
private void ensureCapi() {
811+
try {
812+
CApiContext.ensureCapiWasLoaded(null, getContext(), T_EMPTY_STRING, T_EMPTY_STRING);
813+
} catch (Exception e) {
814+
throw CompilerDirectives.shouldNotReachHere(e);
815+
}
816+
}
817+
}
818+
782819
@Builtin(name = J_EXTEND, minNumOfPositionalArgs = 1, doc = "Extends Java class and return HostAdapterCLass")
783820
@GenerateNodeFactory
784821
public abstract static class JavaExtendNode extends PythonUnaryBuiltinNode {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/capi/PySequenceArrayWrapper.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
public final class PySequenceArrayWrapper {
6363

6464
@GenerateUncached
65-
abstract static class ToNativeStorageNode extends Node {
65+
public abstract static class ToNativeStorageNode extends Node {
6666

6767
public abstract NativeSequenceStorage execute(SequenceStorage object, boolean isBytesLike);
6868

@@ -104,6 +104,10 @@ static NativeSequenceStorage doEmptyStorage(@SuppressWarnings("unused") EmptySeq
104104
protected static boolean isNative(SequenceStorage s) {
105105
return s instanceof NativeSequenceStorage;
106106
}
107+
108+
public static ToNativeStorageNode getUncached() {
109+
return ToNativeStorageNodeGen.getUncached();
110+
}
107111
}
108112

109113
@TruffleBoundary

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/socket/SocketBuiltins.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ Object recvInto(VirtualFrame frame, PSocket socket, Object bufferObj, int recvle
680680
() -> posixLib.recv(getPosixSupport(), socket.getFd(), bytes, 0, len, flags),
681681
false, false);
682682
if (!directWrite) {
683-
bufferLib.readIntoByteArray(buffer, 0, bytes, 0, outlen);
683+
bufferLib.writeFromByteArray(buffer, 0, bytes, 0, outlen);
684684
}
685685
return outlen;
686686
} catch (PosixException e) {
@@ -743,7 +743,7 @@ Object recvFromInto(VirtualFrame frame, PSocket socket, Object bufferObj, int re
743743
() -> posixLib.recvfrom(getPosixSupport(), socket.getFd(), bytes, 0, bytes.length, flags),
744744
false, false);
745745
if (!directWrite) {
746-
bufferLib.readIntoByteArray(buffer, 0, bytes, 0, result.readBytes);
746+
bufferLib.writeFromByteArray(buffer, 0, bytes, 0, result.readBytes);
747747
}
748748
return factory().createTuple(new Object[]{result.readBytes, makeSockAddrNode.execute(frame, result.sockAddr)});
749749
} catch (PosixException e) {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/ssl/SSLSocketBuiltins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ Object readInto(VirtualFrame frame, PSSLSocket self, int len, Object bufferObj,
135135
PythonUtils.flipBuffer(output);
136136
int readBytes = PythonUtils.getBufferRemaining(output);
137137
if (!directWrite) {
138-
bufferLib.readIntoByteArray(buffer, 0, bytes, 0, readBytes);
138+
bufferLib.writeFromByteArray(buffer, 0, bytes, 0, readBytes);
139139
}
140140
return readBytes;
141141
} finally {

0 commit comments

Comments
 (0)