Skip to content

Commit 1678234

Browse files
committed
cleanup unboxed specializations for bytes.find
1 parent b0906cf commit 1678234

File tree

2 files changed

+91
-31
lines changed

2 files changed

+91
-31
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes/BytesBuiltins.java

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
9999
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
100100
import com.oracle.truffle.api.dsl.Cached;
101+
import com.oracle.truffle.api.dsl.Cached.Shared;
101102
import com.oracle.truffle.api.dsl.Fallback;
102103
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
103104
import com.oracle.truffle.api.dsl.GenerateUncached;
@@ -691,48 +692,37 @@ int count(VirtualFrame frame, PIBytesLike byteArray, Object arg,
691692
@Builtin(name = "find", minNumOfPositionalArgs = 2, maxNumOfPositionalArgs = 4)
692693
@GenerateNodeFactory
693694
abstract static class FindNode extends PythonBuiltinNode {
694-
@Child private BytesNodes.FindNode findNode;
695-
@Child private SequenceStorageNodes.LenNode lenNode;
696-
697695
@Specialization
698-
int find(VirtualFrame frame, PIBytesLike self, Object sub, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end) {
699-
return find(frame, self, sub, 0, getLength(self.getSequenceStorage()));
696+
int find(VirtualFrame frame, PIBytesLike self, Object sub, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end,
697+
@Shared("lenNode") @Cached SequenceStorageNodes.LenNode lenNode,
698+
@Shared("findNode") @Cached BytesNodes.FindNode findNode) {
699+
return find(frame, self, sub, 0, lenNode.execute(self.getSequenceStorage()), findNode);
700700
}
701701

702702
@Specialization
703-
int find(VirtualFrame frame, PIBytesLike self, Object sub, int start, @SuppressWarnings("unused") PNone end) {
704-
return find(frame, self, sub, start, getLength(self.getSequenceStorage()));
703+
int find(VirtualFrame frame, PIBytesLike self, Object sub, int start, @SuppressWarnings("unused") PNone end,
704+
@Shared("lenNode") @Cached SequenceStorageNodes.LenNode lenNode,
705+
@Shared("findNode") @Cached BytesNodes.FindNode findNode) {
706+
return find(frame, self, sub, start, lenNode.execute(self.getSequenceStorage()), findNode);
705707
}
706708

707709
@Specialization
708-
int find (VirtualFrame frame, PIBytesLike self, Object sub, Number start, @SuppressWarnings("unused") PNone end) {
709-
return find(frame, self, sub, start, getLength(self.getSequenceStorage()));
710+
int find (VirtualFrame frame, PIBytesLike self, Object sub, Object start, @SuppressWarnings("unused") PNone end,
711+
@Shared("lenNode") @Cached SequenceStorageNodes.LenNode lenNode,
712+
@Shared("findNode") @Cached BytesNodes.FindNode findNode) {
713+
return find(frame, self, sub, start, lenNode.execute(self.getSequenceStorage()), findNode);
710714
}
711715

712716
@Specialization
713-
int find(VirtualFrame frame, PIBytesLike self, Object sub, int start, int ending) {
714-
return getFindNode().execute(frame, self, sub, start, ending);
717+
int find(VirtualFrame frame, PIBytesLike self, Object sub, int start, int ending,
718+
@Shared("findNode") @Cached BytesNodes.FindNode findNode) {
719+
return findNode.execute(frame, self, sub, start, ending);
715720
}
716721

717722
@Specialization
718-
int find(VirtualFrame frame, PIBytesLike self, Object sub, Number start, Number ending) {
719-
return getFindNode().execute(frame, self, sub, start, ending);
720-
}
721-
722-
private BytesNodes.FindNode getFindNode() {
723-
if (findNode == null) {
724-
CompilerDirectives.transferToInterpreterAndInvalidate();
725-
findNode = insert(BytesNodes.FindNode.create());
726-
}
727-
return findNode;
728-
}
729-
730-
private int getLength(SequenceStorage s) {
731-
if (lenNode == null) {
732-
CompilerDirectives.transferToInterpreterAndInvalidate();
733-
lenNode = insert(SequenceStorageNodes.LenNode.create());
734-
}
735-
return lenNode.execute(s);
723+
int find(VirtualFrame frame, PIBytesLike self, Object sub, Object start, Object ending,
724+
@Shared("findNode") @Cached BytesNodes.FindNode findNode) {
725+
return findNode.execute(frame, self, sub, start, ending);
736726
}
737727
}
738728

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes/BytesNodes.java

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,40 @@ public abstract static class FindNode extends PNodeWithContext {
216216

217217
public abstract int execute(VirtualFrame frame, PIBytesLike bytes, Object sub, Object starting, Object ending);
218218

219+
public abstract int execute(VirtualFrame frame, PIBytesLike bytes, int sub, Object starting, Object ending);
220+
221+
public abstract int execute(VirtualFrame frame, PIBytesLike bytes, int sub, int starting, Object ending);
222+
223+
public abstract int execute(VirtualFrame frame, PIBytesLike bytes, int sub, int starting, int ending);
224+
225+
@Specialization
226+
int find(VirtualFrame frame, PIBytesLike primary, PIBytesLike sub, int starting, int ending) {
227+
SequenceStorage haystack = primary.getSequenceStorage();
228+
int len1 = haystack.length();
229+
230+
SequenceStorage needle = sub.getSequenceStorage();
231+
int len2 = needle.length();
232+
233+
int start = getNormalizeIndexNode().execute(starting, len1);
234+
int end = getNormalizeIndexNode().execute(ending, len1);
235+
236+
return findSubSequence(frame, haystack, len1, needle, len2, start, end);
237+
}
238+
239+
@Specialization
240+
int find(VirtualFrame frame, PIBytesLike primary, PIBytesLike sub, int starting, Object ending) {
241+
SequenceStorage haystack = primary.getSequenceStorage();
242+
int len1 = haystack.length();
243+
244+
SequenceStorage needle = sub.getSequenceStorage();
245+
int len2 = needle.length();
246+
247+
int start = getNormalizeIndexNode().execute(starting, len1);
248+
int end = getNormalizeIndexNode().execute(ending, len1);
249+
250+
return findSubSequence(frame, haystack, len1, needle, len2, start, end);
251+
}
252+
219253
@Specialization
220254
int find(VirtualFrame frame, PIBytesLike primary, PIBytesLike sub, Object starting, Object ending) {
221255
SequenceStorage haystack = primary.getSequenceStorage();
@@ -227,6 +261,11 @@ int find(VirtualFrame frame, PIBytesLike primary, PIBytesLike sub, Object starti
227261
int start = getNormalizeIndexNode().execute(starting, len1);
228262
int end = getNormalizeIndexNode().execute(ending, len1);
229263

264+
return findSubSequence(frame, haystack, len1, needle, len2, start, end);
265+
}
266+
267+
private int findSubSequence(VirtualFrame frame, SequenceStorage haystack, int len1, SequenceStorage needle, int len2, int start, int endInput) {
268+
int end = endInput;
230269
if (start >= len1 || len1 < len2) {
231270
return -1;
232271
} else if (end > len1) {
@@ -253,16 +292,47 @@ int find(VirtualFrame frame, PIBytesLike primary, PIBytesLike sub, Object starti
253292
}
254293

255294
@Specialization
256-
int find(VirtualFrame frame, PIBytesLike primary, int sub, Object starting, @SuppressWarnings("unused") Object ending) {
295+
int find(VirtualFrame frame, PIBytesLike primary, int sub, int starting, int ending) {
296+
SequenceStorage haystack = primary.getSequenceStorage();
297+
int len1 = haystack.length();
298+
299+
int start = getNormalizeIndexNode().execute(starting, len1);
300+
int end = getNormalizeIndexNode().execute(ending, len1);
301+
302+
return findElement(frame, sub, haystack, len1, start, end);
303+
}
304+
305+
@Specialization
306+
int find(VirtualFrame frame, PIBytesLike primary, int sub, int starting, Object ending) {
307+
SequenceStorage haystack = primary.getSequenceStorage();
308+
int len1 = haystack.length();
309+
310+
int start = getNormalizeIndexNode().execute(starting, len1);
311+
int end = getNormalizeIndexNode().execute(ending, len1);
312+
313+
return findElement(frame, sub, haystack, len1, start, end);
314+
}
315+
316+
@Specialization
317+
int find(VirtualFrame frame, PIBytesLike primary, int sub, Object starting, Object ending) {
257318
SequenceStorage haystack = primary.getSequenceStorage();
258319
int len1 = haystack.length();
259320

260321
int start = getNormalizeIndexNode().execute(starting, len1);
322+
int end = getNormalizeIndexNode().execute(ending, len1);
323+
324+
return findElement(frame, sub, haystack, len1, start, end);
325+
}
326+
327+
private int findElement(VirtualFrame frame, int sub, SequenceStorage haystack, int len1, int start, int endInput) {
328+
int end = endInput;
261329
if (start >= len1) {
262330
return -1;
331+
} else if (end > len1) {
332+
end = len1;
263333
}
264334

265-
for (int i = start; i < len1; i++) {
335+
for (int i = start; i < end; i++) {
266336
int hb = getGetLeftItemNode().executeInt(frame, haystack, i);
267337
if (hb == sub) {
268338
return i;

0 commit comments

Comments
 (0)