|
29 | 29 | import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
|
30 | 30 | import static com.oracle.graal.python.runtime.exception.PythonErrorType.ValueError;
|
31 | 31 |
|
| 32 | +import java.nio.ByteOrder; |
32 | 33 | import java.util.List;
|
33 | 34 |
|
34 | 35 | import com.oracle.graal.python.annotations.ArgumentClinic;
|
|
40 | 41 | import com.oracle.graal.python.builtins.objects.array.ArrayBuiltins;
|
41 | 42 | import com.oracle.graal.python.builtins.objects.array.ArrayNodes;
|
42 | 43 | import com.oracle.graal.python.builtins.objects.array.PArray;
|
| 44 | +import com.oracle.graal.python.builtins.objects.array.PArray.MachineFormat; |
43 | 45 | import com.oracle.graal.python.builtins.objects.bytes.PBytes;
|
44 | 46 | import com.oracle.graal.python.builtins.objects.bytes.PBytesLike;
|
45 | 47 | import com.oracle.graal.python.builtins.objects.common.SequenceNodes;
|
|
58 | 60 | import com.oracle.graal.python.nodes.function.builtins.PythonVarargsBuiltinNode;
|
59 | 61 | import com.oracle.graal.python.nodes.function.builtins.clinic.ArgumentClinicProvider;
|
60 | 62 | import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
|
| 63 | +import com.oracle.graal.python.nodes.util.CastToJavaStringNode; |
61 | 64 | import com.oracle.graal.python.nodes.util.SplitArgsNode;
|
62 | 65 | import com.oracle.graal.python.runtime.PythonCore;
|
63 | 66 | import com.oracle.graal.python.runtime.exception.PException;
|
@@ -192,6 +195,19 @@ PArray arrayWithBytesInitializer(VirtualFrame frame, Object cls, String typeCode
|
192 | 195 | return array;
|
193 | 196 | }
|
194 | 197 |
|
| 198 | + @Specialization(guards = "isString(initializer)") |
| 199 | + PArray arrayWithStringInitializer(VirtualFrame frame, Object cls, String typeCode, Object initializer, |
| 200 | + @Cached CastToJavaStringNode cast, |
| 201 | + @Cached ArrayBuiltins.FromUnicodeNode fromUnicodeNode) { |
| 202 | + BufferFormat format = getFormatChecked(typeCode); |
| 203 | + if (format != BufferFormat.UNICODE) { |
| 204 | + throw raise(TypeError, "cannot use a str to initialize an array with typecode '%s'", typeCode); |
| 205 | + } |
| 206 | + PArray array = getFactory().createArray(cls, typeCode, format); |
| 207 | + fromUnicodeNode.execute(frame, array, cast.execute(initializer)); |
| 208 | + return array; |
| 209 | + } |
| 210 | + |
195 | 211 | @Specialization
|
196 | 212 | PArray arrayArrayInitializer(VirtualFrame frame, Object cls, String typeCode, PArray initializer,
|
197 | 213 | @Cached ArrayNodes.PutValueNode putValueNode,
|
@@ -230,7 +246,7 @@ PArray arraySequenceInitializer(VirtualFrame frame, Object cls, String typeCode,
|
230 | 246 | }
|
231 | 247 | }
|
232 | 248 |
|
233 |
| - @Specialization(guards = "!isBytes(initializer)", limit = "3") |
| 249 | + @Specialization(guards = {"!isBytes(initializer)", "!isString(initializer)"}, limit = "3") |
234 | 250 | PArray arrayIteratorInitializer(VirtualFrame frame, Object cls, String typeCode, Object initializer,
|
235 | 251 | @CachedLibrary("initializer") PythonObjectLibrary lib,
|
236 | 252 | @Cached ArrayNodes.PutValueNode putValueNode,
|
@@ -306,25 +322,64 @@ private PythonObjectFactory getFactory() {
|
306 | 322 | @ArgumentClinic(name = "mformatCode", conversion = ArgumentClinic.ClinicConversion.Index, defaultValue = "0")
|
307 | 323 | @GenerateNodeFactory
|
308 | 324 | abstract static class ArrayReconstructorNode extends PythonClinicBuiltinNode {
|
309 |
| - @Specialization |
| 325 | + @Specialization(guards = "mformatCode == cachedCode") |
| 326 | + Object reconstructCached(VirtualFrame frame, Object arrayType, String typeCode, @SuppressWarnings("unused") int mformatCode, PBytes bytes, |
| 327 | + @Cached("mformatCode") int cachedCode, |
| 328 | + @Cached("createIdentityProfile()") ValueProfile formatProfile, |
| 329 | + @CachedLibrary(limit = "2") PythonObjectLibrary lib, |
| 330 | + @Cached ArrayBuiltins.FromBytesNode fromBytesNode, |
| 331 | + @Cached ArrayBuiltins.FromUnicodeNode fromUnicodeNode, |
| 332 | + @Cached IsSubtypeNode isSubtypeNode, |
| 333 | + @Cached ArrayBuiltins.ByteSwapNode byteSwapNode) { |
| 334 | + BufferFormat format = BufferFormat.forArray(typeCode); |
| 335 | + if (format == null) { |
| 336 | + throw raise(ValueError, "bad typecode (must be b, B, u, h, H, i, I, l, L, q, Q, f or d)"); |
| 337 | + } |
| 338 | + return doReconstruct(frame, arrayType, typeCode, cachedCode, bytes, lib, fromBytesNode, fromUnicodeNode, isSubtypeNode, byteSwapNode, formatProfile.profile(format)); |
| 339 | + } |
| 340 | + |
| 341 | + @Specialization(replaces = "reconstructCached") |
310 | 342 | Object reconstruct(VirtualFrame frame, Object arrayType, String typeCode, int mformatCode, PBytes bytes,
|
| 343 | + @CachedLibrary(limit = "2") PythonObjectLibrary lib, |
311 | 344 | @Cached ArrayBuiltins.FromBytesNode fromBytesNode,
|
312 |
| - @Cached IsSubtypeNode isSubtypeNode) { |
| 345 | + @Cached ArrayBuiltins.FromUnicodeNode fromUnicodeNode, |
| 346 | + @Cached IsSubtypeNode isSubtypeNode, |
| 347 | + @Cached ArrayBuiltins.ByteSwapNode byteSwapNode) { |
313 | 348 | BufferFormat format = BufferFormat.forArray(typeCode);
|
314 | 349 | if (format == null) {
|
315 | 350 | throw raise(ValueError, "bad typecode (must be b, B, u, h, H, i, I, l, L, q, Q, f or d)");
|
316 | 351 | }
|
| 352 | + return doReconstruct(frame, arrayType, typeCode, mformatCode, bytes, lib, fromBytesNode, fromUnicodeNode, isSubtypeNode, byteSwapNode, format); |
| 353 | + } |
| 354 | + |
| 355 | + private Object doReconstruct(VirtualFrame frame, Object arrayType, String typeCode, int mformatCode, PBytes bytes, PythonObjectLibrary lib, |
| 356 | + ArrayBuiltins.FromBytesNode fromBytesNode, ArrayBuiltins.FromUnicodeNode fromUnicodeNode, IsSubtypeNode isSubtypeNode, |
| 357 | + ArrayBuiltins.ByteSwapNode byteSwapNode, BufferFormat format) { |
317 | 358 | if (!isSubtypeNode.execute(frame, arrayType, PythonBuiltinClassType.PArray)) {
|
318 | 359 | throw raise(TypeError, "%n is not a subtype of array", arrayType);
|
319 | 360 | }
|
320 |
| - PArray.MachineFormat expectedFormat = PArray.MachineFormat.forFormat(format); |
321 |
| - if (expectedFormat != null && expectedFormat.code == mformatCode) { |
322 |
| - PArray array = factory().createArray(arrayType, typeCode, format); |
323 |
| - fromBytesNode.execute(frame, array, bytes); |
| 361 | + MachineFormat machineFormat = MachineFormat.fromCode(mformatCode); |
| 362 | + if (machineFormat != null) { |
| 363 | + PArray array; |
| 364 | + if (machineFormat == MachineFormat.forFormat(format)) { |
| 365 | + array = factory().createArray(arrayType, typeCode, machineFormat.format); |
| 366 | + fromBytesNode.execute(frame, array, bytes); |
| 367 | + } else { |
| 368 | + String newTypeCode = machineFormat.format == format ? typeCode : machineFormat.format.baseTypeCode; |
| 369 | + array = factory().createArray(arrayType, newTypeCode, machineFormat.format); |
| 370 | + if (machineFormat.unicodeEncoding != null) { |
| 371 | + Object decoded = lib.lookupAndCallRegularMethod(bytes, frame, "decode", machineFormat.unicodeEncoding); |
| 372 | + fromUnicodeNode.execute(frame, array, decoded); |
| 373 | + } else { |
| 374 | + fromBytesNode.execute(frame, array, bytes); |
| 375 | + if (machineFormat.order != ByteOrder.nativeOrder()) { |
| 376 | + byteSwapNode.call(frame, array); |
| 377 | + } |
| 378 | + } |
| 379 | + } |
324 | 380 | return array;
|
325 | 381 | } else {
|
326 |
| - // TODO implement decoding for arrays pickled on a machine of different architecture |
327 |
| - throw raise(PythonBuiltinClassType.NotImplementedError, "Cannot decode array format"); |
| 382 | + throw raise(ValueError, "third argument must be a valid machine format code."); |
328 | 383 | }
|
329 | 384 | }
|
330 | 385 |
|
|
0 commit comments