56
56
import com .oracle .graal .python .builtins .objects .cext .NativeCAPISymbols ;
57
57
import com .oracle .graal .python .builtins .objects .common .SequenceNodes ;
58
58
import com .oracle .graal .python .builtins .objects .common .SequenceStorageNodes ;
59
+ import com .oracle .graal .python .builtins .objects .ints .PInt ;
59
60
import com .oracle .graal .python .builtins .objects .object .PythonObjectLibrary ;
60
61
import com .oracle .graal .python .builtins .objects .tuple .PTuple ;
61
62
import com .oracle .graal .python .nodes .ErrorMessages ;
64
65
import com .oracle .graal .python .nodes .attributes .ReadAttributeFromObjectNode ;
65
66
import com .oracle .graal .python .runtime .PythonContext ;
66
67
import com .oracle .graal .python .runtime .exception .PException ;
68
+ import com .oracle .graal .python .runtime .object .PythonObjectFactory ;
67
69
import com .oracle .graal .python .runtime .sequence .storage .SequenceStorage ;
68
70
import com .oracle .truffle .api .CompilerDirectives ;
69
71
import com .oracle .truffle .api .CompilerDirectives .TruffleBoundary ;
81
83
import com .oracle .truffle .api .interop .UnsupportedTypeException ;
82
84
import com .oracle .truffle .api .library .CachedLibrary ;
83
85
import com .oracle .truffle .api .nodes .Node ;
86
+ import com .oracle .truffle .api .profiles .ConditionProfile ;
84
87
85
88
public class MemoryViewNodes {
86
89
static int bytesize (PMemoryView .BufferFormat format ) {
@@ -100,12 +103,7 @@ static int bytesize(PMemoryView.BufferFormat format) {
100
103
return 4 ;
101
104
case UNSIGNED_LONG :
102
105
case SIGNED_LONG :
103
- case UNSIGNED_SIZE :
104
- case SIGNED_SIZE :
105
- case SIGNED_LONG_LONG :
106
- case UNSIGNED_LONG_LONG :
107
106
case DOUBLE :
108
- case POINTER :
109
107
return 8 ;
110
108
}
111
109
return -1 ;
@@ -151,35 +149,96 @@ static int compute(int ndim, int itemsize, int[] shape, int[] strides, int[] sub
151
149
152
150
@ ImportStatic (PMemoryView .BufferFormat .class )
153
151
abstract static class UnpackValueNode extends Node {
154
- public abstract Object execute (PMemoryView .BufferFormat format , byte [] bytes );
152
+ // bytes are expected to already have the appropriate length
153
+ public abstract Object execute (PMemoryView .BufferFormat format , String formatStr , byte [] bytes );
155
154
156
155
@ Specialization (guards = "format == UNSIGNED_BYTE" )
157
- static int unpackUnsignedByte (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , byte [] bytes ) {
156
+ static int unpackUnsignedByte (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , @ SuppressWarnings ( "unused" ) String formatStr , byte [] bytes ) {
158
157
return bytes [0 ] & 0xFF ;
159
158
}
160
159
161
160
@ Specialization (guards = "format == SIGNED_BYTE" )
162
- static int unpackSignedByte (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , byte [] bytes ) {
161
+ static int unpackSignedByte (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , @ SuppressWarnings ( "unused" ) String formatStr , byte [] bytes ) {
163
162
return bytes [0 ];
164
163
}
165
164
166
165
@ Specialization (guards = "format == SIGNED_SHORT" )
167
- static int unpackShort (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , byte [] bytes ) {
168
- return (bytes [0 ] & 0xFF ) | (bytes [1 ] & 0xFF ) << 8 ;
166
+ static int unpackSignedShort (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , @ SuppressWarnings ("unused" ) String formatStr , byte [] bytes ) {
167
+ return unpackInt16 (bytes );
168
+ }
169
+
170
+ @ Specialization (guards = "format == UNSIGNED_SHORT" )
171
+ static int unpackUnsignedShort (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , @ SuppressWarnings ("unused" ) String formatStr , byte [] bytes ) {
172
+ return unpackInt16 (bytes ) & 0xFFFF ;
169
173
}
170
174
171
175
@ Specialization (guards = "format == SIGNED_INT" )
172
- static int unpackInt (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , byte [] bytes ) {
173
- return (bytes [0 ] & 0xFF ) | (bytes [1 ] & 0xFF ) << 8 | (bytes [2 ] & 0xFF ) << 16 | (bytes [3 ] & 0xFF ) << 24 ;
176
+ static int unpackSignedInt (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , @ SuppressWarnings ("unused" ) String formatStr , byte [] bytes ) {
177
+ return unpackInt32 (bytes );
178
+ }
179
+
180
+ @ Specialization (guards = "format == UNSIGNED_INT" )
181
+ static long unpackUnsignedInt (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , @ SuppressWarnings ("unused" ) String formatStr , byte [] bytes ) {
182
+ return unpackInt32 (bytes ) & 0xFFFFFFFFL ;
174
183
}
175
184
176
185
@ Specialization (guards = "format == SIGNED_LONG" )
177
- static long unpackLong (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , byte [] bytes ) {
178
- return (bytes [0 ] & 0xFF ) | (bytes [1 ] & 0xFF ) << 8 | (bytes [2 ] & 0xFF ) << 16 | (bytes [3 ] & 0xFF ) << 24 |
179
- (bytes [4 ] & 0xFFL ) << 32 | (bytes [5 ] & 0xFFL ) << 40 | (bytes [6 ] & 0xFFL ) << 48 | (bytes [7 ] & 0xFFL ) << 56 ;
186
+ static long unpackSignedLong (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , @ SuppressWarnings ("unused" ) String formatStr , byte [] bytes ) {
187
+ return unpackInt64 (bytes );
188
+ }
189
+
190
+ @ Specialization (guards = "format == UNSIGNED_LONG" )
191
+ static Object unpackUnsignedLong (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , @ SuppressWarnings ("unused" ) String formatStr , byte [] bytes ,
192
+ @ Cached ConditionProfile needsPIntProfile ,
193
+ @ Shared ("factory" ) @ Cached PythonObjectFactory factory ) {
194
+ long signedLong = unpackInt64 (bytes );
195
+ if (needsPIntProfile .profile (signedLong < 0 )) {
196
+ return factory .createInt (PInt .longToUnsignedBigInteger (signedLong ));
197
+ } else {
198
+ return signedLong ;
199
+ }
200
+ }
201
+
202
+ @ Specialization (guards = "format == FLOAT" )
203
+ static double unpackFloat (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , @ SuppressWarnings ("unused" ) String formatStr , byte [] bytes ) {
204
+ return Float .intBitsToFloat (unpackInt32 (bytes ));
205
+ }
206
+
207
+ @ Specialization (guards = "format == DOUBLE" )
208
+ static double unpackDouble (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , @ SuppressWarnings ("unused" ) String formatStr , byte [] bytes ) {
209
+ return Double .longBitsToDouble (unpackInt64 (bytes ));
180
210
}
181
211
182
- // TODO rest of formats
212
+ @ Specialization (guards = "format == BOOLEAN" )
213
+ static boolean unpackBoolean (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , @ SuppressWarnings ("unused" ) String formatStr , byte [] bytes ) {
214
+ return bytes [0 ] != 0 ;
215
+ }
216
+
217
+ @ Specialization (guards = "format == CHAR" )
218
+ static Object unpackChar (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , @ SuppressWarnings ("unused" ) String formatStr , byte [] bytes ,
219
+ @ Shared ("factory" ) @ Cached PythonObjectFactory factory ) {
220
+ assert bytes .length == 1 ;
221
+ return factory .createBytes (bytes );
222
+ }
223
+
224
+ @ Specialization (guards = "format == OTHER" )
225
+ static Object notImplemented (@ SuppressWarnings ("unused" ) PMemoryView .BufferFormat format , String formatStr , @ SuppressWarnings ("unused" ) byte [] bytes ,
226
+ @ Cached PRaiseNode raiseNode ) {
227
+ throw raiseNode .raise (NotImplementedError , ErrorMessages .MEMORYVIEW_FORMAT_S_NOT_SUPPORTED , formatStr );
228
+ }
229
+
230
+ private static short unpackInt16 (byte [] bytes ) {
231
+ return (short ) ((bytes [0 ] & 0xFF ) | (bytes [1 ] & 0xFF ) << 8 );
232
+ }
233
+
234
+ private static int unpackInt32 (byte [] bytes ) {
235
+ return (bytes [0 ] & 0xFF ) | (bytes [1 ] & 0xFF ) << 8 | (bytes [2 ] & 0xFF ) << 16 | (bytes [3 ] & 0xFF ) << 24 ;
236
+ }
237
+
238
+ private static long unpackInt64 (byte [] bytes ) {
239
+ return (bytes [0 ] & 0xFFL ) | (bytes [1 ] & 0xFFL ) << 8 | (bytes [2 ] & 0xFFL ) << 16 | (bytes [3 ] & 0xFFL ) << 24 |
240
+ (bytes [4 ] & 0xFFL ) << 32 | (bytes [5 ] & 0xFFL ) << 40 | (bytes [6 ] & 0xFFL ) << 48 | (bytes [7 ] & 0xFFL ) << 56 ;
241
+ }
183
242
}
184
243
185
244
@ ImportStatic (PMemoryView .BufferFormat .class )
@@ -292,7 +351,7 @@ static Object doNative(PMemoryView self, Object ptr, int offset,
292
351
} catch (UnsupportedMessageException | InvalidArrayIndexException e ) {
293
352
throw CompilerDirectives .shouldNotReachHere ("native buffer read failed" );
294
353
}
295
- return unpackValueNode .execute (self .getFormat (), bytes );
354
+ return unpackValueNode .execute (self .getFormat (), self . getFormatString (), bytes );
296
355
}
297
356
298
357
@ Specialization (guards = "ptr == null" )
@@ -305,7 +364,7 @@ static Object doManaged(PMemoryView self, @SuppressWarnings("unused") Object ptr
305
364
for (int i = 0 ; i < self .getItemSize (); i ++) {
306
365
bytes [i ] = (byte ) getItemNode .executeInt (getStorageNode .execute (self .getOwner ()), offset + i );
307
366
}
308
- return unpackValueNode .execute (self .getFormat (), bytes );
367
+ return unpackValueNode .execute (self .getFormat (), self . getFormatString (), bytes );
309
368
}
310
369
}
311
370
0 commit comments