Skip to content

Commit ecc779c

Browse files
committed
[GR-21355] Refactor bytes.(starts|ends)with to work in more generic cases
PullRequest: graalpython/826
2 parents 96458bc + db5c968 commit ecc779c

File tree

2 files changed

+198
-58
lines changed

2 files changed

+198
-58
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes/BytesBuiltins.java

Lines changed: 197 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
import com.oracle.graal.python.builtins.objects.PNotImplemented;
6161
import com.oracle.graal.python.builtins.objects.bytes.BytesBuiltinsFactory.BytesLikeNoGeneralizationNodeGen;
6262
import com.oracle.graal.python.builtins.objects.common.IndexNodes.NormalizeIndexNode;
63+
import com.oracle.graal.python.builtins.objects.common.SequenceNodes.GetObjectArrayNode;
64+
import com.oracle.graal.python.builtins.objects.common.SequenceNodesFactory.GetObjectArrayNodeGen;
6365
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
6466
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.GenNodeSupplier;
6567
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.GeneralizationNode;
@@ -73,22 +75,21 @@
7375
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
7476
import com.oracle.graal.python.builtins.objects.type.LazyPythonClass;
7577
import com.oracle.graal.python.builtins.objects.type.TypeNodes;
78+
import com.oracle.graal.python.nodes.PGuards;
7679
import com.oracle.graal.python.nodes.PRaiseNode;
7780
import com.oracle.graal.python.nodes.SpecialMethodNames;
7881
import com.oracle.graal.python.nodes.argument.ReadArgumentNode;
7982
import com.oracle.graal.python.nodes.builtins.ListNodes.AppendNode;
8083
import com.oracle.graal.python.nodes.call.special.LookupAndCallUnaryNode;
81-
import com.oracle.graal.python.nodes.control.GetIteratorExpressionNode.GetIteratorNode;
82-
import com.oracle.graal.python.nodes.control.GetNextNode;
8384
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
8485
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
8586
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
87+
import com.oracle.graal.python.nodes.function.builtins.PythonQuaternaryBuiltinNode;
8688
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
8789
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
88-
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
90+
import com.oracle.graal.python.nodes.subscript.SliceLiteralNode.CastToSliceComponentNode;
8991
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
9092
import com.oracle.graal.python.nodes.util.CastToByteNode;
91-
import com.oracle.graal.python.runtime.exception.PException;
9293
import com.oracle.graal.python.runtime.exception.PythonErrorType;
9394
import com.oracle.graal.python.runtime.sequence.storage.ByteSequenceStorage;
9495
import com.oracle.graal.python.runtime.sequence.storage.IntSequenceStorage;
@@ -106,6 +107,7 @@
106107
import com.oracle.truffle.api.dsl.Specialization;
107108
import com.oracle.truffle.api.dsl.TypeSystemReference;
108109
import com.oracle.truffle.api.frame.VirtualFrame;
110+
import com.oracle.truffle.api.interop.UnsupportedMessageException;
109111
import com.oracle.truffle.api.library.CachedLibrary;
110112
import com.oracle.truffle.api.nodes.Node;
111113
import com.oracle.truffle.api.profiles.BranchProfile;
@@ -564,80 +566,218 @@ PSequenceIterator contains(PIBytesLike self) {
564566
}
565567
}
566568

567-
@Builtin(name = "startswith", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 4)
568-
@GenerateNodeFactory
569-
abstract static class StartsWithNode extends PythonBuiltinNode {
570-
@Child private SequenceStorageNodes.LenNode lenNode;
569+
abstract static class PrefixSuffixBaseNode extends PythonQuaternaryBuiltinNode {
571570

572-
@Specialization
573-
boolean startswith(VirtualFrame frame, PByteArray self, PTuple prefixes, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end,
574-
@Cached GetIteratorNode getIteratorNode,
575-
@Cached IsBuiltinClassProfile errorProfile,
576-
@Cached GetNextNode getNextNode,
577-
@Cached BytesNodes.FindNode findNode) {
578-
Object iterator = getIteratorNode.executeWith(frame, prefixes);
579-
while (true) {
580-
try {
581-
Object arrayObj = getNextNode.execute(frame, iterator);
582-
if (arrayObj instanceof PIBytesLike) {
583-
PIBytesLike array = (PIBytesLike) arrayObj;
584-
if (startswith(frame, self, array, start, end, findNode)) {
585-
return true;
586-
}
587-
} else {
588-
throw raise(PythonBuiltinClassType.TypeError, "a bytes-like object is required, not '%p'", arrayObj);
589-
}
590-
} catch (PException e) {
591-
e.expectStopIteration(errorProfile);
592-
return false;
571+
@Child private CastToSliceComponentNode castSliceComponentNode;
572+
@Child private GetObjectArrayNode getObjectArrayNode;
573+
574+
private static final String INVALID_RECEIVER = "Method requires a 'bytes' object, got '%p'";
575+
private static final String INVALID_FIRST_ARG = "first arg must be bytes or a tuple of bytes, not %p";
576+
private static final String INVALID_ELEMENT_TYPE = "a bytes-like object is required, not '%p'";
577+
578+
// common and specialized cases --------------------
579+
580+
@Specialization(guards = "!isPTuple(substr)", limit = "2")
581+
boolean doPrefixStartEnd(PIBytesLike self, Object substr, int start, int end,
582+
@CachedLibrary("self") PythonObjectLibrary lib,
583+
@CachedLibrary("substr") PythonObjectLibrary substrLib) {
584+
byte[] bytes = getBytes(lib, self);
585+
byte[] substrBytes = getBytes(substrLib, substr, INVALID_FIRST_ARG);
586+
int len = bytes.length;
587+
return doIt(bytes, substrBytes, adjustStart(start, len), adjustStart(end, len));
588+
}
589+
590+
@Specialization(guards = "!isPTuple(substr)", limit = "2")
591+
boolean doPrefixStart(PIBytesLike self, Object substr, int start, @SuppressWarnings("unused") PNone end,
592+
@CachedLibrary("self") PythonObjectLibrary lib,
593+
@CachedLibrary("substr") PythonObjectLibrary substrLib) {
594+
byte[] bytes = getBytes(lib, self);
595+
byte[] substrBytes = getBytes(substrLib, substr, INVALID_FIRST_ARG);
596+
int len = bytes.length;
597+
return doIt(bytes, substrBytes, adjustStart(start, len), len);
598+
}
599+
600+
@Specialization(guards = "!isPTuple(substr)", limit = "2")
601+
boolean doPrefix(PIBytesLike self, Object substr, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end,
602+
@CachedLibrary("self") PythonObjectLibrary lib,
603+
@CachedLibrary("substr") PythonObjectLibrary substrLib) {
604+
byte[] bytes = getBytes(lib, self);
605+
byte[] substrBytes = getBytes(substrLib, substr, INVALID_FIRST_ARG);
606+
return doIt(bytes, substrBytes, 0, bytes.length);
607+
}
608+
609+
@Specialization(limit = "2")
610+
boolean doTuplePrefixStartEnd(PIBytesLike self, PTuple substrs, int start, int end,
611+
@CachedLibrary("self") PythonObjectLibrary lib,
612+
@CachedLibrary(limit = "16") PythonObjectLibrary substrLib) {
613+
byte[] bytes = getBytes(lib, self);
614+
int len = bytes.length;
615+
return doIt(bytes, substrs, adjustStart(start, len), adjustStart(end, len), substrLib);
616+
}
617+
618+
@Specialization(limit = "2")
619+
boolean doTuplePrefixStart(PIBytesLike self, PTuple substrs, int start, @SuppressWarnings("unused") PNone end,
620+
@CachedLibrary("self") PythonObjectLibrary lib,
621+
@CachedLibrary(limit = "16") PythonObjectLibrary substrLib) {
622+
byte[] bytes = getBytes(lib, self);
623+
int len = bytes.length;
624+
return doIt(bytes, substrs, adjustStart(start, len), len, substrLib);
625+
}
626+
627+
@Specialization(limit = "2")
628+
boolean doTuplePrefix(PIBytesLike self, PTuple substrs, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end,
629+
@CachedLibrary("self") PythonObjectLibrary lib,
630+
@CachedLibrary(limit = "16") PythonObjectLibrary substrLib) {
631+
byte[] bytes = getBytes(lib, self);
632+
return doIt(bytes, substrs, 0, bytes.length, substrLib);
633+
}
634+
635+
// generic cases --------------------
636+
637+
@Specialization(guards = "!isPTuple(substr)", replaces = {"doPrefixStartEnd", "doPrefixStart", "doPrefix"}, limit = "200")
638+
boolean doPrefixGeneric(VirtualFrame frame, PIBytesLike self, Object substr, Object start, Object end,
639+
@CachedLibrary("self") PythonObjectLibrary lib,
640+
@CachedLibrary("substr") PythonObjectLibrary substrLib) {
641+
byte[] bytes = getBytes(lib, self);
642+
byte[] substrBytes = getBytes(substrLib, substr, INVALID_FIRST_ARG);
643+
int len = bytes.length;
644+
int istart = PGuards.isPNone(start) ? 0 : castSlicePart(frame, start);
645+
int iend = PGuards.isPNone(end) ? len : adjustEnd(castSlicePart(frame, end), len);
646+
return doIt(bytes, substrBytes, adjustStart(istart, len), adjustStart(iend, len));
647+
}
648+
649+
@Specialization(replaces = {"doTuplePrefixStartEnd", "doTuplePrefixStart", "doTuplePrefix"}, limit = "2")
650+
boolean doTuplePrefixGeneric(VirtualFrame frame, PIBytesLike self, PTuple substrs, Object start, Object end,
651+
@CachedLibrary("self") PythonObjectLibrary lib,
652+
@CachedLibrary(limit = "16") PythonObjectLibrary substrLib) {
653+
byte[] bytes = getBytes(lib, self);
654+
int len = bytes.length;
655+
int istart = PGuards.isPNone(start) ? 0 : castSlicePart(frame, start);
656+
int iend = PGuards.isPNone(end) ? len : adjustEnd(castSlicePart(frame, end), len);
657+
return doIt(bytes, substrs, adjustStart(istart, len), adjustStart(iend, len), substrLib);
658+
}
659+
660+
@Specialization(guards = "!isBytes(self)")
661+
boolean doGeneric(@SuppressWarnings("unused") Object self, @SuppressWarnings("unused") Object substr,
662+
@SuppressWarnings("unused") Object start, @SuppressWarnings("unused") Object end) {
663+
throw raise(TypeError, INVALID_RECEIVER, self);
664+
}
665+
666+
// the actual operation; will be overridden by subclasses
667+
protected boolean doIt(byte[] bytes, byte[] prefix, int start, int end) {
668+
CompilerDirectives.transferToInterpreter();
669+
throw new IllegalStateException("should not reach");
670+
}
671+
672+
private boolean doIt(byte[] self, PTuple substrs, int start, int stop, PythonObjectLibrary lib) {
673+
for (Object element : ensureGetObjectArrayNode().execute(substrs)) {
674+
byte[] bytes = getBytes(lib, element, INVALID_ELEMENT_TYPE);
675+
if (doIt(self, bytes, start, stop)) {
676+
return true;
593677
}
594678
}
679+
return false;
595680
}
596681

597-
@Specialization
598-
boolean startswith(VirtualFrame frame, PIBytesLike self, PIBytesLike prefix, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end,
599-
@Cached("create()") BytesNodes.FindNode findNode) {
600-
return findNode.execute(frame, self, prefix, 0, getLength(self.getSequenceStorage())) == 0;
682+
private byte[] getBytes(PythonObjectLibrary lib, Object object) {
683+
try {
684+
return lib.getBufferBytes(object);
685+
} catch (UnsupportedMessageException e) {
686+
CompilerDirectives.transferToInterpreter();
687+
throw new IllegalStateException(e);
688+
}
601689
}
602690

603-
@Specialization
604-
boolean startswith(VirtualFrame frame, PIBytesLike self, PIBytesLike prefix, int start, @SuppressWarnings("unused") PNone end,
605-
@Cached("create()") BytesNodes.FindNode findNode) {
606-
return findNode.execute(frame, self, prefix, start, getLength(self.getSequenceStorage())) == start;
691+
private byte[] getBytes(PythonObjectLibrary lib, Object object, String errorMessage) {
692+
if (!lib.isBuffer(object)) {
693+
throw raise(TypeError, errorMessage, object);
694+
}
695+
return getBytes(lib, object);
607696
}
608697

609-
@Specialization
610-
boolean startswith(VirtualFrame frame, PIBytesLike self, PIBytesLike prefix, int start, int end,
611-
@Cached("create()") BytesNodes.FindNode findNode) {
612-
return findNode.execute(frame, self, prefix, start, end) == start;
698+
// helper methods --------------------
699+
700+
// for semantics, see macro 'ADJUST_INDICES' in CPython's 'unicodeobject.c'
701+
static int adjustStart(int start, int length) {
702+
if (start < 0) {
703+
int adjusted = start + length;
704+
return adjusted < 0 ? 0 : adjusted;
705+
}
706+
return start;
613707
}
614708

615-
private int getLength(SequenceStorage s) {
616-
if (lenNode == null) {
709+
// for semantics, see macro 'ADJUST_INDICES' in CPython's 'unicodeobject.c'
710+
static int adjustEnd(int end, int length) {
711+
if (end > length) {
712+
return length;
713+
}
714+
return adjustStart(end, length);
715+
}
716+
717+
private int castSlicePart(VirtualFrame frame, Object idx) {
718+
if (castSliceComponentNode == null) {
617719
CompilerDirectives.transferToInterpreterAndInvalidate();
618-
lenNode = insert(SequenceStorageNodes.LenNode.create());
720+
// None should map to 0, overflow to the maximum integer
721+
castSliceComponentNode = insert(CastToSliceComponentNode.create(0, Integer.MAX_VALUE));
619722
}
620-
return lenNode.execute(s);
723+
return castSliceComponentNode.execute(frame, idx);
724+
}
725+
726+
private GetObjectArrayNode ensureGetObjectArrayNode() {
727+
if (getObjectArrayNode == null) {
728+
CompilerDirectives.transferToInterpreterAndInvalidate();
729+
getObjectArrayNode = insert(GetObjectArrayNodeGen.create());
730+
}
731+
return getObjectArrayNode;
621732
}
622733
}
623734

624-
@Builtin(name = "endswith", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 4)
735+
// bytes.startswith(prefix[, start[, end]])
736+
// bytearray.startswith(prefix[, start[, end]])
737+
@Builtin(name = "startswith", minNumOfPositionalArgs = 2, parameterNames = {"self", "prefix", "start", "end"})
625738
@GenerateNodeFactory
626-
abstract static class EndsWithNode extends PythonBuiltinNode {
627-
@Child private SequenceStorageNodes.LenNode lenNode;
739+
public abstract static class StartsWithNode extends PrefixSuffixBaseNode {
628740

629-
@Specialization
630-
boolean endswith(VirtualFrame frame, PIBytesLike self, PIBytesLike suffix, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end,
631-
@Cached("create()") BytesNodes.FindNode findNode) {
632-
return findNode.execute(frame, self, suffix, getLength(self.getSequenceStorage()) - getLength(suffix.getSequenceStorage()), getLength(self.getSequenceStorage())) != -1;
741+
@Override
742+
protected boolean doIt(byte[] bytes, byte[] prefix, int start, int end) {
743+
// start and end must be normalized indices for 'bytes'
744+
assert start >= 0;
745+
assert end >= 0 && end <= bytes.length;
746+
747+
if (end - start < prefix.length) {
748+
return false;
749+
}
750+
for (int i = 0; i < prefix.length; i++) {
751+
if (bytes[start + i] != prefix[i]) {
752+
return false;
753+
}
754+
}
755+
return true;
633756
}
757+
}
634758

635-
private int getLength(SequenceStorage s) {
636-
if (lenNode == null) {
637-
CompilerDirectives.transferToInterpreterAndInvalidate();
638-
lenNode = insert(SequenceStorageNodes.LenNode.create());
759+
// bytes.endswith(suffix[, start[, end]])
760+
// bytearray.endswith(suffix[, start[, end]])
761+
@Builtin(name = "endswith", minNumOfPositionalArgs = 2, parameterNames = {"self", "suffix", "start", "end"})
762+
@GenerateNodeFactory
763+
public abstract static class EndsWithNode extends PrefixSuffixBaseNode {
764+
765+
@Override
766+
protected boolean doIt(byte[] bytes, byte[] suffix, int start, int end) {
767+
// start and end must be normalized indices for 'bytes'
768+
assert start >= 0;
769+
assert end >= 0 && end <= bytes.length;
770+
771+
int suffixLen = suffix.length;
772+
if (end - start < suffixLen) {
773+
return false;
639774
}
640-
return lenNode.execute(s);
775+
for (int i = 0; i < suffix.length; i++) {
776+
if (bytes[end - suffixLen + i] != suffix[i]) {
777+
return false;
778+
}
779+
}
780+
return true;
641781
}
642782
}
643783

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/object/DefaultPythonStringExports.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ static LazyPythonClass getLazyPythonClass(@SuppressWarnings("unused") String val
7171

7272
@ExportMessage
7373
static boolean isBuffer(@SuppressWarnings("unused") String str) {
74-
return true;
74+
return false;
7575
}
7676

7777
@ExportMessage

0 commit comments

Comments
 (0)