Skip to content

Commit 2607224

Browse files
committed
Fix find/index with empty substring
Fixes #324
1 parent 213304c commit 2607224

File tree

2 files changed

+53
-50
lines changed

2 files changed

+53
-50
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_string.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, 2022, Oracle and/or its affiliates.
1+
# Copyright (c) 2018, 2023, Oracle and/or its affiliates.
22
# Copyright (C) 1996-2017 Python Software Foundation
33
#
44
# Licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
@@ -14,6 +14,16 @@ def __index__(self):
1414
return self.value
1515

1616

17+
def test_index():
18+
assert 'bla'.index('a') == 2
19+
assertRaises(ValueError, 'bla'.index, 'c')
20+
assert ''.index('') == 0
21+
assert 'adsf'.index('') == 0
22+
assert 'adsf'.index('', 2) == 2
23+
assert 'adsf'.index('', 1, 1) == 1
24+
assertRaises(ValueError, 'asdf'.index, '', 2, 1)
25+
26+
1727
def test_find():
1828
assert "teststring".find("test") == 0
1929
assert "teststring".find("string") == 4
@@ -56,6 +66,12 @@ def test_find():
5666
assert s.find('cau', None, MyIndexable(8)) == 5
5767
assert s.find('cau', 2**100) == -1
5868

69+
assert ''.find('') == 0
70+
assert 'adsf'.find('') == 0
71+
assert 'adsf'.find('', 2) == 2
72+
assert 'adsf'.find('', 1, 1) == 1
73+
assert 'asdf'.find('', 2, 1) == -1
74+
5975

6076
def test_rfind():
6177
assert "test string test".rfind("test") == 12

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/str/StringBuiltins.java

Lines changed: 36 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@
186186
import com.oracle.truffle.api.strings.TruffleString.CodePointLengthNode;
187187
import com.oracle.truffle.api.strings.TruffleString.CodeRange;
188188
import com.oracle.truffle.api.strings.TruffleString.Encoding;
189+
import com.oracle.truffle.api.strings.TruffleString.IndexOfStringNode;
190+
import com.oracle.truffle.api.strings.TruffleString.LastIndexOfStringNode;
189191
import com.oracle.truffle.api.strings.TruffleStringBuilder;
190192
import com.oracle.truffle.api.strings.TruffleStringIterator;
191193

@@ -752,20 +754,7 @@ protected ArgumentClinicProvider getArgumentClinic() {
752754
static int rfind(TruffleString self, TruffleString sub, int start, int end,
753755
@Shared("cpLen") @Cached TruffleString.CodePointLengthNode codePointLengthNode,
754756
@Shared("lastIndexOf") @Cached TruffleString.LastIndexOfStringNode lastIndexOfStringNode) {
755-
int cpLen = codePointLengthNode.execute(self, TS_ENCODING);
756-
int cpStart = adjustStartIndex(start, cpLen);
757-
int cpEnd = adjustEndIndex(end, cpLen);
758-
if (sub.isEmpty() && cpStart == cpLen) {
759-
return cpLen;
760-
}
761-
if (cpStart >= cpLen) {
762-
return -1;
763-
}
764-
int idx = lastIndexOfStringNode.execute(self, sub, cpEnd, cpStart, TS_ENCODING);
765-
if (idx < 0) {
766-
return -1;
767-
}
768-
return idx;
757+
return lastIndexOf(self, sub, start, end, codePointLengthNode, lastIndexOfStringNode);
769758
}
770759

771760
@Specialization
@@ -795,20 +784,7 @@ protected ArgumentClinicProvider getArgumentClinic() {
795784
static int find(TruffleString self, TruffleString sub, int start, int end,
796785
@Shared("cpLen") @Cached TruffleString.CodePointLengthNode codePointLengthNode,
797786
@Shared("indexOf") @Cached TruffleString.IndexOfStringNode indexOfStringNode) {
798-
int cpLen = codePointLengthNode.execute(self, TS_ENCODING);
799-
int cpStart = adjustStartIndex(start, cpLen);
800-
int cpEnd = adjustEndIndex(end, cpLen);
801-
if (sub.isEmpty() && cpStart == cpLen) {
802-
return cpLen;
803-
}
804-
if (cpStart >= cpLen) {
805-
return -1;
806-
}
807-
int idx = indexOfStringNode.execute(self, sub, cpStart, cpEnd, TS_ENCODING);
808-
if (idx < 0) {
809-
return -1;
810-
}
811-
return idx;
787+
return indexOf(self, sub, start, end, codePointLengthNode, indexOfStringNode);
812788
}
813789

814790
@Specialization
@@ -1731,6 +1707,30 @@ static int len(Object self,
17311707
}
17321708
}
17331709

1710+
private static int indexOf(TruffleString self, TruffleString sub, int start, int end, CodePointLengthNode codePointLengthNode, IndexOfStringNode indexOfStringNode) {
1711+
int cpLen = codePointLengthNode.execute(self, TS_ENCODING);
1712+
int cpStart = adjustStartIndex(start, cpLen);
1713+
int cpEnd = adjustEndIndex(end, cpLen);
1714+
if (cpStart < cpEnd) {
1715+
return indexOfStringNode.execute(self, sub, cpStart, cpEnd, TS_ENCODING);
1716+
} else if (sub.isEmpty() && cpStart == cpEnd && cpStart <= cpLen) {
1717+
return cpStart;
1718+
}
1719+
return -1;
1720+
}
1721+
1722+
private static int lastIndexOf(TruffleString self, TruffleString sub, int start, int end, CodePointLengthNode codePointLengthNode, LastIndexOfStringNode lastIndexOfStringNode) {
1723+
int cpLen = codePointLengthNode.execute(self, TS_ENCODING);
1724+
int cpStart = adjustStartIndex(start, cpLen);
1725+
int cpEnd = adjustEndIndex(end, cpLen);
1726+
if (cpStart < cpEnd) {
1727+
return lastIndexOfStringNode.execute(self, sub, cpEnd, cpStart, TS_ENCODING);
1728+
} else if (sub.isEmpty() && cpStart == cpEnd && cpStart <= cpLen) {
1729+
return cpStart;
1730+
}
1731+
return -1;
1732+
}
1733+
17341734
@Builtin(name = "index", minNumOfPositionalArgs = 2, parameterNames = {"$self", "sub", "start", "end"})
17351735
@ArgumentClinic(name = "start", conversion = ArgumentClinic.ClinicConversion.SliceIndex, defaultValue = "0", useDefaultForNone = true)
17361736
@ArgumentClinic(name = "end", conversion = ArgumentClinic.ClinicConversion.SliceIndex, defaultValue = "Integer.MAX_VALUE", useDefaultForNone = true)
@@ -1745,16 +1745,11 @@ protected ArgumentClinicProvider getArgumentClinic() {
17451745
public int index(TruffleString self, TruffleString sub, int start, int end,
17461746
@Shared("cpLen") @Cached TruffleString.CodePointLengthNode codePointLengthNode,
17471747
@Shared("indexOf") @Cached TruffleString.IndexOfStringNode indexOfStringNode) {
1748-
int cpLen = codePointLengthNode.execute(self, TS_ENCODING);
1749-
int cpStart = adjustStartIndex(start, cpLen);
1750-
int cpEnd = adjustEndIndex(end, cpLen);
1751-
if (cpStart < cpLen) {
1752-
int idx = indexOfStringNode.execute(self, sub, cpStart, cpEnd, TS_ENCODING);
1753-
if (idx >= 0) {
1754-
return idx;
1755-
}
1748+
int idx = indexOf(self, sub, start, end, codePointLengthNode, indexOfStringNode);
1749+
if (idx < 0) {
1750+
throw raise(ValueError, ErrorMessages.SUBSTRING_NOT_FOUND);
17561751
}
1757-
throw raise(ValueError, ErrorMessages.SUBSTRING_NOT_FOUND);
1752+
return idx;
17581753
}
17591754

17601755
@Specialization
@@ -1782,19 +1777,11 @@ protected ArgumentClinicProvider getArgumentClinic() {
17821777
public int rindex(TruffleString self, TruffleString sub, int start, int end,
17831778
@Shared("cpLen") @Cached TruffleString.CodePointLengthNode codePointLengthNode,
17841779
@Shared("lastIndexOf") @Cached TruffleString.LastIndexOfStringNode lastIndexOfStringNode) {
1785-
int cpLen = codePointLengthNode.execute(self, TS_ENCODING);
1786-
int cpStart = adjustStartIndex(start, cpLen);
1787-
int cpEnd = adjustEndIndex(end, cpLen);
1788-
if (sub.isEmpty() && cpStart == cpLen) {
1789-
return cpLen;
1790-
}
1791-
if (cpStart < cpLen) {
1792-
int idx = lastIndexOfStringNode.execute(self, sub, cpEnd, cpStart, TS_ENCODING);
1793-
if (idx >= 0) {
1794-
return idx;
1795-
}
1780+
int idx = lastIndexOf(self, sub, start, end, codePointLengthNode, lastIndexOfStringNode);
1781+
if (idx < 0) {
1782+
throw raise(ValueError, ErrorMessages.SUBSTRING_NOT_FOUND);
17961783
}
1797-
throw raise(ValueError, ErrorMessages.SUBSTRING_NOT_FOUND);
1784+
return idx;
17981785
}
17991786

18001787
@Specialization

0 commit comments

Comments
 (0)