Skip to content

Commit dfa325b

Browse files
committed
Implement unpickling arrays from different architectures
1 parent 928eedb commit dfa325b

File tree

6 files changed

+114
-39
lines changed

6 files changed

+114
-39
lines changed

graalpython/com.oracle.graal.python.test/src/tests/unittest_tags/test_array.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
*graalpython.lib-python.3.test.test_array.ArrayReconstructorTest.test_error
2+
*graalpython.lib-python.3.test.test_array.ArrayReconstructorTest.test_numbers
3+
*graalpython.lib-python.3.test.test_array.ArrayReconstructorTest.test_unicode
14
*graalpython.lib-python.3.test.test_array.ByteTest.test_add
25
*graalpython.lib-python.3.test.test_array.ByteTest.test_assignment
36
*graalpython.lib-python.3.test.test_array.ByteTest.test_buffer_info

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ArrayModuleBuiltins.java

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
3030
import static com.oracle.graal.python.runtime.exception.PythonErrorType.ValueError;
3131

32+
import java.nio.ByteOrder;
3233
import java.util.List;
3334

3435
import com.oracle.graal.python.annotations.ArgumentClinic;
@@ -40,6 +41,7 @@
4041
import com.oracle.graal.python.builtins.objects.array.ArrayBuiltins;
4142
import com.oracle.graal.python.builtins.objects.array.ArrayNodes;
4243
import com.oracle.graal.python.builtins.objects.array.PArray;
44+
import com.oracle.graal.python.builtins.objects.array.PArray.MachineFormat;
4345
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
4446
import com.oracle.graal.python.builtins.objects.bytes.PBytesLike;
4547
import com.oracle.graal.python.builtins.objects.common.SequenceNodes;
@@ -58,6 +60,7 @@
5860
import com.oracle.graal.python.nodes.function.builtins.PythonVarargsBuiltinNode;
5961
import com.oracle.graal.python.nodes.function.builtins.clinic.ArgumentClinicProvider;
6062
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
63+
import com.oracle.graal.python.nodes.util.CastToJavaStringNode;
6164
import com.oracle.graal.python.nodes.util.SplitArgsNode;
6265
import com.oracle.graal.python.runtime.PythonCore;
6366
import com.oracle.graal.python.runtime.exception.PException;
@@ -192,6 +195,19 @@ PArray arrayWithBytesInitializer(VirtualFrame frame, Object cls, String typeCode
192195
return array;
193196
}
194197

198+
@Specialization(guards = "isString(initializer)")
199+
PArray arrayWithStringInitializer(VirtualFrame frame, Object cls, String typeCode, Object initializer,
200+
@Cached CastToJavaStringNode cast,
201+
@Cached ArrayBuiltins.FromUnicodeNode fromUnicodeNode) {
202+
BufferFormat format = getFormatChecked(typeCode);
203+
if (format != BufferFormat.UNICODE) {
204+
throw raise(TypeError, "cannot use a str to initialize an array with typecode '%s'", typeCode);
205+
}
206+
PArray array = getFactory().createArray(cls, typeCode, format);
207+
fromUnicodeNode.execute(frame, array, cast.execute(initializer));
208+
return array;
209+
}
210+
195211
@Specialization
196212
PArray arrayArrayInitializer(VirtualFrame frame, Object cls, String typeCode, PArray initializer,
197213
@Cached ArrayNodes.PutValueNode putValueNode,
@@ -230,7 +246,7 @@ PArray arraySequenceInitializer(VirtualFrame frame, Object cls, String typeCode,
230246
}
231247
}
232248

233-
@Specialization(guards = "!isBytes(initializer)", limit = "3")
249+
@Specialization(guards = {"!isBytes(initializer)", "!isString(initializer)"}, limit = "3")
234250
PArray arrayIteratorInitializer(VirtualFrame frame, Object cls, String typeCode, Object initializer,
235251
@CachedLibrary("initializer") PythonObjectLibrary lib,
236252
@Cached ArrayNodes.PutValueNode putValueNode,
@@ -306,25 +322,64 @@ private PythonObjectFactory getFactory() {
306322
@ArgumentClinic(name = "mformatCode", conversion = ArgumentClinic.ClinicConversion.Index, defaultValue = "0")
307323
@GenerateNodeFactory
308324
abstract static class ArrayReconstructorNode extends PythonClinicBuiltinNode {
309-
@Specialization
325+
@Specialization(guards = "mformatCode == cachedCode")
326+
Object reconstructCached(VirtualFrame frame, Object arrayType, String typeCode, @SuppressWarnings("unused") int mformatCode, PBytes bytes,
327+
@Cached("mformatCode") int cachedCode,
328+
@Cached("createIdentityProfile()") ValueProfile formatProfile,
329+
@CachedLibrary(limit = "2") PythonObjectLibrary lib,
330+
@Cached ArrayBuiltins.FromBytesNode fromBytesNode,
331+
@Cached ArrayBuiltins.FromUnicodeNode fromUnicodeNode,
332+
@Cached IsSubtypeNode isSubtypeNode,
333+
@Cached ArrayBuiltins.ByteSwapNode byteSwapNode) {
334+
BufferFormat format = BufferFormat.forArray(typeCode);
335+
if (format == null) {
336+
throw raise(ValueError, "bad typecode (must be b, B, u, h, H, i, I, l, L, q, Q, f or d)");
337+
}
338+
return doReconstruct(frame, arrayType, typeCode, cachedCode, bytes, lib, fromBytesNode, fromUnicodeNode, isSubtypeNode, byteSwapNode, formatProfile.profile(format));
339+
}
340+
341+
@Specialization(replaces = "reconstructCached")
310342
Object reconstruct(VirtualFrame frame, Object arrayType, String typeCode, int mformatCode, PBytes bytes,
343+
@CachedLibrary(limit = "2") PythonObjectLibrary lib,
311344
@Cached ArrayBuiltins.FromBytesNode fromBytesNode,
312-
@Cached IsSubtypeNode isSubtypeNode) {
345+
@Cached ArrayBuiltins.FromUnicodeNode fromUnicodeNode,
346+
@Cached IsSubtypeNode isSubtypeNode,
347+
@Cached ArrayBuiltins.ByteSwapNode byteSwapNode) {
313348
BufferFormat format = BufferFormat.forArray(typeCode);
314349
if (format == null) {
315350
throw raise(ValueError, "bad typecode (must be b, B, u, h, H, i, I, l, L, q, Q, f or d)");
316351
}
352+
return doReconstruct(frame, arrayType, typeCode, mformatCode, bytes, lib, fromBytesNode, fromUnicodeNode, isSubtypeNode, byteSwapNode, format);
353+
}
354+
355+
private Object doReconstruct(VirtualFrame frame, Object arrayType, String typeCode, int mformatCode, PBytes bytes, PythonObjectLibrary lib,
356+
ArrayBuiltins.FromBytesNode fromBytesNode, ArrayBuiltins.FromUnicodeNode fromUnicodeNode, IsSubtypeNode isSubtypeNode,
357+
ArrayBuiltins.ByteSwapNode byteSwapNode, BufferFormat format) {
317358
if (!isSubtypeNode.execute(frame, arrayType, PythonBuiltinClassType.PArray)) {
318359
throw raise(TypeError, "%n is not a subtype of array", arrayType);
319360
}
320-
PArray.MachineFormat expectedFormat = PArray.MachineFormat.forFormat(format);
321-
if (expectedFormat != null && expectedFormat.code == mformatCode) {
322-
PArray array = factory().createArray(arrayType, typeCode, format);
323-
fromBytesNode.execute(frame, array, bytes);
361+
MachineFormat machineFormat = MachineFormat.fromCode(mformatCode);
362+
if (machineFormat != null) {
363+
PArray array;
364+
if (machineFormat == MachineFormat.forFormat(format)) {
365+
array = factory().createArray(arrayType, typeCode, machineFormat.format);
366+
fromBytesNode.execute(frame, array, bytes);
367+
} else {
368+
String newTypeCode = machineFormat.format == format ? typeCode : machineFormat.format.baseTypeCode;
369+
array = factory().createArray(arrayType, newTypeCode, machineFormat.format);
370+
if (machineFormat.unicodeEncoding != null) {
371+
Object decoded = lib.lookupAndCallRegularMethod(bytes, frame, "decode", machineFormat.unicodeEncoding);
372+
fromUnicodeNode.execute(frame, array, decoded);
373+
} else {
374+
fromBytesNode.execute(frame, array, bytes);
375+
if (machineFormat.order != ByteOrder.nativeOrder()) {
376+
byteSwapNode.call(frame, array);
377+
}
378+
}
379+
}
324380
return array;
325381
} else {
326-
// TODO implement decoding for arrays pickled on a machine of different architecture
327-
throw raise(PythonBuiltinClassType.NotImplementedError, "Cannot decode array format");
382+
throw raise(ValueError, "third argument must be a valid machine format code.");
328383
}
329384
}
330385

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/array/ArrayBuiltins.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,19 +1010,19 @@ static Object fromother(VirtualFrame frame, PArray self, Object str,
10101010
@Builtin(name = "fromunicode", minNumOfPositionalArgs = 2, numOfPositionalOnlyArgs = 2, parameterNames = {"$self", "str"})
10111011
@ArgumentClinic(name = "str", conversion = ArgumentClinic.ClinicConversion.String)
10121012
@GenerateNodeFactory
1013-
abstract static class FromUnicodeNode extends PythonBinaryClinicBuiltinNode {
1013+
public abstract static class FromUnicodeNode extends PythonBinaryClinicBuiltinNode {
10141014
@Specialization
10151015
Object fromunicode(VirtualFrame frame, PArray self, String str,
10161016
@Cached ArrayNodes.PutValueNode putValueNode) {
10171017
try {
10181018
int length = PString.codePointCount(str, 0, str.length());
10191019
int newLength = PythonUtils.addExact(self.getLength(), length);
10201020
self.resizeStorage(newLength);
1021-
for (int i = 0, index = 0; i < length; index++) {
1022-
int cpCount = PString.charCount(PString.codePointAt(str, i));
1023-
String value = PString.substring(str, i, i + cpCount);
1024-
putValueNode.execute(frame, self, self.getLength() + index, value);
1025-
i += cpCount;
1021+
for (int codePointIndex = 0, charIndex = 0; codePointIndex < length; codePointIndex++) {
1022+
int charCount = PString.charCount(PString.codePointAt(str, charIndex));
1023+
String value = PString.substring(str, charIndex, charIndex + charCount);
1024+
putValueNode.execute(frame, self, self.getLength() + codePointIndex, value);
1025+
charIndex += charCount;
10261026
}
10271027
self.setLength(newLength);
10281028
return PNone.NONE;
@@ -1123,7 +1123,7 @@ Object tofile(VirtualFrame frame, PArray self, Object file,
11231123

11241124
@Builtin(name = "byteswap", minNumOfPositionalArgs = 1)
11251125
@GenerateNodeFactory
1126-
abstract static class ByteSwapNode extends PythonUnaryBuiltinNode {
1126+
public abstract static class ByteSwapNode extends PythonUnaryBuiltinNode {
11271127

11281128
@Specialization(guards = "self.getFormat().bytesize == 1")
11291129
static Object byteswap1(@SuppressWarnings("unused") PArray self) {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/array/PArray.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ int getBufferLength() {
170170
}
171171

172172
public enum MachineFormat {
173-
UNKNOWN_FORMAT(-1, null, null),
174173
UNSIGNED_INT8(0, BufferFormat.UINT_8, null),
175174
SIGNED_INT8(1, BufferFormat.INT_8, null),
176175
UNSIGNED_INT16_LE(2, BufferFormat.UINT_16, ByteOrder.LITTLE_ENDIAN),
@@ -189,20 +188,26 @@ public enum MachineFormat {
189188
IEEE_754_FLOAT_BE(15, BufferFormat.FLOAT, ByteOrder.BIG_ENDIAN),
190189
IEEE_754_DOUBLE_LE(16, BufferFormat.DOUBLE, ByteOrder.LITTLE_ENDIAN),
191190
IEEE_754_DOUBLE_BE(17, BufferFormat.DOUBLE, ByteOrder.BIG_ENDIAN),
192-
// TODO
193-
UTF16_LE(18, null, ByteOrder.LITTLE_ENDIAN),
194-
UTF16_BE(19, null, ByteOrder.BIG_ENDIAN),
195-
UTF32_LE(20, BufferFormat.UNICODE, ByteOrder.LITTLE_ENDIAN),
196-
UTF32_BE(21, BufferFormat.UNICODE, ByteOrder.BIG_ENDIAN);
191+
UTF32_LE(20, BufferFormat.UNICODE, ByteOrder.LITTLE_ENDIAN, "utf-32-le"),
192+
UTF32_BE(21, BufferFormat.UNICODE, ByteOrder.BIG_ENDIAN, "utf-32-be"),
193+
// These two need to come after UTF32, so that forFormat doesn't pick them for UNICODE
194+
UTF16_LE(18, BufferFormat.UNICODE, ByteOrder.LITTLE_ENDIAN, "utf-16-le"),
195+
UTF16_BE(19, BufferFormat.UNICODE, ByteOrder.BIG_ENDIAN, "utf-16-be");
197196

198197
public final int code;
199198
public final BufferFormat format;
200199
public final ByteOrder order;
200+
public final String unicodeEncoding;
201201

202202
MachineFormat(int code, BufferFormat format, ByteOrder order) {
203+
this(code, format, order, null);
204+
}
205+
206+
MachineFormat(int code, BufferFormat format, ByteOrder order, String unicodeEncoding) {
203207
this.code = code;
204208
this.format = format;
205209
this.order = order;
210+
this.unicodeEncoding = unicodeEncoding;
206211
}
207212

208213
@ExplodeLoop
@@ -214,5 +219,15 @@ public static MachineFormat forFormat(BufferFormat format) {
214219
}
215220
return null;
216221
}
222+
223+
@ExplodeLoop
224+
public static MachineFormat fromCode(int code) {
225+
for (MachineFormat machineFormat : MachineFormat.values()) {
226+
if (machineFormat.code == code) {
227+
return machineFormat;
228+
}
229+
}
230+
return null;
231+
}
217232
}
218233
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/util/BufferFormat.java

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,27 +48,29 @@
4848
* string around for error messages.
4949
*/
5050
public enum BufferFormat {
51-
UINT_8(1),
52-
INT_8(1),
53-
UINT_16(2),
54-
INT_16(2),
55-
UINT_32(4),
56-
INT_32(4),
57-
UINT_64(8),
58-
INT_64(8),
59-
FLOAT(4),
60-
DOUBLE(8),
51+
UINT_8(1, "B"),
52+
INT_8(1, "b"),
53+
UINT_16(2, "H"),
54+
INT_16(2, "h"),
55+
UINT_32(4, "I"),
56+
INT_32(4, "i"),
57+
UINT_64(8, "L"),
58+
INT_64(8, "l"),
59+
FLOAT(4, "f"),
60+
DOUBLE(8, "d"),
6161
// Unicode is array-only and deprecated
62-
UNICODE(4),
62+
UNICODE(4, "u"),
6363
// The following are memoryview-only
64-
CHAR(1),
65-
BOOLEAN(1),
66-
OTHER(-1);
64+
CHAR(1, "c"),
65+
BOOLEAN(1, "?"),
66+
OTHER(-1, null);
6767

6868
public final int bytesize;
69+
public final String baseTypeCode;
6970

70-
BufferFormat(int bytesize) {
71+
BufferFormat(int bytesize, String baseTypeCode) {
7172
this.bytesize = bytesize;
73+
this.baseTypeCode = baseTypeCode;
7274
}
7375

7476
public static BufferFormat forMemoryView(String formatString) {

graalpython/lib-python/3/test/test_array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,10 +1069,10 @@ def test_initialize_with_unicode(self):
10691069
if self.typecode != 'u':
10701070
with self.assertRaises(TypeError) as cm:
10711071
a = array.array(self.typecode, 'foo')
1072-
# XXX Truffle change: don't dwell on exact error messages, this feature is deprecated anyway
1073-
# self.assertIn("cannot use a str", str(cm.exception))
1072+
self.assertIn("cannot use a str", str(cm.exception))
10741073
with self.assertRaises(TypeError) as cm:
10751074
a = array.array(self.typecode, array.array('u', 'foo'))
1075+
# XXX Truffle change: don't dwell on exact error messages, this feature is deprecated anyway
10761076
# self.assertIn("cannot use a unicode array", str(cm.exception))
10771077
else:
10781078
a = array.array(self.typecode, "foo")

0 commit comments

Comments
 (0)