Skip to content

Commit 4eed044

Browse files
committed
Implement 'mmap.__setitem__'.
1 parent eae7005 commit 4eed044

File tree

3 files changed

+336
-18
lines changed

3 files changed

+336
-18
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/mmap/MMapBuiltins.java

Lines changed: 143 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959

6060
import java.io.IOException;
6161
import java.nio.ByteBuffer;
62+
import java.nio.channels.Channel;
6263
import java.nio.channels.SeekableByteChannel;
6364
import java.util.List;
6465

@@ -69,6 +70,7 @@
6970
import com.oracle.graal.python.builtins.objects.PNone;
7071
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
7172
import com.oracle.graal.python.builtins.objects.bytes.PIBytesLike;
73+
import com.oracle.graal.python.builtins.objects.common.SequenceNodes;
7274
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
7375
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.NoGeneralizationNode;
7476
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.NormalizeIndexNode;
@@ -83,14 +85,16 @@
8385
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
8486
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
8587
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
86-
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
8788
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
8889
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
8990
import com.oracle.graal.python.nodes.util.CastToByteNode;
9091
import com.oracle.graal.python.nodes.util.CastToIndexNode;
9192
import com.oracle.graal.python.nodes.util.CastToJavaLongNode;
93+
import com.oracle.graal.python.nodes.util.ChannelNodes;
9294
import com.oracle.graal.python.nodes.util.ChannelNodes.ReadByteFromChannelNode;
9395
import com.oracle.graal.python.nodes.util.ChannelNodes.ReadFromChannelNode;
96+
import com.oracle.graal.python.nodes.util.ChannelNodes.WriteByteToChannelNode;
97+
import com.oracle.graal.python.nodes.util.ChannelNodes.WriteToChannelNode;
9498
import com.oracle.graal.python.runtime.sequence.storage.ByteSequenceStorage;
9599
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
96100
import com.oracle.truffle.api.CompilerDirectives;
@@ -102,10 +106,59 @@
102106
import com.oracle.truffle.api.dsl.TypeSystemReference;
103107
import com.oracle.truffle.api.frame.VirtualFrame;
104108
import com.oracle.truffle.api.profiles.BranchProfile;
109+
import com.oracle.truffle.api.profiles.ConditionProfile;
105110

106111
@CoreFunctions(extendClasses = PythonBuiltinClassType.PMMap)
107112
public class MMapBuiltins extends PythonBuiltins {
108113

114+
protected interface ByteReadingNode {
115+
116+
static ReadByteFromChannelNode createValueError() {
117+
return ReadByteFromChannelNode.create(() -> new ChannelNodes.ReadByteErrorHandler() {
118+
119+
@Override
120+
public int execute(Channel channel) {
121+
throw raise(PythonBuiltinClassType.ValueError, "read byte out of range");
122+
}
123+
});
124+
}
125+
126+
static ReadByteFromChannelNode createIndexError() {
127+
return ReadByteFromChannelNode.create(() -> new ChannelNodes.ReadByteErrorHandler() {
128+
129+
@Override
130+
public int execute(Channel channel) {
131+
throw raise(PythonBuiltinClassType.IndexError, "mmap index out of range");
132+
}
133+
});
134+
135+
}
136+
}
137+
138+
protected interface ByteWritingNode {
139+
140+
static WriteByteToChannelNode createValueError() {
141+
return WriteByteToChannelNode.create(() -> new ChannelNodes.WriteByteErrorHandler() {
142+
143+
@Override
144+
public void execute(Channel channel, byte b) {
145+
throw raise(PythonBuiltinClassType.ValueError, "write byte out of range");
146+
}
147+
});
148+
}
149+
150+
static WriteByteToChannelNode createIndexError() {
151+
return WriteByteToChannelNode.create(() -> new ChannelNodes.WriteByteErrorHandler() {
152+
153+
@Override
154+
public void execute(Channel channel, byte b) {
155+
throw raise(PythonBuiltinClassType.IndexError, "mmap index out of range");
156+
}
157+
});
158+
159+
}
160+
}
161+
109162
@Override
110163
protected List<? extends NodeFactory<? extends PythonBuiltinBaseNode>> getNodeFactories() {
111164
return MMapBuiltinsFactory.getFactories();
@@ -173,11 +226,11 @@ abstract static class ReprNode extends StrNode {
173226

174227
@Builtin(name = __GETITEM__, fixedNumOfPositionalArgs = 2)
175228
@GenerateNodeFactory
176-
abstract static class GetItemNode extends PythonBuiltinNode {
229+
abstract static class GetItemNode extends PythonBuiltinNode implements ByteReadingNode {
177230

178231
@Specialization(guards = "!isPSlice(idxObj)")
179232
int doSingle(VirtualFrame frame, PMMap self, Object idxObj,
180-
@Cached("create()") ReadByteFromChannelNode readByteNode,
233+
@Cached("createIndexError()") ReadByteFromChannelNode readByteNode,
181234
@Cached("createExact()") CastToJavaLongNode castToLongNode,
182235
@Cached("create()") InternalLenNode lenNode) {
183236

@@ -226,12 +279,82 @@ Object doSlice(VirtualFrame frame, PMMap self, PSlice idx,
226279
throw raise(PythonBuiltinClassType.OSError, e.getMessage());
227280
}
228281
}
282+
229283
}
230284

231285
@Builtin(name = SpecialMethodNames.__SETITEM__, fixedNumOfPositionalArgs = 3)
232286
@GenerateNodeFactory
233-
abstract static class SetItemNode extends PythonTernaryBuiltinNode {
287+
abstract static class SetItemNode extends PythonBuiltinNode implements ByteWritingNode {
288+
289+
@Specialization(guards = "!isPSlice(idxObj)")
290+
PNone doSingle(VirtualFrame frame, PMMap self, Object idxObj, Object val,
291+
@Cached("createIndexError()") WriteByteToChannelNode writeByteNode,
292+
@Cached("createExact()") CastToJavaLongNode castToLongNode,
293+
@Cached("createCoerce()") CastToByteNode castToByteNode,
294+
@Cached("create()") InternalLenNode lenNode,
295+
@Cached("createBinaryProfile()") ConditionProfile outOfRangeProfile) {
296+
297+
try {
298+
long i = castToLongNode.execute(idxObj);
299+
long len = lenNode.execute(frame, self);
300+
SeekableByteChannel channel = self.getChannel();
301+
long idx = i < 0 ? i + len : i;
302+
303+
if (outOfRangeProfile.profile(idx < 0 || idx >= len)) {
304+
throw raise(PythonBuiltinClassType.IndexError, "mmap index out of range");
305+
}
306+
307+
// save current position
308+
long oldPos = channel.position();
234309

310+
channel.position(idx);
311+
writeByteNode.execute(channel, castToByteNode.execute(val));
312+
313+
// restore position
314+
channel.position(oldPos);
315+
316+
return PNone.NONE;
317+
318+
} catch (IOException e) {
319+
throw raise(PythonBuiltinClassType.OSError, e.getMessage());
320+
}
321+
}
322+
323+
@Specialization
324+
PNone doSlice(VirtualFrame frame, PMMap self, PSlice idx, PIBytesLike val,
325+
@Cached("create()") WriteToChannelNode writeNode,
326+
@Cached("create()") SequenceNodes.GetSequenceStorageNode getStorageNode,
327+
@Cached("create()") InternalLenNode lenNode,
328+
@Cached("createBinaryProfile()") ConditionProfile invalidStepProfile) {
329+
330+
try {
331+
long len = lenNode.execute(frame, self);
332+
SliceInfo info = idx.computeIndices(PInt.intValueExact(len));
333+
SeekableByteChannel channel = self.getChannel();
334+
335+
if (invalidStepProfile.profile(info.step != 1)) {
336+
throw raise(PythonBuiltinClassType.SystemError, "step != 1 not supported");
337+
}
338+
339+
// save current position
340+
long oldPos = channel.position();
341+
342+
channel.position(info.start);
343+
writeNode.execute(channel, getStorageNode.execute(val), info.length);
344+
345+
// restore position
346+
channel.position(oldPos);
347+
348+
return PNone.NONE;
349+
350+
} catch (IOException e) {
351+
throw raise(PythonBuiltinClassType.OSError, e.getMessage());
352+
}
353+
}
354+
355+
protected static CastToByteNode createCoerce() {
356+
return CastToByteNode.create(true);
357+
}
235358
}
236359

237360
@Builtin(name = __LEN__, fixedNumOfPositionalArgs = 1)
@@ -280,6 +403,16 @@ PNone close(PMMap self) {
280403
}
281404
}
282405

406+
@Builtin(name = "closed", fixedNumOfPositionalArgs = 1, isGetter = true)
407+
@GenerateNodeFactory
408+
abstract static class ClosedNode extends PythonUnaryBuiltinNode {
409+
410+
@Specialization
411+
boolean close(PMMap self) {
412+
return !self.getChannel().isOpen();
413+
}
414+
}
415+
283416
@Builtin(name = "size", fixedNumOfPositionalArgs = 1)
284417
@GenerateNodeFactory
285418
abstract static class SizeNode extends PythonBuiltinNode {
@@ -309,11 +442,11 @@ long readline(VirtualFrame frame, PMMap self) {
309442
@Builtin(name = "read_byte", fixedNumOfPositionalArgs = 1)
310443
@GenerateNodeFactory
311444
@TypeSystemReference(PythonArithmeticTypes.class)
312-
abstract static class ReadByteNode extends PythonUnaryBuiltinNode {
445+
abstract static class ReadByteNode extends PythonUnaryBuiltinNode implements ByteReadingNode {
313446

314447
@Specialization
315448
int readByte(PMMap self,
316-
@Cached("create()") ReadByteFromChannelNode readByteNode) {
449+
@Cached("createValueError()") ReadByteFromChannelNode readByteNode) {
317450
return readByteNode.execute(self.getChannel());
318451
}
319452
}
@@ -337,6 +470,7 @@ PBytes read(PMMap self, Object n,
337470
ByteSequenceStorage res = readChannelNode.execute(self.getChannel(), castToIndexNode.execute(n));
338471
return factory().createBytes(res);
339472
}
473+
340474
}
341475

342476
@Builtin(name = "readline", fixedNumOfPositionalArgs = 1)
@@ -433,7 +567,7 @@ private Object doSeek(PMMap self, long dist, int how) throws IOException {
433567
@Builtin(name = "find", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 4)
434568
@GenerateNodeFactory
435569
@TypeSystemReference(PythonArithmeticTypes.class)
436-
public abstract static class FindNode extends PythonBuiltinNode {
570+
public abstract static class FindNode extends PythonBuiltinNode implements ByteReadingNode {
437571

438572
@Child private NormalizeIndexNode normalizeIndexNode;
439573
@Child private SequenceStorageNodes.GetItemNode getLeftItemNode;
@@ -443,7 +577,7 @@ public abstract static class FindNode extends PythonBuiltinNode {
443577

444578
@Specialization
445579
long find(PMMap primary, PIBytesLike sub, Object starting, Object ending,
446-
@Cached("create()") ReadByteFromChannelNode readByteNode) {
580+
@Cached("createValueError()") ReadByteFromChannelNode readByteNode) {
447581
try {
448582
SeekableByteChannel channel = primary.getChannel();
449583
long len1 = channel.size();
@@ -484,7 +618,7 @@ long find(PMMap primary, PIBytesLike sub, Object starting, Object ending,
484618

485619
@Specialization
486620
long find(PMMap primary, int sub, Object starting, @SuppressWarnings("unused") Object ending,
487-
@Cached("create()") ReadByteFromChannelNode readByteNode) {
621+
@Cached("createValueError()") ReadByteFromChannelNode readByteNode) {
488622
try {
489623
SeekableByteChannel channel = primary.getChannel();
490624
long len1 = channel.size();

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/util/CastToByteNode.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,12 @@
4545

4646
import java.util.function.Function;
4747

48+
import com.oracle.graal.python.builtins.objects.bytes.PIBytesLike;
49+
import com.oracle.graal.python.builtins.objects.common.SequenceNodes;
50+
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
4851
import com.oracle.graal.python.builtins.objects.ints.PInt;
4952
import com.oracle.graal.python.nodes.PNodeWithContext;
53+
import com.oracle.truffle.api.dsl.Cached;
5054
import com.oracle.truffle.api.dsl.Fallback;
5155
import com.oracle.truffle.api.dsl.Specialization;
5256

@@ -55,10 +59,12 @@ public abstract class CastToByteNode extends PNodeWithContext {
5559

5660
private final Function<Object, Byte> rangeErrorHandler;
5761
private final Function<Object, Byte> typeErrorHandler;
62+
protected final boolean coerce;
5863

59-
protected CastToByteNode(Function<Object, Byte> rangeErrorHandler, Function<Object, Byte> typeErrorHandler) {
64+
protected CastToByteNode(Function<Object, Byte> rangeErrorHandler, Function<Object, Byte> typeErrorHandler, boolean coerce) {
6065
this.rangeErrorHandler = rangeErrorHandler;
6166
this.typeErrorHandler = typeErrorHandler;
67+
this.coerce = coerce;
6268
}
6369

6470
public abstract byte execute(Object val);
@@ -115,6 +121,13 @@ protected byte doBoolean(boolean value) {
115121
return value ? (byte) 1 : (byte) 0;
116122
}
117123

124+
@Specialization(guards = "coerce")
125+
protected byte doBytes(PIBytesLike value,
126+
@Cached("create()") SequenceNodes.GetSequenceStorageNode getStorageNode,
127+
@Cached("create()") SequenceStorageNodes.GetItemNode getItemNode) {
128+
return doIntOvf(getItemNode.executeInt(getStorageNode.execute(value), 0));
129+
}
130+
118131
@Fallback
119132
protected byte doGeneric(@SuppressWarnings("unused") Object val) {
120133
if (typeErrorHandler != null) {
@@ -133,11 +146,18 @@ private byte handleRangeError(Object val) {
133146
}
134147

135148
public static CastToByteNode create() {
136-
return CastToByteNodeGen.create(null, null);
149+
return CastToByteNodeGen.create(null, null, false);
150+
}
151+
152+
public static CastToByteNode create(boolean coerce) {
153+
return CastToByteNodeGen.create(null, null, coerce);
137154
}
138155

139156
public static CastToByteNode create(Function<Object, Byte> rangeErrorHandler, Function<Object, Byte> typeErrorHandler) {
140157
return CastToByteNodeGen.create(rangeErrorHandler, typeErrorHandler);
141158
}
142159

160+
public static CastToByteNode create(Function<Object, Byte> rangeErrorHandler, Function<Object, Byte> typeErrorHandler, boolean coerce) {
161+
return CastToByteNodeGen.create(rangeErrorHandler, typeErrorHandler, coerce);
162+
}
143163
}

0 commit comments

Comments
 (0)