Skip to content

Commit 585e5e4

Browse files
committed
Refactor bytes.(starts|ends)with to work in more generic cases
Based on the version for string Fixes #114
1 parent 9fb8a9a commit 585e5e4

File tree

1 file changed

+194
-55
lines changed
  • graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes

1 file changed

+194
-55
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes/BytesBuiltins.java

Lines changed: 194 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
import com.oracle.graal.python.builtins.objects.PNotImplemented;
6161
import com.oracle.graal.python.builtins.objects.bytes.BytesBuiltinsFactory.BytesLikeNoGeneralizationNodeGen;
6262
import com.oracle.graal.python.builtins.objects.common.IndexNodes.NormalizeIndexNode;
63+
import com.oracle.graal.python.builtins.objects.common.SequenceNodes.GetObjectArrayNode;
64+
import com.oracle.graal.python.builtins.objects.common.SequenceNodesFactory.GetObjectArrayNodeGen;
6365
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
6466
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.GenNodeSupplier;
6567
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.GeneralizationNode;
@@ -73,22 +75,21 @@
7375
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
7476
import com.oracle.graal.python.builtins.objects.type.LazyPythonClass;
7577
import com.oracle.graal.python.builtins.objects.type.TypeNodes;
78+
import com.oracle.graal.python.nodes.PGuards;
7679
import com.oracle.graal.python.nodes.PRaiseNode;
7780
import com.oracle.graal.python.nodes.SpecialMethodNames;
7881
import com.oracle.graal.python.nodes.argument.ReadArgumentNode;
7982
import com.oracle.graal.python.nodes.builtins.ListNodes.AppendNode;
8083
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
81-
import com.oracle.graal.python.nodes.control.GetIteratorExpressionNode.GetIteratorNode;
82-
import com.oracle.graal.python.nodes.control.GetNextNode;
8384
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
8485
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
8586
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
87+
import com.oracle.graal.python.nodes.function.builtins.PythonQuaternaryBuiltinNode;
8688
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
8789
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
88-
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
90+
import com.oracle.graal.python.nodes.subscript.SliceLiteralNode.CastToSliceComponentNode;
8991
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
9092
import com.oracle.graal.python.nodes.util.CastToByteNode;
91-
import com.oracle.graal.python.runtime.exception.PException;
9293
import com.oracle.graal.python.runtime.exception.PythonErrorType;
9394
import com.oracle.graal.python.runtime.sequence.storage.ByteSequenceStorage;
9495
import com.oracle.graal.python.runtime.sequence.storage.IntSequenceStorage;
@@ -564,80 +565,218 @@ PSequenceIterator contains(PIBytesLike self) {
564565
}
565566
}
566567

567-
@Builtin(name = "startswith", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 4)
568-
@GenerateNodeFactory
569-
abstract static class StartsWithNode extends PythonBuiltinNode {
570-
@Child private SequenceStorageNodes.LenNode lenNode;
568+
abstract static class PrefixSuffixBaseNode extends PythonQuaternaryBuiltinNode {
569+
570+
@Child private CastToSliceComponentNode castSliceComponentNode;
571+
@Child private GetObjectArrayNode getObjectArrayNode;
572+
573+
// common and specialized cases --------------------
571574

572575
@Specialization
573-
boolean startswith(VirtualFrame frame, PByteArray self, PTuple prefixes, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end,
574-
@Cached GetIteratorNode getIteratorNode,
575-
@Cached IsBuiltinClassProfile errorProfile,
576-
@Cached GetNextNode getNextNode,
577-
@Cached BytesNodes.FindNode findNode) {
578-
Object iterator = getIteratorNode.executeWith(frame, prefixes);
579-
while (true) {
580-
try {
581-
Object arrayObj = getNextNode.execute(frame, iterator);
582-
if (arrayObj instanceof PIBytesLike) {
583-
PIBytesLike array = (PIBytesLike) arrayObj;
584-
if (startswith(frame, self, array, start, end, findNode)) {
585-
return true;
586-
}
587-
} else {
588-
throw raise(PythonBuiltinClassType.TypeError, "a bytes-like object is required, not '%p'", arrayObj);
589-
}
590-
} catch (PException e) {
591-
e.expectStopIteration(errorProfile);
592-
return false;
593-
}
594-
}
576+
boolean doPrefixStartEnd(PIBytesLike self, PIBytesLike substr, int start, int end,
577+
@Shared("toByteArrayNode") @Cached SequenceStorageNodes.ToByteArrayNode toByteArrayNode) {
578+
byte[] bytes = toByteArrayNode.execute(self.getSequenceStorage());
579+
byte[] substrBytes = toByteArrayNode.execute(substr.getSequenceStorage());
580+
int len = bytes.length;
581+
return doIt(bytes, substrBytes, adjustStart(start, len), adjustStart(end, len));
595582
}
596583

597584
@Specialization
598-
boolean startswith(VirtualFrame frame, PIBytesLike self, PIBytesLike prefix, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end,
599-
@Cached("create()") BytesNodes.FindNode findNode) {
600-
return findNode.execute(frame, self, prefix, 0, getLength(self.getSequenceStorage())) == 0;
585+
boolean doPrefixStart(PIBytesLike self, PIBytesLike substr, int start, @SuppressWarnings("unused") PNone end,
586+
@Shared("toByteArrayNode") @Cached SequenceStorageNodes.ToByteArrayNode toByteArrayNode) {
587+
byte[] bytes = toByteArrayNode.execute(self.getSequenceStorage());
588+
byte[] substrBytes = toByteArrayNode.execute(substr.getSequenceStorage());
589+
int len = bytes.length;
590+
return doIt(bytes, substrBytes, adjustStart(start, len), len);
601591
}
602592

603593
@Specialization
604-
boolean startswith(VirtualFrame frame, PIBytesLike self, PIBytesLike prefix, int start, @SuppressWarnings("unused") PNone end,
605-
@Cached("create()") BytesNodes.FindNode findNode) {
606-
return findNode.execute(frame, self, prefix, start, getLength(self.getSequenceStorage())) == start;
594+
boolean doPrefix(PIBytesLike self, PIBytesLike substr, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end,
595+
@Shared("toByteArrayNode") @Cached SequenceStorageNodes.ToByteArrayNode toByteArrayNode) {
596+
byte[] bytes = toByteArrayNode.execute(self.getSequenceStorage());
597+
byte[] substrBytes = toByteArrayNode.execute(substr.getSequenceStorage());
598+
return doIt(bytes, substrBytes, 0, bytes.length);
607599
}
608600

609601
@Specialization
610-
boolean startswith(VirtualFrame frame, PIBytesLike self, PIBytesLike prefix, int start, int end,
611-
@Cached("create()") BytesNodes.FindNode findNode) {
612-
return findNode.execute(frame, self, prefix, start, end) == start;
602+
boolean doTuplePrefixStartEnd(PIBytesLike self, PTuple substrs, int start, int end,
603+
@Shared("toByteArrayNode") @Cached SequenceStorageNodes.ToByteArrayNode toByteArrayNode) {
604+
byte[] bytes = toByteArrayNode.execute(self.getSequenceStorage());
605+
int len = bytes.length;
606+
return doIt(bytes, substrs, adjustStart(start, len), adjustStart(end, len), toByteArrayNode);
613607
}
614608

615-
private int getLength(SequenceStorage s) {
616-
if (lenNode == null) {
609+
@Specialization
610+
boolean doTuplePrefixStart(PIBytesLike self, PTuple substrs, int start, @SuppressWarnings("unused") PNone end,
611+
@Shared("toByteArrayNode") @Cached SequenceStorageNodes.ToByteArrayNode toByteArrayNode) {
612+
byte[] bytes = toByteArrayNode.execute(self.getSequenceStorage());
613+
int len = bytes.length;
614+
return doIt(bytes, substrs, adjustStart(start, len), len, toByteArrayNode);
615+
}
616+
617+
@Specialization
618+
boolean doTuplePrefix(PIBytesLike self, PTuple substrs, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end,
619+
@Shared("toByteArrayNode") @Cached SequenceStorageNodes.ToByteArrayNode toByteArrayNode) {
620+
byte[] bytes = toByteArrayNode.execute(self.getSequenceStorage());
621+
return doIt(bytes, substrs, 0, bytes.length, toByteArrayNode);
622+
}
623+
624+
// generic cases --------------------
625+
626+
@Specialization(replaces = {"doPrefixStartEnd", "doPrefixStart", "doPrefix"})
627+
boolean doPrefixGeneric(VirtualFrame frame, PIBytesLike self, PIBytesLike substr, Object start, Object end,
628+
@Shared("toByteArrayNode") @Cached SequenceStorageNodes.ToByteArrayNode toByteArrayNode) {
629+
byte[] bytes = toByteArrayNode.execute(self.getSequenceStorage());
630+
byte[] substrBytes = toByteArrayNode.execute(substr.getSequenceStorage());
631+
int len = bytes.length;
632+
int istart = PGuards.isPNone(start) ? 0 : castSlicePart(frame, start);
633+
int iend = PGuards.isPNone(end) ? len : adjustEnd(castSlicePart(frame, end), len);
634+
return doIt(bytes, substrBytes, adjustStart(istart, len), adjustStart(iend, len));
635+
}
636+
637+
@Specialization(replaces = {"doTuplePrefixStartEnd", "doTuplePrefixStart", "doTuplePrefix"})
638+
boolean doTuplePrefixGeneric(VirtualFrame frame, PIBytesLike self, PTuple substrs, Object start, Object end,
639+
@Shared("toByteArrayNode") @Cached SequenceStorageNodes.ToByteArrayNode toByteArrayNode) {
640+
byte[] bytes = toByteArrayNode.execute(self.getSequenceStorage());
641+
int len = bytes.length;
642+
int istart = PGuards.isPNone(start) ? 0 : castSlicePart(frame, start);
643+
int iend = PGuards.isPNone(end) ? len : adjustEnd(castSlicePart(frame, end), len);
644+
return doIt(bytes, substrs, adjustStart(istart, len), adjustStart(iend, len), toByteArrayNode);
645+
}
646+
647+
@Specialization(guards = {"!isBytes(substr)", "!isPTuple(substr)"})
648+
boolean doGeneric(@SuppressWarnings("unused") PIBytesLike self, Object substr, @SuppressWarnings("unused") Object start,
649+
@SuppressWarnings("unused") Object end) {
650+
throw raise(TypeError, "first arg must be bytes or a tuple of bytes, not %p", substr);
651+
}
652+
653+
@Specialization(guards = "!isBytes(self)")
654+
boolean doGeneric(@SuppressWarnings("unused") Object self, @SuppressWarnings("unused") Object substr,
655+
@SuppressWarnings("unused") Object start, @SuppressWarnings("unused") Object end) {
656+
throw raise(TypeError, "Method requires a 'bytes' object, got '%p'", self);
657+
}
658+
659+
// the actual operation; will be overridden by subclasses
660+
protected boolean doIt(byte[] bytes, byte[] prefix, int start, int end) {
661+
CompilerDirectives.transferToInterpreter();
662+
throw new IllegalStateException("should not reach");
663+
}
664+
665+
private boolean doIt(byte[] self, PTuple substrs, int start, int stop, SequenceStorageNodes.ToByteArrayNode toByteArrayNode) {
666+
for (Object element : ensureGetObjectArrayNode().execute(substrs)) {
667+
if (element instanceof PIBytesLike) {
668+
if (doIt(self, toByteArrayNode.execute(((PIBytesLike) element).getSequenceStorage()), start, stop)) {
669+
return true;
670+
}
671+
} else {
672+
throw raise(TypeError, getErrorMessage(), element);
673+
}
674+
}
675+
return false;
676+
}
677+
678+
protected String getErrorMessage() {
679+
CompilerDirectives.transferToInterpreter();
680+
throw new IllegalStateException("should not reach");
681+
}
682+
683+
// helper methods --------------------
684+
685+
// for semantics, see macro 'ADJUST_INDICES' in CPython's 'unicodeobject.c'
686+
static int adjustStart(int start, int length) {
687+
if (start < 0) {
688+
int adjusted = start + length;
689+
return adjusted < 0 ? 0 : adjusted;
690+
}
691+
return start;
692+
}
693+
694+
// for semantics, see macro 'ADJUST_INDICES' in CPython's 'unicodeobject.c'
695+
static int adjustEnd(int end, int length) {
696+
if (end > length) {
697+
return length;
698+
}
699+
return adjustStart(end, length);
700+
}
701+
702+
private int castSlicePart(VirtualFrame frame, Object idx) {
703+
if (castSliceComponentNode == null) {
617704
CompilerDirectives.transferToInterpreterAndInvalidate();
618-
lenNode = insert(SequenceStorageNodes.LenNode.create());
705+
// None should map to 0, overflow to the maximum integer
706+
castSliceComponentNode = insert(CastToSliceComponentNode.create(0, Integer.MAX_VALUE));
619707
}
620-
return lenNode.execute(s);
708+
return castSliceComponentNode.execute(frame, idx);
709+
}
710+
711+
private GetObjectArrayNode ensureGetObjectArrayNode() {
712+
if (getObjectArrayNode == null) {
713+
CompilerDirectives.transferToInterpreterAndInvalidate();
714+
getObjectArrayNode = insert(GetObjectArrayNodeGen.create());
715+
}
716+
return getObjectArrayNode;
621717
}
622718
}
623719

624-
@Builtin(name = "endswith", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 4)
720+
// bytes.startswith(prefix[, start[, end]])
721+
// bytearray.startswith(prefix[, start[, end]])
722+
@Builtin(name = "startswith", minNumOfPositionalArgs = 2, parameterNames = {"self", "prefix", "start", "end"})
625723
@GenerateNodeFactory
626-
abstract static class EndsWithNode extends PythonBuiltinNode {
627-
@Child private SequenceStorageNodes.LenNode lenNode;
724+
public abstract static class StartsWithNode extends PrefixSuffixBaseNode {
628725

629-
@Specialization
630-
boolean endswith(VirtualFrame frame, PIBytesLike self, PIBytesLike suffix, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end,
631-
@Cached("create()") BytesNodes.FindNode findNode) {
632-
return findNode.execute(frame, self, suffix, getLength(self.getSequenceStorage()) - getLength(suffix.getSequenceStorage()), getLength(self.getSequenceStorage())) != -1;
726+
private static final String INVALID_ELEMENT_TYPE = "a bytes-like object is required, not '%p'";
727+
728+
@Override
729+
protected boolean doIt(byte[] bytes, byte[] prefix, int start, int end) {
730+
// start and end must be normalized indices for 'bytes'
731+
assert start >= 0;
732+
assert end >= 0 && end <= bytes.length;
733+
734+
if (end - start < prefix.length) {
735+
return false;
736+
}
737+
for (int i = 0; i < prefix.length; i++) {
738+
if (bytes[start + i] != prefix[i]) {
739+
return false;
740+
}
741+
}
742+
return true;
633743
}
634744

635-
private int getLength(SequenceStorage s) {
636-
if (lenNode == null) {
637-
CompilerDirectives.transferToInterpreterAndInvalidate();
638-
lenNode = insert(SequenceStorageNodes.LenNode.create());
745+
@Override
746+
protected String getErrorMessage() {
747+
return INVALID_ELEMENT_TYPE;
748+
}
749+
}
750+
751+
// bytes.endswith(suffix[, start[, end]])
752+
// bytearray.endswith(suffix[, start[, end]])
753+
@Builtin(name = "endswith", minNumOfPositionalArgs = 2, parameterNames = {"self", "suffix", "start", "end"})
754+
@GenerateNodeFactory
755+
public abstract static class EndsWithNode extends PrefixSuffixBaseNode {
756+
757+
private static final String INVALID_ELEMENT_TYPE = "a bytes-like object is required, not '%p'";
758+
759+
@Override
760+
protected boolean doIt(byte[] bytes, byte[] suffix, int start, int end) {
761+
// start and end must be normalized indices for 'bytes'
762+
assert start >= 0;
763+
assert end >= 0 && end <= bytes.length;
764+
765+
int suffixLen = suffix.length;
766+
if (end - start < suffixLen) {
767+
return false;
639768
}
640-
return lenNode.execute(s);
769+
for (int i = 0; i < suffix.length; i++) {
770+
if (bytes[end - suffixLen + i] != suffix[i]) {
771+
return false;
772+
}
773+
}
774+
return true;
775+
}
776+
777+
@Override
778+
protected String getErrorMessage() {
779+
return INVALID_ELEMENT_TYPE;
641780
}
642781
}
643782

0 commit comments

Comments
 (0)