Skip to content

Commit db146f5

Browse files
committed
Implement memoryview equality
1 parent adbd7b3 commit db146f5

File tree

5 files changed

+121
-5
lines changed

5 files changed

+121
-5
lines changed

graalpython/com.oracle.graal.python.test/src/tests/unittest_tags/test_memoryview.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
*graalpython.lib-python.3.test.test_memoryview.ArrayMemoryviewTest.test_writable_readonly
2626
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceSliceTest.test_attributes_readonly
2727
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceSliceTest.test_attributes_writable
28+
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceSliceTest.test_compare
2829
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceSliceTest.test_contextmanager
2930
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceSliceTest.test_delitem
3031
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceSliceTest.test_gc
@@ -44,6 +45,7 @@
4445
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceSliceTest.test_writable_readonly
4546
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceTest.test_attributes_readonly
4647
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceTest.test_attributes_writable
48+
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceTest.test_compare
4749
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceTest.test_contextmanager
4850
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceTest.test_delitem
4951
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceTest.test_gc
@@ -64,6 +66,7 @@
6466
*graalpython.lib-python.3.test.test_memoryview.BytesMemorySliceTest.test_writable_readonly
6567
*graalpython.lib-python.3.test.test_memoryview.BytesMemoryviewTest.test_attributes_readonly
6668
*graalpython.lib-python.3.test.test_memoryview.BytesMemoryviewTest.test_attributes_writable
69+
*graalpython.lib-python.3.test.test_memoryview.BytesMemoryviewTest.test_compare
6770
*graalpython.lib-python.3.test.test_memoryview.BytesMemoryviewTest.test_constructor
6871
*graalpython.lib-python.3.test.test_memoryview.BytesMemoryviewTest.test_contextmanager
6972
*graalpython.lib-python.3.test.test_memoryview.BytesMemoryviewTest.test_delitem

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinConstructors.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3323,7 +3323,7 @@ public boolean hasSetItem(Object object) {
33233323
public abstract static class MemoryViewNode extends PythonBuiltinNode {
33243324
public abstract IntrinsifiedPMemoryView execute(Object cls, Object object);
33253325

3326-
public final IntrinsifiedPMemoryView create(Object object) {
3326+
public final IntrinsifiedPMemoryView execute(Object object) {
33273327
return execute(PythonBuiltinClassType.PMemoryView, object);
33283328
}
33293329

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/PythonCextBuiltins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,7 @@ Object wrap(VirtualFrame frame, Object object,
15541554
@Cached BuiltinConstructors.MemoryViewNode memoryViewNode,
15551555
@Cached GetNativeNullNode getNativeNullNode) {
15561556
try {
1557-
return memoryViewNode.create(object);
1557+
return memoryViewNode.execute(object);
15581558
} catch (PException e) {
15591559
transformToNative(frame, e);
15601560
return getNativeNullNode.execute();

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/memoryview/MemoryViewBuiltins.java

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import static com.oracle.graal.python.builtins.PythonBuiltinClassType.ValueError;
77
import static com.oracle.graal.python.nodes.SpecialMethodNames.__DELITEM__;
88
import static com.oracle.graal.python.nodes.SpecialMethodNames.__ENTER__;
9+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__EQ__;
910
import static com.oracle.graal.python.nodes.SpecialMethodNames.__EXIT__;
1011
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GETITEM__;
1112
import static com.oracle.graal.python.nodes.SpecialMethodNames.__HASH__;
@@ -25,6 +26,7 @@
2526
import com.oracle.graal.python.builtins.modules.BuiltinConstructors;
2627
import com.oracle.graal.python.builtins.objects.PEllipsis;
2728
import com.oracle.graal.python.builtins.objects.PNone;
29+
import com.oracle.graal.python.builtins.objects.PNotImplemented;
2830
import com.oracle.graal.python.builtins.objects.bytes.BytesBuiltins;
2931
import com.oracle.graal.python.builtins.objects.bytes.BytesBuiltins.ExpectIntNode;
3032
import com.oracle.graal.python.builtins.objects.bytes.BytesBuiltins.SepExpectByteNode;
@@ -50,19 +52,21 @@
5052
import com.oracle.graal.python.nodes.function.builtins.clinic.ArgumentClinicProvider;
5153
import com.oracle.graal.python.nodes.subscript.SliceLiteralNode;
5254
import com.oracle.graal.python.runtime.AsyncHandler;
55+
import com.oracle.graal.python.runtime.ExecutionContext.IndirectCallContext;
5356
import com.oracle.graal.python.runtime.PythonContext;
5457
import com.oracle.graal.python.runtime.PythonCore;
55-
import com.oracle.graal.python.runtime.ExecutionContext.IndirectCallContext;
58+
import com.oracle.graal.python.runtime.exception.PException;
5659
import com.oracle.graal.python.runtime.sequence.storage.IntSequenceStorage;
5760
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
5861
import com.oracle.graal.python.util.PythonUtils;
5962
import com.oracle.truffle.api.CompilerDirectives;
6063
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
6164
import com.oracle.truffle.api.dsl.Cached;
65+
import com.oracle.truffle.api.dsl.Cached.Shared;
66+
import com.oracle.truffle.api.dsl.Fallback;
6267
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
6368
import com.oracle.truffle.api.dsl.NodeFactory;
6469
import com.oracle.truffle.api.dsl.Specialization;
65-
import com.oracle.truffle.api.dsl.Cached.Shared;
6670
import com.oracle.truffle.api.frame.VirtualFrame;
6771
import com.oracle.truffle.api.interop.ArityException;
6872
import com.oracle.truffle.api.interop.InteropLibrary;
@@ -218,7 +222,7 @@ Object setitem(VirtualFrame frame, IntrinsifiedPMemoryView self, PSlice slice, O
218222
if (self.getDimensions() != 1) {
219223
throw raise(NotImplementedError, ErrorMessages.MEMORYVIEW_SLICE_ASSIGNMENT_RESTRICTED_TO_DIM_1);
220224
}
221-
IntrinsifiedPMemoryView srcView = createMemoryView.create(object);
225+
IntrinsifiedPMemoryView srcView = createMemoryView.execute(object);
222226
IntrinsifiedPMemoryView destView = (IntrinsifiedPMemoryView) getItemNode.execute(frame, self, slice);
223227
// TODO format skip @
224228
if (srcView.getDimensions() != destView.getDimensions() || srcView.getBufferShape()[0] != destView.getBufferShape()[0] || !srcView.getFormatString().equals(destView.getFormatString())) {
@@ -257,6 +261,110 @@ private void checkReadonly(IntrinsifiedPMemoryView self) {
257261
}
258262
}
259263

264+
@Builtin(name = __EQ__, minNumOfPositionalArgs = 2)
265+
@GenerateNodeFactory
266+
public static abstract class EqNode extends PythonBinaryBuiltinNode {
267+
@Child private CExtNodes.PCallCapiFunction callCapiFunction;
268+
269+
@Specialization
270+
boolean eq(VirtualFrame frame, IntrinsifiedPMemoryView self, IntrinsifiedPMemoryView other,
271+
@CachedLibrary(limit = "3") PythonObjectLibrary lib,
272+
@Cached MemoryViewNodes.ReadItemAtNode readSelf,
273+
@Cached MemoryViewNodes.ReadItemAtNode readOther) {
274+
if (self.isReleased() || other.isReleased()) {
275+
return self == other;
276+
}
277+
278+
int ndim = self.getDimensions();
279+
if (ndim != other.getDimensions()) {
280+
return false;
281+
}
282+
283+
for (int i = 0; i < ndim; i++) {
284+
if (self.getBufferShape()[i] != other.getBufferShape()[i]) {
285+
return false;
286+
}
287+
if (self.getBufferShape()[i] == 0) {
288+
break;
289+
}
290+
}
291+
292+
// TODO CPython supports only limited set of typed for reading and writing, but
293+
// for equality comparisons, it supports all the struct module formats. Implement that
294+
295+
if (ndim == 0) {
296+
Object selfItem = readSelf.execute(self, self.getBufferPointer(), 0);
297+
Object otherItem = readOther.execute(other, other.getBufferPointer(), 0);
298+
return lib.equalsWithFrame(selfItem, otherItem, lib, frame);
299+
}
300+
301+
return recursive(lib, self, other, readSelf, readOther, 0, ndim,
302+
self.getBufferPointer(), self.getOffset(), other.getBufferPointer(), other.getOffset());
303+
}
304+
305+
@Specialization(guards = "!isMemoryView(other)")
306+
Object eq(VirtualFrame frame, IntrinsifiedPMemoryView self, Object other,
307+
@Cached BuiltinConstructors.MemoryViewNode memoryViewNode,
308+
@CachedLibrary(limit = "3") PythonObjectLibrary lib,
309+
@Cached MemoryViewNodes.ReadItemAtNode readSelf,
310+
@Cached MemoryViewNodes.ReadItemAtNode readOther) {
311+
IntrinsifiedPMemoryView memoryView;
312+
try {
313+
memoryView = memoryViewNode.execute(other);
314+
} catch (PException e) {
315+
return PNotImplemented.NOT_IMPLEMENTED;
316+
}
317+
return eq(frame, self, memoryView, lib, readSelf, readOther);
318+
}
319+
320+
@Fallback
321+
@SuppressWarnings("unused")
322+
static Object eq(Object self, Object other) {
323+
return PNotImplemented.NOT_IMPLEMENTED;
324+
}
325+
326+
private boolean recursive(PythonObjectLibrary lib, IntrinsifiedPMemoryView self, IntrinsifiedPMemoryView other,
327+
MemoryViewNodes.ReadItemAtNode readSelf, MemoryViewNodes.ReadItemAtNode readOther,
328+
int dim, int ndim, Object selfPtr, int selfOffset, Object otherPtr, int otherOffset) {
329+
for (int i = 0; i < self.getBufferShape()[dim]; i++) {
330+
Object selfXPtr = selfPtr;
331+
int selfXOffset = selfOffset;
332+
Object otherXPtr = otherPtr;
333+
int otherXOffset = otherOffset;
334+
if (self.getBufferSuboffsets() != null && self.getBufferSuboffsets()[dim] >= 0) {
335+
selfXPtr = getCallCapiFunction().call(NativeCAPISymbols.FUN_TRUFFLE_ADD_SUBOFFSET, selfPtr, selfOffset, self.getBufferSuboffsets()[dim], self.getLength());
336+
selfXOffset = 0;
337+
}
338+
if (other.getBufferSuboffsets() != null && other.getBufferSuboffsets()[dim] >= 0) {
339+
otherXPtr = getCallCapiFunction().call(NativeCAPISymbols.FUN_TRUFFLE_ADD_SUBOFFSET, otherPtr, otherOffset, other.getBufferSuboffsets()[dim], other.getLength());
340+
otherXOffset = 0;
341+
}
342+
if (dim == ndim - 1) {
343+
Object selfItem = readSelf.execute(self, selfXPtr, selfXOffset);
344+
Object otherItem = readOther.execute(other, otherXPtr, otherXOffset);
345+
if (!lib.equals(selfItem, otherItem, lib)) {
346+
return false;
347+
}
348+
} else {
349+
if (!recursive(lib, self, other, readSelf, readOther, dim + 1, ndim, selfXPtr, selfXOffset, otherXPtr, otherXOffset)) {
350+
return false;
351+
}
352+
}
353+
selfOffset += self.getBufferStrides()[dim];
354+
otherOffset += other.getBufferStrides()[dim];
355+
}
356+
return true;
357+
}
358+
359+
private CExtNodes.PCallCapiFunction getCallCapiFunction() {
360+
if (callCapiFunction == null) {
361+
CompilerDirectives.transferToInterpreterAndInvalidate();
362+
callCapiFunction = insert(CExtNodes.PCallCapiFunction.create());
363+
}
364+
return callCapiFunction;
365+
}
366+
}
367+
260368
@Builtin(name = __DELITEM__, minNumOfPositionalArgs = 1)
261369
@GenerateNodeFactory
262370
public static abstract class DelItemNode extends PythonUnaryBuiltinNode {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/PGuards.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import com.oracle.graal.python.builtins.objects.ints.PInt;
6464
import com.oracle.graal.python.builtins.objects.iterator.PSequenceIterator;
6565
import com.oracle.graal.python.builtins.objects.list.PList;
66+
import com.oracle.graal.python.builtins.objects.memoryview.IntrinsifiedPMemoryView;
6667
import com.oracle.graal.python.builtins.objects.method.PBuiltinMethod;
6768
import com.oracle.graal.python.builtins.objects.method.PMethod;
6869
import com.oracle.graal.python.builtins.objects.module.PythonModule;
@@ -134,6 +135,10 @@ public static boolean isEllipsis(Object object) {
134135
return object == PEllipsis.INSTANCE;
135136
}
136137

138+
public static boolean isMemoryView(Object object) {
139+
return object instanceof IntrinsifiedPMemoryView;
140+
}
141+
137142
public static boolean isDeleteMarker(Object object) {
138143
return object == DescriptorDeleteMarker.INSTANCE;
139144
}

0 commit comments

Comments
 (0)