Skip to content

Commit 6231a21

Browse files
committed
Implement packing of remaining formats
1 parent 72ebdcb commit 6231a21

File tree

3 files changed

+213
-15
lines changed

3 files changed

+213
-15
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_memoryview.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,58 @@ def test_unpack():
107107
assert memoryview(b'\xaa\xaa\xaa\xaa').cast('f')[0] == -3.0316488252093987e-13
108108
assert memoryview(b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa').cast('d')[0] == -3.7206620809969885e-103
109109
assert memoryview(b'\xaa').cast('c')[0] == b'\xaa'
110+
111+
def test_pack():
112+
b = bytearray(1)
113+
memoryview(b).cast('B')[0] = 170
114+
assert b == b'\xaa'
115+
b = bytearray(1)
116+
memoryview(b).cast('b')[0] = -86
117+
assert b == b'\xaa'
118+
b = bytearray(2)
119+
memoryview(b).cast('H')[0] = 43690
120+
assert b == b'\xaa\xaa'
121+
b = bytearray(2)
122+
memoryview(b).cast('h')[0] = -21846
123+
assert b == b'\xaa\xaa'
124+
b = bytearray(4)
125+
memoryview(b).cast('I')[0] = 2863311530
126+
assert b == b'\xaa\xaa\xaa\xaa'
127+
b = bytearray(4)
128+
memoryview(b).cast('i')[0] = -1431655766
129+
assert b == b'\xaa\xaa\xaa\xaa'
130+
b = bytearray(8)
131+
memoryview(b).cast('L')[0] = 12297829382473034410
132+
assert b == b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa'
133+
b = bytearray(8)
134+
memoryview(b).cast('l')[0] = -6148914691236517206
135+
assert b == b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa'
136+
b = bytearray(8)
137+
memoryview(b).cast('Q')[0] = 12297829382473034410
138+
assert b == b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa'
139+
b = bytearray(8)
140+
memoryview(b).cast('q')[0] = -6148914691236517206
141+
assert b == b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa'
142+
b = bytearray(8)
143+
memoryview(b).cast('N')[0] = 12297829382473034410
144+
assert b == b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa'
145+
b = bytearray(8)
146+
memoryview(b).cast('n')[0] = -6148914691236517206
147+
assert b == b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa'
148+
b = bytearray(8)
149+
memoryview(b).cast('P')[0] = 12297829382473034410
150+
assert b == b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa'
151+
b = bytearray(4)
152+
memoryview(b).cast('f')[0] = -3.0316488252093987e-13
153+
assert b == b'\xaa\xaa\xaa\xaa'
154+
b = bytearray(8)
155+
memoryview(b).cast('d')[0] = -3.7206620809969885e-103
156+
assert b == b'\xaa\xaa\xaa\xaa\xaa\xaa\xaa\xaa'
157+
b = bytearray(1)
158+
memoryview(b).cast('c')[0] = b'\xaa'
159+
assert b == b'\xaa'
160+
b = bytearray(1)
161+
memoryview(b).cast('?')[0] = True
162+
assert b == b'\x01'
163+
memoryview(b).cast('?')[0] = False
164+
assert b == b'\x00'

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/memoryview/MemoryViewNodes.java

Lines changed: 157 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import com.oracle.graal.python.PythonLanguage;
5353
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
5454
import com.oracle.graal.python.builtins.objects.bytes.PByteArray;
55+
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
5556
import com.oracle.graal.python.builtins.objects.cext.CExtNodes;
5657
import com.oracle.graal.python.builtins.objects.cext.NativeCAPISymbols;
5758
import com.oracle.graal.python.builtins.objects.common.SequenceNodes;
@@ -63,6 +64,7 @@
6364
import com.oracle.graal.python.nodes.PGuards;
6465
import com.oracle.graal.python.nodes.PRaiseNode;
6566
import com.oracle.graal.python.nodes.attributes.ReadAttributeFromObjectNode;
67+
import com.oracle.graal.python.nodes.util.CastToJavaUnsignedLongNode;
6668
import com.oracle.graal.python.runtime.PythonContext;
6769
import com.oracle.graal.python.runtime.exception.PException;
6870
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
@@ -241,35 +243,175 @@ private static long unpackInt64(byte[] bytes) {
241243
}
242244
}
243245

244-
@ImportStatic(PMemoryView.BufferFormat.class)
246+
@ImportStatic({PMemoryView.BufferFormat.class, PGuards.class})
245247
abstract static class PackValueNode extends Node {
246248
@Child private PRaiseNode raiseNode;
247249

248250
// Output goes to bytes, lenght not checked
249-
public abstract void execute(PMemoryView.BufferFormat format, Object object, byte[] bytes);
251+
public abstract void execute(VirtualFrame frame, PMemoryView.BufferFormat format, String formatStr, Object object, byte[] bytes);
250252

251-
@Specialization(guards = "format == UNSIGNED_BYTE", limit = "2")
252-
void packUnsignedByte(@SuppressWarnings("unused") PMemoryView.BufferFormat format, Object object, byte[] bytes,
253+
@Specialization(guards = "format == UNSIGNED_BYTE")
254+
void packUnsignedByteInt(@SuppressWarnings("unused") PMemoryView.BufferFormat format, String formatStr, int value, byte[] bytes) {
255+
assert bytes.length == 1;
256+
if (value < 0 || value > 0xFF) {
257+
throw raise(ValueError, ErrorMessages.MEMORYVIEW_INVALID_VALUE_FOR_FORMAT_S, formatStr);
258+
}
259+
bytes[0] = (byte) value;
260+
}
261+
262+
@Specialization(guards = "format == UNSIGNED_BYTE", replaces = "packUnsignedByteInt", limit = "2")
263+
void packUnsignedByteGeneric(VirtualFrame frame, @SuppressWarnings("unused") PMemoryView.BufferFormat format, String formatStr, Object object, byte[] bytes,
253264
@CachedLibrary("object") PythonObjectLibrary lib) {
265+
long value = lib.asJavaLong(object, frame);
254266
assert bytes.length == 1;
255-
long value = lib.asJavaLong(object);
256267
if (value < 0 || value > 0xFF) {
257-
throw raise(ValueError, ErrorMessages.MEMORYVIEW_INVALID_VALUE_FOR_FORMAT_S, format);
268+
throw raise(ValueError, ErrorMessages.MEMORYVIEW_INVALID_VALUE_FOR_FORMAT_S, formatStr);
258269
}
259270
bytes[0] = (byte) value;
260271
}
261272

273+
@Specialization(guards = "format == SIGNED_BYTE", replaces = "packUnsignedByteInt", limit = "2")
274+
void packSignedByteGeneric(VirtualFrame frame, @SuppressWarnings("unused") PMemoryView.BufferFormat format, String formatStr, Object object, byte[] bytes,
275+
@CachedLibrary("object") PythonObjectLibrary lib) {
276+
long value = lib.asJavaLong(object, frame);
277+
assert bytes.length == 1;
278+
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
279+
throw raise(ValueError, ErrorMessages.MEMORYVIEW_INVALID_VALUE_FOR_FORMAT_S, formatStr);
280+
}
281+
bytes[0] = (byte) value;
282+
}
283+
284+
@Specialization(guards = "format == SIGNED_SHORT", limit = "2")
285+
void packSignedShortGeneric(VirtualFrame frame, @SuppressWarnings("unused") PMemoryView.BufferFormat format, String formatStr, Object object, byte[] bytes,
286+
@CachedLibrary("object") PythonObjectLibrary lib) {
287+
long value = lib.asJavaLong(object, frame);
288+
if (value < Short.MIN_VALUE || value > Short.MAX_VALUE) {
289+
throw raise(ValueError, ErrorMessages.MEMORYVIEW_INVALID_VALUE_FOR_FORMAT_S, formatStr);
290+
}
291+
packInt16((int) value, bytes);
292+
}
293+
294+
@Specialization(guards = "format == UNSIGNED_SHORT", limit = "2")
295+
void packUnsignedShortGeneric(VirtualFrame frame, @SuppressWarnings("unused") PMemoryView.BufferFormat format, String formatStr, Object object, byte[] bytes,
296+
@CachedLibrary("object") PythonObjectLibrary lib) {
297+
long value = lib.asJavaLong(object, frame);
298+
if (value < 0 || value > (Short.MAX_VALUE << 1) + 1) {
299+
throw raise(ValueError, ErrorMessages.MEMORYVIEW_INVALID_VALUE_FOR_FORMAT_S, formatStr);
300+
}
301+
packInt16((int) value, bytes);
302+
}
303+
304+
@Specialization(guards = "format == SIGNED_INT")
305+
static void packSignedIntInt(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, int value, byte[] bytes) {
306+
packInt32(value, bytes);
307+
}
308+
309+
@Specialization(guards = "format == SIGNED_INT", replaces = "packSignedIntInt", limit = "2")
310+
void packSignedIntGeneric(VirtualFrame frame, @SuppressWarnings("unused") PMemoryView.BufferFormat format, String formatStr, Object object, byte[] bytes,
311+
@CachedLibrary("object") PythonObjectLibrary lib) {
312+
long value = lib.asJavaLong(object, frame);
313+
if (value < Integer.MIN_VALUE || value > Integer.MAX_VALUE) {
314+
throw raise(ValueError, ErrorMessages.MEMORYVIEW_INVALID_VALUE_FOR_FORMAT_S, formatStr);
315+
}
316+
packInt32((int) value, bytes);
317+
}
318+
319+
@Specialization(guards = "format == UNSIGNED_INT", replaces = "packSignedIntInt", limit = "2")
320+
void packUnsignedIntGeneric(VirtualFrame frame, @SuppressWarnings("unused") PMemoryView.BufferFormat format, String formatStr, Object object, byte[] bytes,
321+
@CachedLibrary("object") PythonObjectLibrary lib) {
322+
long value = lib.asJavaLong(object, frame);
323+
if (value < 0 || value > ((long) (Integer.MAX_VALUE) << 1L) + 1L) {
324+
throw raise(ValueError, ErrorMessages.MEMORYVIEW_INVALID_VALUE_FOR_FORMAT_S, formatStr);
325+
}
326+
packInt32((int) value, bytes);
327+
}
328+
262329
@Specialization(guards = "format == SIGNED_LONG", limit = "2")
263-
static void packLong(@SuppressWarnings("unused") PMemoryView.BufferFormat format, Object object, byte[] bytes,
330+
static void packSignedLong(VirtualFrame frame, @SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, Object object, byte[] bytes,
331+
@CachedLibrary("object") PythonObjectLibrary lib) {
332+
assert bytes.length == 8;
333+
packInt64(lib.asJavaLong(object, frame), bytes);
334+
}
335+
336+
@Specialization(guards = "format == UNSIGNED_LONG", limit = "2")
337+
static void packUnsignedLong(VirtualFrame frame, @SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, Object object, byte[] bytes,
338+
@Cached CastToJavaUnsignedLongNode cast,
264339
@CachedLibrary("object") PythonObjectLibrary lib) {
265340
assert bytes.length == 8;
266-
long value = lib.asJavaLong(object);
267-
for (int i = 7; i >= 0; i--) {
268-
bytes[i] = (byte) (value & 0xFFL);
269-
value >>= 8;
341+
packInt64(cast.execute(lib.asIndexWithFrame(object, frame)), bytes);
342+
}
343+
344+
@Specialization(guards = "format == FLOAT", limit = "2")
345+
static void packFloat(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, Object object, byte[] bytes,
346+
@CachedLibrary("object") PythonObjectLibrary lib) {
347+
assert bytes.length == 4;
348+
packInt32(Float.floatToRawIntBits((float) lib.asJavaDouble(object)), bytes);
349+
}
350+
351+
@Specialization(guards = "format == DOUBLE", limit = "2")
352+
static void packDouble(@SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, Object object, byte[] bytes,
353+
@CachedLibrary("object") PythonObjectLibrary lib) {
354+
assert bytes.length == 8;
355+
packInt64(Double.doubleToRawLongBits(lib.asJavaDouble(object)), bytes);
356+
}
357+
358+
@Specialization(guards = "format == BOOLEAN", limit = "2")
359+
static void packBoolean(VirtualFrame frame, @SuppressWarnings("unused") PMemoryView.BufferFormat format, @SuppressWarnings("unused") String formatStr, Object object, byte[] bytes,
360+
@CachedLibrary("object") PythonObjectLibrary lib) {
361+
assert bytes.length == 1;
362+
bytes[0] = lib.isTrue(object, frame) ? (byte) 1 : (byte) 0;
363+
}
364+
365+
@Specialization(guards = "format == CHAR", limit = "2")
366+
void packChar(@SuppressWarnings("unused") PMemoryView.BufferFormat format, String formatStr, PBytes object, byte[] bytes,
367+
@CachedLibrary("object") PythonObjectLibrary lib) {
368+
try {
369+
byte[] value = lib.getBufferBytes(object);
370+
if (value.length != 1) {
371+
throw raise(ValueError, ErrorMessages.MEMORYVIEW_INVALID_VALUE_FOR_FORMAT_S, formatStr);
372+
}
373+
bytes[0] = value[0];
374+
} catch (UnsupportedMessageException e) {
375+
throw CompilerDirectives.shouldNotReachHere();
270376
}
271377
}
272378

379+
@Specialization(guards = {"format == CHAR", "!isBytes(object)"})
380+
void packChar(@SuppressWarnings("unused") PMemoryView.BufferFormat format, String formatStr, @SuppressWarnings("unused") Object object, @SuppressWarnings("unused") byte[] bytes) {
381+
throw raise(TypeError, ErrorMessages.MEMORYVIEW_INVALID_TYPE_FOR_FORMAT_S, formatStr);
382+
}
383+
384+
@Specialization(guards = "format == OTHER")
385+
void notImplemented(@SuppressWarnings("unused") PMemoryView.BufferFormat format, String formatStr, @SuppressWarnings("unused") Object object, @SuppressWarnings("unused") byte[] bytes) {
386+
throw raise(NotImplementedError, ErrorMessages.MEMORYVIEW_FORMAT_S_NOT_SUPPORTED, formatStr);
387+
}
388+
389+
private static void packInt16(int value, byte[] bytes) {
390+
assert bytes.length == 2;
391+
bytes[0] = (byte) value;
392+
bytes[1] = (byte) (value >> 8);
393+
}
394+
395+
private static void packInt32(int value, byte[] bytes) {
396+
assert bytes.length == 4;
397+
bytes[0] = (byte) value;
398+
bytes[1] = (byte) (value >> 8);
399+
bytes[2] = (byte) (value >> 16);
400+
bytes[3] = (byte) (value >> 24);
401+
}
402+
403+
private static void packInt64(long value, byte[] bytes) {
404+
assert bytes.length == 8;
405+
bytes[0] = (byte) value;
406+
bytes[1] = (byte) (value >> 8);
407+
bytes[2] = (byte) (value >> 16);
408+
bytes[3] = (byte) (value >> 24);
409+
bytes[4] = (byte) (value >> 32);
410+
bytes[5] = (byte) (value >> 40);
411+
bytes[6] = (byte) (value >> 48);
412+
bytes[7] = (byte) (value >> 56);
413+
}
414+
273415
private PException raise(PythonBuiltinClassType type, String message, Object... args) {
274416
if (raiseNode == null) {
275417
CompilerDirectives.transferToInterpreterAndInvalidate();
@@ -372,12 +514,12 @@ abstract static class WriteItemAtNode extends Node {
372514
public abstract void execute(VirtualFrame frame, PMemoryView self, Object ptr, int offset, Object object);
373515

374516
@Specialization(guards = "ptr != null")
375-
static void doNative(PMemoryView self, Object ptr, int offset, Object object,
517+
static void doNative(VirtualFrame frame, PMemoryView self, Object ptr, int offset, Object object,
376518
@CachedLibrary(limit = "1") InteropLibrary lib,
377519
@Cached PackValueNode packValueNode) {
378520
int itemsize = self.getItemSize();
379521
byte[] bytes = new byte[itemsize];
380-
packValueNode.execute(self.getFormat(), object, bytes);
522+
packValueNode.execute(frame, self.getFormat(), self.getFormatString(), object, bytes);
381523
try {
382524
for (int i = 0; i < itemsize; i++) {
383525
lib.writeArrayElement(ptr, offset + i, bytes[i]);
@@ -388,13 +530,13 @@ static void doNative(PMemoryView self, Object ptr, int offset, Object object,
388530
}
389531

390532
@Specialization(guards = "ptr == null")
391-
static void doManaged(PMemoryView self, @SuppressWarnings("unused") Object ptr, int offset, Object object,
533+
static void doManaged(VirtualFrame frame, PMemoryView self, @SuppressWarnings("unused") Object ptr, int offset, Object object,
392534
@Cached PackValueNode packValueNode,
393535
@Cached SequenceNodes.GetSequenceStorageNode getStorageNode,
394536
@Cached SequenceStorageNodes.SetItemScalarNode setItemNode) {
395537
// TODO assumes bytes storage
396538
byte[] bytes = new byte[self.getItemSize()];
397-
packValueNode.execute(self.getFormat(), object, bytes);
539+
packValueNode.execute(frame, self.getFormat(), self.getFormatString(), object, bytes);
398540
for (int i = 0; i < self.getItemSize(); i++) {
399541
setItemNode.execute(getStorageNode.execute(self.getOwner()), offset + i, bytes[i]);
400542
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/ErrorMessages.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ public abstract class ErrorMessages {
356356
public static final String MEMORYVIEW_INVALID_SLICE_KEY = "memoryview: invalid slice key";
357357
public static final String MEMORYVIEW_A_BYTES_LIKE_OBJECT_REQUIRED_NOT_P = "memoryview: a bytes-like object is required, not '%p'";
358358
public static final String MEMORYVIEW_INVALID_VALUE_FOR_FORMAT_S = "memoryview: invalid value for format '%s'";
359+
public static final String MEMORYVIEW_INVALID_TYPE_FOR_FORMAT_S = "memoryview: invalid type for format '%s'";
359360
public static final String MEMORYVIEW_SLICE_ASSIGNMENT_RESTRICTED_TO_DIM_1 = "memoryview slice assignments are currently restricted to ndim = 1";
360361
public static final String MEMORYVIEW_DIFFERENT_STRUCTURES = "memoryview assignment: lvalue and rvalue have different structures";
361362
public static final String MEMORYVIEW_FORBIDDEN_RELEASED = "operation forbidden on released memoryview object";

0 commit comments

Comments
 (0)