Skip to content

Commit adbd7b3

Browse files
committed
Fix slice assignment with memory overlap
1 parent b047e4c commit adbd7b3

File tree

2 files changed

+37
-74
lines changed

2 files changed

+37
-74
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/memoryview/MemoryViewBuiltins.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,19 @@
5050
import com.oracle.graal.python.nodes.function.builtins.clinic.ArgumentClinicProvider;
5151
import com.oracle.graal.python.nodes.subscript.SliceLiteralNode;
5252
import com.oracle.graal.python.runtime.AsyncHandler;
53-
import com.oracle.graal.python.runtime.ExecutionContext.IndirectCallContext;
5453
import com.oracle.graal.python.runtime.PythonContext;
5554
import com.oracle.graal.python.runtime.PythonCore;
55+
import com.oracle.graal.python.runtime.ExecutionContext.IndirectCallContext;
5656
import com.oracle.graal.python.runtime.sequence.storage.IntSequenceStorage;
5757
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
5858
import com.oracle.graal.python.util.PythonUtils;
5959
import com.oracle.truffle.api.CompilerDirectives;
6060
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
6161
import com.oracle.truffle.api.dsl.Cached;
62-
import com.oracle.truffle.api.dsl.Cached.Shared;
6362
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
6463
import com.oracle.truffle.api.dsl.NodeFactory;
6564
import com.oracle.truffle.api.dsl.Specialization;
65+
import com.oracle.truffle.api.dsl.Cached.Shared;
6666
import com.oracle.truffle.api.frame.VirtualFrame;
6767
import com.oracle.truffle.api.interop.ArityException;
6868
import com.oracle.truffle.api.interop.InteropLibrary;
@@ -211,7 +211,8 @@ Object setitem(VirtualFrame frame, IntrinsifiedPMemoryView self, PSlice slice, O
211211
@Cached GetItemNode getItemNode,
212212
@Cached BuiltinConstructors.MemoryViewNode createMemoryView,
213213
@Cached MemoryViewNodes.PointerLookupNode pointerLookupNode,
214-
@Cached MemoryViewNodes.CopyBytesNode copyBytesNode) {
214+
@Cached MemoryViewNodes.ToJavaBytesNode toJavaBytesNode,
215+
@Cached MemoryViewNodes.WriteBytesAtNode writeBytesAtNode) {
215216
self.checkReleased(this);
216217
checkReadonly(self);
217218
if (self.getDimensions() != 1) {
@@ -223,11 +224,13 @@ Object setitem(VirtualFrame frame, IntrinsifiedPMemoryView self, PSlice slice, O
223224
if (srcView.getDimensions() != destView.getDimensions() || srcView.getBufferShape()[0] != destView.getBufferShape()[0] || !srcView.getFormatString().equals(destView.getFormatString())) {
224225
throw raise(ValueError, ErrorMessages.MEMORYVIEW_DIFFERENT_STRUCTURES);
225226
}
227+
// The intermediate array is necessary for overlapping views (where src and dest are the
228+
// same buffer)
229+
byte[] srcBytes = toJavaBytesNode.execute(srcView);
230+
int itemsize = srcView.getItemSize();
226231
for (int i = 0; i < destView.getBufferShape()[0]; i++) {
227-
// TODO doesn't look very efficient
228232
MemoryViewNodes.MemoryPointer destPtr = pointerLookupNode.execute(frame, destView, i);
229-
MemoryViewNodes.MemoryPointer srcPtr = pointerLookupNode.execute(frame, srcView, i);
230-
copyBytesNode.execute(destView, destPtr.ptr, destPtr.offset, srcView, srcPtr.ptr, srcPtr.offset, destView.getItemSize());
233+
writeBytesAtNode.execute(srcBytes, i * itemsize, itemsize, self, destPtr.ptr, destPtr.offset);
231234
}
232235
return PNone.NONE;
233236
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/memoryview/MemoryViewNodes.java

Lines changed: 28 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,34 @@ static void doManaged(byte[] dest, int destOffset, int len, IntrinsifiedPMemoryV
208208
}
209209
}
210210

211+
@GenerateUncached
212+
static abstract class WriteBytesAtNode extends Node {
213+
public abstract void execute(byte[] src, int srcOffset, int len, IntrinsifiedPMemoryView self, Object ptr, int offset);
214+
215+
@Specialization(guards = "ptr != null")
216+
static void doNative(byte[] src, int srcOffset, int len, @SuppressWarnings("unused") IntrinsifiedPMemoryView self, Object ptr, int offset,
217+
@CachedLibrary(limit = "1") InteropLibrary lib) {
218+
try {
219+
for (int i = 0; i < len; i++) {
220+
lib.writeArrayElement(ptr, offset + i, src[srcOffset + i]);
221+
}
222+
} catch (UnsupportedMessageException | InvalidArrayIndexException | UnsupportedTypeException e) {
223+
throw CompilerDirectives.shouldNotReachHere("native buffer read failed");
224+
}
225+
}
226+
227+
@Specialization(guards = "ptr == null")
228+
static void doManaged(byte[] src, int srcOffset, int len, IntrinsifiedPMemoryView self, @SuppressWarnings("unused") Object ptr, int offset,
229+
@Cached SequenceNodes.GetSequenceStorageNode getStorageNode,
230+
@Cached SequenceStorageNodes.SetItemScalarNode setItemNode) {
231+
// TODO assumes byte storage
232+
SequenceStorage storage = getStorageNode.execute(self.getOwner());
233+
for (int i = 0; i < len; i++) {
234+
setItemNode.execute(storage, offset + i, src[srcOffset + i]);
235+
}
236+
}
237+
}
238+
211239
static abstract class ReadItemAtNode extends Node {
212240
public abstract Object execute(IntrinsifiedPMemoryView self, Object ptr, int offset);
213241

@@ -274,74 +302,6 @@ static void doManaged(IntrinsifiedPMemoryView self, @SuppressWarnings("unused")
274302
}
275303
}
276304

277-
static abstract class CopyBytesNode extends Node {
278-
public abstract void execute(IntrinsifiedPMemoryView dest, Object destPtr, int destOffset, IntrinsifiedPMemoryView src, Object srcPtr, int srcOffset, int nbytes);
279-
280-
@Specialization(guards = {"destPtr == null", "srcPtr == null"})
281-
@SuppressWarnings("unused")
282-
static void managedToManaged(IntrinsifiedPMemoryView dest, Object destPtr, int destOffset, IntrinsifiedPMemoryView src, Object srcPtr, int srcOffset, int nbytes,
283-
@Cached SequenceNodes.GetSequenceStorageNode getSequenceStorageNode,
284-
@Cached SequenceStorageNodes.MemCopyNode memCopyNode) {
285-
// TODO assumes bytes storage
286-
SequenceStorage destStorage = getSequenceStorageNode.execute(dest.getOwner());
287-
SequenceStorage srcStorage = getSequenceStorageNode.execute(src.getOwner());
288-
memCopyNode.execute(destStorage, destOffset, srcStorage, srcOffset, nbytes);
289-
}
290-
291-
@Specialization(guards = {"destPtr != null", "srcPtr == null"})
292-
@SuppressWarnings("unused")
293-
static void managedToNative(IntrinsifiedPMemoryView dest, Object destPtr, int destOffset, IntrinsifiedPMemoryView src, Object srcPtr, int srcOffset, int nbytes,
294-
@Cached SequenceNodes.GetSequenceStorageNode getSequenceStorageNode,
295-
@Cached SequenceStorageNodes.GetItemScalarNode getItemNode,
296-
@Shared("lib") @CachedLibrary(limit = "1") InteropLibrary lib) {
297-
// TODO assumes bytes storage
298-
// TODO avoid byte->int conversion
299-
// TODO explode?
300-
SequenceStorage srcStorage = getSequenceStorageNode.execute(src.getOwner());
301-
try {
302-
for (int i = 0; i < nbytes; i++) {
303-
lib.writeArrayElement(destPtr, destOffset + i, (byte) getItemNode.executeInt(srcStorage, srcOffset + i));
304-
}
305-
} catch (UnsupportedMessageException | UnsupportedTypeException | InvalidArrayIndexException e) {
306-
throw CompilerDirectives.shouldNotReachHere(e);
307-
}
308-
}
309-
310-
@Specialization(guards = {"destPtr == null", "srcPtr != null"})
311-
@SuppressWarnings("unused")
312-
static void nativeToManaged(IntrinsifiedPMemoryView dest, Object destPtr, int destOffset, IntrinsifiedPMemoryView src, Object srcPtr, int srcOffset, int nbytes,
313-
@Cached SequenceNodes.GetSequenceStorageNode getSequenceStorageNode,
314-
@Cached SequenceStorageNodes.SetItemScalarNode setItemNode,
315-
@Shared("lib") @CachedLibrary(limit = "1") InteropLibrary lib) {
316-
// TODO assumes bytes storage
317-
// TODO avoid byte->int conversion
318-
// TODO explode?
319-
SequenceStorage destStorage = getSequenceStorageNode.execute(dest.getOwner());
320-
try {
321-
for (int i = 0; i < nbytes; i++) {
322-
setItemNode.execute(destStorage, (byte) lib.readArrayElement(srcPtr, srcOffset + i) & 0xFF, destOffset + i);
323-
}
324-
} catch (UnsupportedMessageException | InvalidArrayIndexException e) {
325-
throw CompilerDirectives.shouldNotReachHere(e);
326-
}
327-
}
328-
329-
@Specialization(guards = {"destPtr != null", "srcPtr != null"})
330-
@SuppressWarnings("unused")
331-
static void nativeToNative(IntrinsifiedPMemoryView dest, Object destPtr, int destOffset, IntrinsifiedPMemoryView src, Object srcPtr, int srcOffset, int nbytes,
332-
@Shared("lib") @CachedLibrary(limit = "1") InteropLibrary lib) {
333-
// TODO call native memcpy?
334-
// TODO explode?
335-
try {
336-
for (int i = 0; i < nbytes; i++) {
337-
lib.writeArrayElement(destPtr, destOffset + i, lib.readArrayElement(srcPtr, srcOffset + i));
338-
}
339-
} catch (UnsupportedMessageException | UnsupportedTypeException | InvalidArrayIndexException e) {
340-
throw CompilerDirectives.shouldNotReachHere(e);
341-
}
342-
}
343-
}
344-
345305
@ValueType
346306
static class MemoryPointer {
347307
public Object ptr;

0 commit comments

Comments
 (0)