Skip to content

Commit c839cf4

Browse files
committed
[GR-20746] Add boxed specializations of BytesBuiltin.FindNode
PullRequest: graalpython/788
2 parents 5702630 + 44d9701 commit c839cf4

File tree

3 files changed

+97
-23
lines changed

3 files changed

+97
-23
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_bytes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,10 @@ def test_find():
559559
else:
560560
assert False, "should not reach here"
561561

562+
class SubInt(int):
563+
pass
564+
assert ba.find(i, SubInt(6)) == 7
565+
562566

563567
def test_same_id():
564568
empty_ids = set([id(bytes()) for i in range(100)])

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes/BytesBuiltins.java

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
9999
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
100100
import com.oracle.truffle.api.dsl.Cached;
101+
import com.oracle.truffle.api.dsl.Cached.Shared;
101102
import com.oracle.truffle.api.dsl.Fallback;
102103
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
103104
import com.oracle.truffle.api.dsl.GenerateUncached;
@@ -691,38 +692,37 @@ int count(VirtualFrame frame, PIBytesLike byteArray, Object arg,
691692
@Builtin(name = "find", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 4)
692693
@GenerateNodeFactory
693694
abstract static class FindNode extends PythonBuiltinNode {
694-
@Child private BytesNodes.FindNode findNode;
695-
@Child private SequenceStorageNodes.LenNode lenNode;
696-
697695
@Specialization
698-
int find(VirtualFrame frame, PIBytesLike self, Object sub, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end) {
699-
return find(frame, self, sub, 0, getLength(self.getSequenceStorage()));
696+
int find(VirtualFrame frame, PIBytesLike self, Object sub, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end,
697+
@Shared("lenNode") @Cached SequenceStorageNodes.LenNode lenNode,
698+
@Shared("findNode") @Cached BytesNodes.FindNode findNode) {
699+
return find(frame, self, sub, 0, lenNode.execute(self.getSequenceStorage()), findNode);
700700
}
701701

702702
@Specialization
703-
int find(VirtualFrame frame, PIBytesLike self, Object sub, int start, @SuppressWarnings("unused") PNone end) {
704-
return find(frame, self, sub, start, getLength(self.getSequenceStorage()));
703+
int find(VirtualFrame frame, PIBytesLike self, Object sub, int start, @SuppressWarnings("unused") PNone end,
704+
@Shared("lenNode") @Cached SequenceStorageNodes.LenNode lenNode,
705+
@Shared("findNode") @Cached BytesNodes.FindNode findNode) {
706+
return find(frame, self, sub, start, lenNode.execute(self.getSequenceStorage()), findNode);
705707
}
706708

707709
@Specialization
708-
int find(VirtualFrame frame, PIBytesLike self, Object sub, int start, int ending) {
709-
return getFindNode().execute(frame, self, sub, start, ending);
710+
int find(VirtualFrame frame, PIBytesLike self, Object sub, Object start, @SuppressWarnings("unused") PNone end,
711+
@Shared("lenNode") @Cached SequenceStorageNodes.LenNode lenNode,
712+
@Shared("findNode") @Cached BytesNodes.FindNode findNode) {
713+
return find(frame, self, sub, start, lenNode.execute(self.getSequenceStorage()), findNode);
710714
}
711715

712-
private BytesNodes.FindNode getFindNode() {
713-
if (findNode == null) {
714-
CompilerDirectives.transferToInterpreterAndInvalidate();
715-
findNode = insert(BytesNodes.FindNode.create());
716-
}
717-
return findNode;
716+
@Specialization
717+
int find(VirtualFrame frame, PIBytesLike self, Object sub, int start, int ending,
718+
@Shared("findNode") @Cached BytesNodes.FindNode findNode) {
719+
return findNode.execute(frame, self, sub, start, ending);
718720
}
719721

720-
private int getLength(SequenceStorage s) {
721-
if (lenNode == null) {
722-
CompilerDirectives.transferToInterpreterAndInvalidate();
723-
lenNode = insert(SequenceStorageNodes.LenNode.create());
724-
}
725-
return lenNode.execute(s);
722+
@Specialization
723+
int find(VirtualFrame frame, PIBytesLike self, Object sub, Object start, Object ending,
724+
@Shared("findNode") @Cached BytesNodes.FindNode findNode) {
725+
return findNode.execute(frame, self, sub, start, ending);
726726
}
727727
}
728728

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes/BytesNodes.java

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,40 @@ public abstract static class FindNode extends PNodeWithContext {
216216

217217
public abstract int execute(VirtualFrame frame, PIBytesLike bytes, Object sub, Object starting, Object ending);
218218

219+
public abstract int execute(VirtualFrame frame, PIBytesLike bytes, int sub, Object starting, Object ending);
220+
221+
public abstract int execute(VirtualFrame frame, PIBytesLike bytes, int sub, int starting, Object ending);
222+
223+
public abstract int execute(VirtualFrame frame, PIBytesLike bytes, int sub, int starting, int ending);
224+
225+
@Specialization
226+
int find(VirtualFrame frame, PIBytesLike primary, PIBytesLike sub, int starting, int ending) {
227+
SequenceStorage haystack = primary.getSequenceStorage();
228+
int len1 = haystack.length();
229+
230+
SequenceStorage needle = sub.getSequenceStorage();
231+
int len2 = needle.length();
232+
233+
int start = getNormalizeIndexNode().execute(starting, len1);
234+
int end = getNormalizeIndexNode().execute(ending, len1);
235+
236+
return findSubSequence(frame, haystack, len1, needle, len2, start, end);
237+
}
238+
239+
@Specialization
240+
int find(VirtualFrame frame, PIBytesLike primary, PIBytesLike sub, int starting, Object ending) {
241+
SequenceStorage haystack = primary.getSequenceStorage();
242+
int len1 = haystack.length();
243+
244+
SequenceStorage needle = sub.getSequenceStorage();
245+
int len2 = needle.length();
246+
247+
int start = getNormalizeIndexNode().execute(starting, len1);
248+
int end = getNormalizeIndexNode().execute(ending, len1);
249+
250+
return findSubSequence(frame, haystack, len1, needle, len2, start, end);
251+
}
252+
219253
@Specialization
220254
int find(VirtualFrame frame, PIBytesLike primary, PIBytesLike sub, Object starting, Object ending) {
221255
SequenceStorage haystack = primary.getSequenceStorage();
@@ -227,6 +261,11 @@ int find(VirtualFrame frame, PIBytesLike primary, PIBytesLike sub, Object starti
227261
int start = getNormalizeIndexNode().execute(starting, len1);
228262
int end = getNormalizeIndexNode().execute(ending, len1);
229263

264+
return findSubSequence(frame, haystack, len1, needle, len2, start, end);
265+
}
266+
267+
private int findSubSequence(VirtualFrame frame, SequenceStorage haystack, int len1, SequenceStorage needle, int len2, int start, int endInput) {
268+
int end = endInput;
230269
if (start >= len1 || len1 < len2) {
231270
return -1;
232271
} else if (end > len1) {
@@ -253,16 +292,47 @@ int find(VirtualFrame frame, PIBytesLike primary, PIBytesLike sub, Object starti
253292
}
254293

255294
@Specialization
256-
int find(VirtualFrame frame, PIBytesLike primary, int sub, Object starting, @SuppressWarnings("unused") Object ending) {
295+
int find(VirtualFrame frame, PIBytesLike primary, int sub, int starting, int ending) {
296+
SequenceStorage haystack = primary.getSequenceStorage();
297+
int len1 = haystack.length();
298+
299+
int start = getNormalizeIndexNode().execute(starting, len1);
300+
int end = getNormalizeIndexNode().execute(ending, len1);
301+
302+
return findElement(frame, sub, haystack, len1, start, end);
303+
}
304+
305+
@Specialization
306+
int find(VirtualFrame frame, PIBytesLike primary, int sub, int starting, Object ending) {
307+
SequenceStorage haystack = primary.getSequenceStorage();
308+
int len1 = haystack.length();
309+
310+
int start = getNormalizeIndexNode().execute(starting, len1);
311+
int end = getNormalizeIndexNode().execute(ending, len1);
312+
313+
return findElement(frame, sub, haystack, len1, start, end);
314+
}
315+
316+
@Specialization
317+
int find(VirtualFrame frame, PIBytesLike primary, int sub, Object starting, Object ending) {
257318
SequenceStorage haystack = primary.getSequenceStorage();
258319
int len1 = haystack.length();
259320

260321
int start = getNormalizeIndexNode().execute(starting, len1);
322+
int end = getNormalizeIndexNode().execute(ending, len1);
323+
324+
return findElement(frame, sub, haystack, len1, start, end);
325+
}
326+
327+
private int findElement(VirtualFrame frame, int sub, SequenceStorage haystack, int len1, int start, int endInput) {
328+
int end = endInput;
261329
if (start >= len1) {
262330
return -1;
331+
} else if (end > len1) {
332+
end = len1;
263333
}
264334

265-
for (int i = start; i < len1; i++) {
335+
for (int i = start; i < end; i++) {
266336
int hb = getGetLeftItemNode().executeInt(frame, haystack, i);
267337
if (hb == sub) {
268338
return i;

0 commit comments

Comments
 (0)