Skip to content

Commit eae7005

Browse files
committed
Implement 'mmap.find' and mman subscript.
1 parent 86bfe1f commit eae7005

File tree

2 files changed

+209
-17
lines changed

2 files changed

+209
-17
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/mmap/MMapBuiltins.java

Lines changed: 208 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,16 @@
6868
import com.oracle.graal.python.builtins.PythonBuiltins;
6969
import com.oracle.graal.python.builtins.objects.PNone;
7070
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
71+
import com.oracle.graal.python.builtins.objects.bytes.PIBytesLike;
7172
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
7273
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.NoGeneralizationNode;
74+
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.NormalizeIndexNode;
7375
import com.oracle.graal.python.builtins.objects.exception.OSErrorEnum;
76+
import com.oracle.graal.python.builtins.objects.ints.PInt;
77+
import com.oracle.graal.python.builtins.objects.mmap.MMapBuiltinsFactory.InternalLenNodeGen;
78+
import com.oracle.graal.python.builtins.objects.slice.PSlice;
79+
import com.oracle.graal.python.builtins.objects.slice.PSlice.SliceInfo;
80+
import com.oracle.graal.python.nodes.PNodeWithContext;
7481
import com.oracle.graal.python.nodes.SpecialMethodNames;
7582
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
7683
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
@@ -81,17 +88,20 @@
8188
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
8289
import com.oracle.graal.python.nodes.util.CastToByteNode;
8390
import com.oracle.graal.python.nodes.util.CastToIndexNode;
91+
import com.oracle.graal.python.nodes.util.CastToJavaLongNode;
8492
import com.oracle.graal.python.nodes.util.ChannelNodes.ReadByteFromChannelNode;
8593
import com.oracle.graal.python.nodes.util.ChannelNodes.ReadFromChannelNode;
8694
import com.oracle.graal.python.runtime.sequence.storage.ByteSequenceStorage;
95+
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
96+
import com.oracle.truffle.api.CompilerDirectives;
8797
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
8898
import com.oracle.truffle.api.dsl.Cached;
8999
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
90100
import com.oracle.truffle.api.dsl.NodeFactory;
91101
import com.oracle.truffle.api.dsl.Specialization;
92102
import com.oracle.truffle.api.dsl.TypeSystemReference;
93103
import com.oracle.truffle.api.frame.VirtualFrame;
94-
import com.oracle.truffle.api.profiles.ConditionProfile;
104+
import com.oracle.truffle.api.profiles.BranchProfile;
95105

96106
@CoreFunctions(extendClasses = PythonBuiltinClassType.PMMap)
97107
public class MMapBuiltins extends PythonBuiltins {
@@ -163,7 +173,59 @@ abstract static class ReprNode extends StrNode {
163173

164174
@Builtin(name = __GETITEM__, fixedNumOfPositionalArgs = 2)
165175
@GenerateNodeFactory
166-
abstract static class GetItemNode extends PythonBinaryBuiltinNode {
176+
abstract static class GetItemNode extends PythonBuiltinNode {
177+
178+
@Specialization(guards = "!isPSlice(idxObj)")
179+
int doSingle(VirtualFrame frame, PMMap self, Object idxObj,
180+
@Cached("create()") ReadByteFromChannelNode readByteNode,
181+
@Cached("createExact()") CastToJavaLongNode castToLongNode,
182+
@Cached("create()") InternalLenNode lenNode) {
183+
184+
try {
185+
long i = castToLongNode.execute(idxObj);
186+
long len = lenNode.execute(frame, self);
187+
SeekableByteChannel channel = self.getChannel();
188+
long idx = i < 0 ? i + len : i;
189+
190+
// save current position
191+
long oldPos = channel.position();
192+
193+
channel.position(idx);
194+
int res = readByteNode.execute(channel);
195+
196+
// restore position
197+
channel.position(oldPos);
198+
199+
return res;
200+
201+
} catch (IOException e) {
202+
throw raise(PythonBuiltinClassType.OSError, e.getMessage());
203+
}
204+
}
205+
206+
@Specialization
207+
Object doSlice(VirtualFrame frame, PMMap self, PSlice idx,
208+
@Cached("create()") ReadFromChannelNode readNode,
209+
@Cached("create()") InternalLenNode lenNode) {
210+
try {
211+
long len = lenNode.execute(frame, self);
212+
SliceInfo info = idx.computeIndices(PInt.intValueExact(len));
213+
SeekableByteChannel channel = self.getChannel();
214+
215+
// save current position
216+
long oldPos = channel.position();
217+
218+
channel.position(info.start);
219+
ByteSequenceStorage s = readNode.execute(channel, info.length);
220+
221+
// restore position
222+
channel.position(oldPos);
223+
224+
return factory().createBytes(s);
225+
} catch (IOException e) {
226+
throw raise(PythonBuiltinClassType.OSError, e.getMessage());
227+
}
228+
}
167229
}
168230

169231
@Builtin(name = SpecialMethodNames.__SETITEM__, fixedNumOfPositionalArgs = 3)
@@ -176,12 +238,9 @@ abstract static class SetItemNode extends PythonTernaryBuiltinNode {
176238
@GenerateNodeFactory
177239
public abstract static class LenNode extends PythonBuiltinNode {
178240
@Specialization
179-
long len(VirtualFrame frame, PMMap self) {
180-
try {
181-
return self.getChannel().size();
182-
} catch (IOException e) {
183-
throw raiseOSError(frame, OSErrorEnum.EIO, e.getMessage());
184-
}
241+
long len(VirtualFrame frame, PMMap self,
242+
@Cached("create()") InternalLenNode lenNode) {
243+
return lenNode.execute(frame, self);
185244
}
186245
}
187246

@@ -227,15 +286,8 @@ abstract static class SizeNode extends PythonBuiltinNode {
227286

228287
@Specialization
229288
long size(VirtualFrame frame, PMMap self,
230-
@Cached("createBinaryProfile()") ConditionProfile profile) {
231-
if (profile.profile(self.getLength() == 0)) {
232-
try {
233-
return self.getChannel().size() - self.getOffset();
234-
} catch (IOException e) {
235-
throw raiseOSError(frame, OSErrorEnum.EIO, e.getMessage());
236-
}
237-
}
238-
return self.getLength();
289+
@Cached("create()") InternalLenNode lenNode) {
290+
return lenNode.execute(frame, self);
239291
}
240292
}
241293

@@ -378,4 +430,143 @@ private Object doSeek(PMMap self, long dist, int how) throws IOException {
378430
}
379431
}
380432

433+
@Builtin(name = "find", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 4)
434+
@GenerateNodeFactory
435+
@TypeSystemReference(PythonArithmeticTypes.class)
436+
public abstract static class FindNode extends PythonBuiltinNode {
437+
438+
@Child private NormalizeIndexNode normalizeIndexNode;
439+
@Child private SequenceStorageNodes.GetItemNode getLeftItemNode;
440+
@Child private SequenceStorageNodes.GetItemNode getRightItemNode;
441+
442+
public abstract long execute(PMMap bytes, Object sub, Object starting, Object ending);
443+
444+
@Specialization
445+
long find(PMMap primary, PIBytesLike sub, Object starting, Object ending,
446+
@Cached("create()") ReadByteFromChannelNode readByteNode) {
447+
try {
448+
SeekableByteChannel channel = primary.getChannel();
449+
long len1 = channel.size();
450+
451+
SequenceStorage needle = sub.getSequenceStorage();
452+
int len2 = needle.length();
453+
454+
long s = castToLong(starting, 0);
455+
long e = castToLong(ending, len1);
456+
457+
long start = s < 0 ? s + len1 : s;
458+
long end = e < 0 ? e + len1 : e;
459+
460+
if (start >= len1 || len1 < len2) {
461+
return -1;
462+
} else if (end > len1) {
463+
end = len1;
464+
}
465+
466+
// TODO implement a more efficient algorithm
467+
outer: for (long i = start; i < end; i++) {
468+
// TODO(fa) don't seek but use circular buffer
469+
channel.position(i);
470+
for (int j = 0; j < len2; j++) {
471+
int hb = readByteNode.execute(channel);
472+
int nb = getGetRightItemNode().executeInt(needle, j);
473+
if (nb != hb || i + j >= end) {
474+
continue outer;
475+
}
476+
}
477+
return i;
478+
}
479+
return -1;
480+
} catch (IOException e) {
481+
throw raise(PythonBuiltinClassType.OSError, e.getMessage());
482+
}
483+
}
484+
485+
@Specialization
486+
long find(PMMap primary, int sub, Object starting, @SuppressWarnings("unused") Object ending,
487+
@Cached("create()") ReadByteFromChannelNode readByteNode) {
488+
try {
489+
SeekableByteChannel channel = primary.getChannel();
490+
long len1 = channel.size();
491+
492+
long s = castToLong(starting, 0);
493+
long e = castToLong(ending, len1);
494+
495+
long start = s < 0 ? s + len1 : s;
496+
long end = Math.max(e < 0 ? e + len1 : e, len1);
497+
498+
channel.position(start);
499+
500+
for (long i = start; i < end; i++) {
501+
int hb = readByteNode.execute(channel);
502+
if (hb == sub) {
503+
return i;
504+
}
505+
}
506+
return -1;
507+
} catch (IOException e) {
508+
throw raise(PythonBuiltinClassType.OSError, e.getMessage());
509+
}
510+
}
511+
512+
// TODO(fa): use node
513+
private static long castToLong(Object obj, long defaultVal) {
514+
if (obj instanceof Integer || obj instanceof Long) {
515+
return ((Number) obj).longValue();
516+
} else if (obj instanceof PInt) {
517+
try {
518+
return ((PInt) obj).longValueExact();
519+
} catch (ArithmeticException e) {
520+
return defaultVal;
521+
}
522+
}
523+
return defaultVal;
524+
}
525+
526+
private SequenceStorageNodes.GetItemNode getGetRightItemNode() {
527+
if (getRightItemNode == null) {
528+
CompilerDirectives.transferToInterpreterAndInvalidate();
529+
getRightItemNode = insert(SequenceStorageNodes.GetItemNode.create());
530+
}
531+
return getRightItemNode;
532+
}
533+
}
534+
535+
abstract static class InternalLenNode extends PNodeWithContext {
536+
537+
public abstract long execute(VirtualFrame frame, PMMap self);
538+
539+
@Specialization(guards = "self.getLength() == 0")
540+
long doFull(VirtualFrame frame, PMMap self,
541+
@Cached("create()") BranchProfile profile) {
542+
try {
543+
return self.getChannel().size() - self.getOffset();
544+
} catch (IOException e) {
545+
profile.enter();
546+
throw raiseOSError(frame, OSErrorEnum.EIO, e.getMessage());
547+
}
548+
}
549+
550+
@Specialization(guards = "self.getLength() > 0")
551+
long doWindow(@SuppressWarnings("unused") VirtualFrame frame, PMMap self) {
552+
return self.getLength();
553+
}
554+
555+
@Specialization
556+
long doGeneric(VirtualFrame frame, PMMap self) {
557+
if (self.getLength() == 0) {
558+
try {
559+
return self.getChannel().size() - self.getOffset();
560+
} catch (IOException e) {
561+
throw raiseOSError(frame, OSErrorEnum.EIO, e.getMessage());
562+
}
563+
}
564+
return self.getLength();
565+
}
566+
567+
public static InternalLenNode create() {
568+
return InternalLenNodeGen.create();
569+
}
570+
}
571+
381572
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/util/ChannelNodes.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ int readByte(ReadableByteChannel channel,
133133

134134
@TruffleBoundary(allowInlining = true)
135135
private static int get(ByteBuffer buf) {
136+
buf.flip();
136137
return buf.get();
137138
}
138139

0 commit comments

Comments
 (0)