Skip to content

Commit 72ebdcb

Browse files
committed
Implement unpacking of remaining formats
1 parent a2a38e3 commit 72ebdcb

File tree

4 files changed

+103
-32
lines changed

4 files changed

+103
-32
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_memoryview.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,24 @@ def test_slice():
8686
else:
8787
e2 = s2[l]
8888
assert e1 == e2
89+
90+
def test_unpack():
91+
assert memoryview(b'\xaa')[0] == 170
92+
assert memoryview(b'\xaa').cast('B')[0] == 170
93+
assert memoryview(b'\xaa').cast('b')[0] == -86
94+
assert memoryview(b'\xaa\xaa').cast('H')[0] == 43690
95+
assert memoryview(b'\xaa\xaa').cast('h')[0] == -21846
96+
assert memoryview(b'\xaa\xaa\xaa\xaa').cast('I')[0] == 2863311530
97+
assert memoryview(b'\xaa\xaa\xaa\xaa').cast('i')[0] == -1431655766
98+
assert memoryview(b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa').cast('L')[0] == 12297829382473034410
99+
assert memoryview(b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa').cast('l')[0] == -6148914691236517206
100+
assert memoryview(b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa').cast('Q')[0] == 12297829382473034410
101+
assert memoryview(b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa').cast('q')[0] == -6148914691236517206
102+
assert memoryview(b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa').cast('N')[0] == 12297829382473034410
103+
assert memoryview(b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa').cast('n')[0] == -6148914691236517206
104+
assert memoryview(b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa').cast('P')[0] == 12297829382473034410
105+
assert memoryview(b'\x00').cast('?')[0] is False
106+
assert memoryview(b'\xaa').cast('?')[0] is True
107+
assert memoryview(b'\xaa\xaa\xaa\xaa').cast('f')[0] == -3.0316488252093987e-13
108+
assert memoryview(b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa').cast('d')[0] == -3.7206620809969885e-103
109+
assert memoryview(b'\xaa').cast('c')[0] == b'\xaa'

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/memoryview/MemoryViewNodes.java

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import com.oracle.graal.python.builtins.objects.cext.NativeCAPISymbols;
5757
import com.oracle.graal.python.builtins.objects.common.SequenceNodes;
5858
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
59+
import com.oracle.graal.python.builtins.objects.ints.PInt;
5960
import com.oracle.graal.python.builtins.objects.object.PythonObjectLibrary;
6061
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
6162
import com.oracle.graal.python.nodes.ErrorMessages;
@@ -64,6 +65,7 @@
6465
import com.oracle.graal.python.nodes.attributes.ReadAttributeFromObjectNode;
6566
import com.oracle.graal.python.runtime.PythonContext;
6667
import com.oracle.graal.python.runtime.exception.PException;
68+
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
6769
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
6870
import com.oracle.truffle.api.CompilerDirectives;
6971
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
@@ -81,6 +83,7 @@
8183
import com.oracle.truffle.api.interop.UnsupportedTypeException;
8284
import com.oracle.truffle.api.library.CachedLibrary;
8385
import com.oracle.truffle.api.nodes.Node;
86+
import com.oracle.truffle.api.profiles.ConditionProfile;
8487

8588
public class MemoryViewNodes {
8689
static int bytesize(PMemoryView.BufferFormat format) {
@@ -100,12 +103,7 @@ static int bytesize(PMemoryView.BufferFormat format) {
100103
return 4;
101104
case UNSIGNED_LONG:
102105
case SIGNED_LONG:
103-
case UNSIGNED_SIZE:
104-
case SIGNED_SIZE:
105-
case SIGNED_LONG_LONG:
106-
case UNSIGNED_LONG_LONG:
107106
case DOUBLE:
108-
case POINTER:
109107
return 8;
110108
}
111109
return -1;
@@ -151,35 +149,96 @@ static int compute(int ndim, int itemsize, int[] shape, int[] strides, int[] sub
151149

152150
@ImportStatic(PMemoryView.BufferFormat.class)
153151
abstract static class UnpackValueNode extends Node {
154-
public abstract Object execute(PMemoryView.BufferFormat format, byte[] bytes);
152+
// bytes are expected to already have the appropriate length
153+
public abstract Object execute(PMemoryView.BufferFormat format, String formatStr, byte[] bytes);
155154

156155
@Specialization(guards = "format == UNSIGNED_BYTE")
157-
static int unpackUnsignedByte(@SuppressWarnings("unused") PMemoryView.BufferFormat format, byte[] bytes) {
156+
static int unpackUnsignedByte(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, byte[] bytes) {
158157
return bytes[0] & 0xFF;
159158
}
160159

161160
@Specialization(guards = "format == SIGNED_BYTE")
162-
static int unpackSignedByte(@SuppressWarnings("unused") PMemoryView.BufferFormat format, byte[] bytes) {
161+
static int unpackSignedByte(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, byte[] bytes) {
163162
return bytes[0];
164163
}
165164

166165
@Specialization(guards = "format == SIGNED_SHORT")
167-
static int unpackShort(@SuppressWarnings("unused") PMemoryView.BufferFormat format, byte[] bytes) {
168-
return (bytes[0] & 0xFF) | (bytes[1] & 0xFF) << 8;
166+
static int unpackSignedShort(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, byte[] bytes) {
167+
return unpackInt16(bytes);
168+
}
169+
170+
@Specialization(guards = "format == UNSIGNED_SHORT")
171+
static int unpackUnsignedShort(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, byte[] bytes) {
172+
return unpackInt16(bytes) & 0xFFFF;
169173
}
170174

171175
@Specialization(guards = "format == SIGNED_INT")
172-
static int unpackInt(@SuppressWarnings("unused") PMemoryView.BufferFormat format, byte[] bytes) {
173-
return (bytes[0] & 0xFF) | (bytes[1] & 0xFF) << 8 | (bytes[2] & 0xFF) << 16 | (bytes[3] & 0xFF) << 24;
176+
static int unpackSignedInt(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, byte[] bytes) {
177+
return unpackInt32(bytes);
178+
}
179+
180+
@Specialization(guards = "format == UNSIGNED_INT")
181+
static long unpackUnsignedInt(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, byte[] bytes) {
182+
return unpackInt32(bytes) & 0xFFFFFFFFL;
174183
}
175184

176185
@Specialization(guards = "format == SIGNED_LONG")
177-
static long unpackLong(@SuppressWarnings("unused") PMemoryView.BufferFormat format, byte[] bytes) {
178-
return (bytes[0] & 0xFF) | (bytes[1] & 0xFF) << 8 | (bytes[2] & 0xFF) << 16 | (bytes[3] & 0xFF) << 24 |
179-
(bytes[4] & 0xFFL) << 32 | (bytes[5] & 0xFFL) << 40 | (bytes[6] & 0xFFL) << 48 | (bytes[7] & 0xFFL) << 56;
186+
static long unpackSignedLong(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, byte[] bytes) {
187+
return unpackInt64(bytes);
188+
}
189+
190+
@Specialization(guards = "format == UNSIGNED_LONG")
191+
static Object unpackUnsignedLong(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, byte[] bytes,
192+
@Cached ConditionProfile needsPIntProfile,
193+
@Shared("factory") @Cached PythonObjectFactory factory) {
194+
long signedLong = unpackInt64(bytes);
195+
if (needsPIntProfile.profile(signedLong < 0)) {
196+
return factory.createInt(PInt.longToUnsignedBigInteger(signedLong));
197+
} else {
198+
return signedLong;
199+
}
200+
}
201+
202+
@Specialization(guards = "format == FLOAT")
203+
static double unpackFloat(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, byte[] bytes) {
204+
return Float.intBitsToFloat(unpackInt32(bytes));
205+
}
206+
207+
@Specialization(guards = "format == DOUBLE")
208+
static double unpackDouble(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, byte[] bytes) {
209+
return Double.longBitsToDouble(unpackInt64(bytes));
180210
}
181211

182-
// TODO rest of formats
212+
@Specialization(guards = "format == BOOLEAN")
213+
static boolean unpackBoolean(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, byte[] bytes) {
214+
return bytes[0] != 0;
215+
}
216+
217+
@Specialization(guards = "format == CHAR")
218+
static Object unpackChar(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, byte[] bytes,
219+
@Shared("factory") @Cached PythonObjectFactory factory) {
220+
assert bytes.length == 1;
221+
return factory.createBytes(bytes);
222+
}
223+
224+
@Specialization(guards = "format == OTHER")
225+
static Object notImplemented(@SuppressWarnings("unused") PMemoryView.BufferFormat format, String formatStr, @SuppressWarnings("unused") byte[] bytes,
226+
@Cached PRaiseNode raiseNode) {
227+
throw raiseNode.raise(NotImplementedError, ErrorMessages.MEMORYVIEW_FORMAT_S_NOT_SUPPORTED, formatStr);
228+
}
229+
230+
private static short unpackInt16(byte[] bytes) {
231+
return (short) ((bytes[0] & 0xFF) | (bytes[1] & 0xFF) << 8);
232+
}
233+
234+
private static int unpackInt32(byte[] bytes) {
235+
return (bytes[0] & 0xFF) | (bytes[1] & 0xFF) << 8 | (bytes[2] & 0xFF) << 16 | (bytes[3] & 0xFF) << 24;
236+
}
237+
238+
private static long unpackInt64(byte[] bytes) {
239+
return (bytes[0] & 0xFFL) | (bytes[1] & 0xFFL) << 8 | (bytes[2] & 0xFFL) << 16 | (bytes[3] & 0xFFL) << 24 |
240+
(bytes[4] & 0xFFL) << 32 | (bytes[5] & 0xFFL) << 40 | (bytes[6] & 0xFFL) << 48 | (bytes[7] & 0xFFL) << 56;
241+
}
183242
}
184243

185244
@ImportStatic(PMemoryView.BufferFormat.class)
@@ -292,7 +351,7 @@ static Object doNative(PMemoryView self, Object ptr, int offset,
292351
} catch (UnsupportedMessageException | InvalidArrayIndexException e) {
293352
throw CompilerDirectives.shouldNotReachHere("native buffer read failed");
294353
}
295-
return unpackValueNode.execute(self.getFormat(), bytes);
354+
return unpackValueNode.execute(self.getFormat(), self.getFormatString(), bytes);
296355
}
297356

298357
@Specialization(guards = "ptr == null")
@@ -305,7 +364,7 @@ static Object doManaged(PMemoryView self, @SuppressWarnings("unused") Object ptr
305364
for (int i = 0; i < self.getItemSize(); i++) {
306365
bytes[i] = (byte) getItemNode.executeInt(getStorageNode.execute(self.getOwner()), offset + i);
307366
}
308-
return unpackValueNode.execute(self.getFormat(), bytes);
367+
return unpackValueNode.execute(self.getFormat(), self.getFormatString(), bytes);
309368
}
310369
}
311370

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/memoryview/PMemoryView.java

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,8 @@ public enum BufferFormat {
120120
SIGNED_INT,
121121
UNSIGNED_LONG,
122122
SIGNED_LONG,
123-
UNSIGNED_LONG_LONG,
124-
SIGNED_LONG_LONG,
125-
UNSIGNED_SIZE,
126-
SIGNED_SIZE,
127123
BOOLEAN,
128124
CHAR,
129-
POINTER,
130125
FLOAT,
131126
DOUBLE,
132127
OTHER;
@@ -154,23 +149,18 @@ public static BufferFormat fromString(String format) {
154149
case 'i':
155150
return SIGNED_INT;
156151
case 'L':
152+
case 'Q':
153+
case 'N':
154+
case 'P':
157155
return UNSIGNED_LONG;
158156
case 'l':
159-
return SIGNED_LONG;
160-
case 'Q':
161-
return UNSIGNED_LONG_LONG;
162157
case 'q':
163-
return SIGNED_LONG_LONG;
164-
case 'N':
165-
return SIGNED_SIZE;
166158
case 'n':
167-
return UNSIGNED_SIZE;
159+
return SIGNED_LONG;
168160
case 'f':
169161
return FLOAT;
170162
case 'd':
171163
return DOUBLE;
172-
case 'P':
173-
return POINTER;
174164
case '?':
175165
return BOOLEAN;
176166
case 'c':

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/ErrorMessages.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ public abstract class ErrorMessages {
369369
public static final String MEMORYVIEW_CAST_WRONG_LENGTH = "memoryview: product(shape) * itemsize != buffer size";
370370
public static final String MEMORYVIEW_CAST_ELEMENTS_MUST_BE_POSITIVE_INTEGERS = "memoryview.cast(): elements of shape must be integers > 0";
371371
public static final String MEMORYVIEW_HAS_D_EXPORTED_BUFFERS = "memoryview has %d exported buffers";
372+
public static final String MEMORYVIEW_FORMAT_S_NOT_SUPPORTED = "memoryview: format %s not supported";
372373
public static final String METACLASS_CONFLICT = "metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases";
373374
public static final String METHOD_NAME_MUST_BE = "method name must be string, not %p";
374375
public static final String MISSING_D_REQUIRED_S_ARGUMENT_S_POS = "%s() missing required argument '%s' (pos %d)";

0 commit comments

Comments
 (0)