Skip to content

Commit 68aa438

Browse files
committed
[GR-11981] str.count() returns wrong results.
PullRequest: graalpython/219
2 parents 4d97ba6 + b8add24 commit 68aa438

File tree

3 files changed

+190
-70
lines changed

3 files changed

+190
-70
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_string.py

Lines changed: 115 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
import sys
88

99

10+
class MyIndexable(object):
11+
def __init__(self, value):
12+
self.value = value
13+
def __index__(self):
14+
return self.value
15+
1016
def test_find():
1117
assert "teststring".find("test") == 0
1218
assert "teststring".find("string") == 4
@@ -23,6 +29,31 @@ def test_find():
2329
assert "teststring".find("tst", None, 2) == -1
2430
assert "teststring".find("st", None, 4) == 2
2531

32+
s = 'ahoj cau nazadar ahoj'
33+
assert s.find('ahoj') == 0
34+
assert s.find('ahoj', 4) == 17
35+
assert s.find('ahoj', -3) == -1
36+
assert s.find('ahoj', -21) == 0
37+
assert s.find('cau', -21) == 5
38+
assert s.find('cau', -36, -10) == 5
39+
assert s.find('cau', None) == 5
40+
assert s.find('ahoj', None) == 0
41+
assert s.find('cau', None, 8) == 5
42+
assert s.find('cau', None, 7) == -1
43+
assert s.find('u', 3) == 7
44+
assert s.find('u', 3, 7) == -1
45+
assert s.find('u', 3, 8) == 7
46+
assert s.find('u', -18, -13) == 7
47+
assert s.find('u', -18, -12) == 7
48+
assert s.find('u', -18, -14) == -1
49+
assert s.find('u', -14, -13) == 7
50+
assert s.find('u', -12, -13) == -1
51+
assert s.find('cau', MyIndexable(4)) == 5
52+
assert s.find('cau', MyIndexable(5)) == 5
53+
assert s.find('cau', MyIndexable(5), None) == 5
54+
assert s.find('cau', MyIndexable(5), MyIndexable(8)) == 5
55+
assert s.find('cau', None, MyIndexable(8)) == 5
56+
2657

2758
def test_rfind():
2859
assert "test string test".rfind("test") == 12
@@ -34,6 +65,20 @@ def test_rfind():
3465
assert "test string test".rfind("test", 4, 14) == -1
3566
assert "test string test".rfind("test", None, 14) == 0
3667

68+
s = 'ahoj cau nazdar ahoj'
69+
assert s.rfind('cau', None, None) == 5
70+
assert s.rfind('cau', -25, None) == 5
71+
assert s.rfind('cau', -25, -3) == 5
72+
assert s.rfind('cau', -25, -12) == 5
73+
assert s.rfind('cau', -25, -13) == -1
74+
assert s.rfind('cau', -15, -12) == 5
75+
assert s.rfind('cau', -14, -12) == -1
76+
assert s.rfind('ahoj', -14) == 16
77+
assert s.rfind('ahoj', -4) == 16
78+
assert s.rfind('ahoj', -3) == -1
79+
assert s.rfind('ahoj', 16) == 16
80+
assert s.rfind('ahoj', 16, 20) == 16
81+
assert s.rfind('ahoj', 16, 19) == -1
3782

3883
def test_format():
3984
assert "{}.{}".format("part1", "part2") == "part1.part2"
@@ -707,13 +752,6 @@ def test_zfill(self):
707752

708753
def test_zfill_specialization(self):
709754
self.checkequal('123', '123', 'zfill', True)
710-
711-
class MyIndexable(object):
712-
def __init__(self, value):
713-
self.value = value
714-
def __index__(self):
715-
return self.value
716-
717755
self.checkequal('0123', '123', 'zfill', MyIndexable(4))
718756

719757
def test_title(self):
@@ -774,6 +812,76 @@ def test_center_uni(self):
774812
self.assertEqual('x'.center(4, '\U0010FFFF'),
775813
'\U0010FFFFx\U0010FFFF\U0010FFFF')
776814

815+
# Whether the "contained items" of the container are integers in
816+
# range(0, 256) (i.e. bytes, bytearray) or strings of length 1
817+
# (str)
818+
contains_bytes = False
819+
820+
def test_count(self):
821+
self.checkequal(3, 'aaa', 'count', 'a')
822+
self.checkequal(0, 'aaa', 'count', 'b')
823+
self.checkequal(3, 'aaa', 'count', 'a')
824+
self.checkequal(0, 'aaa', 'count', 'b')
825+
self.checkequal(3, 'aaa', 'count', 'a')
826+
self.checkequal(0, 'aaa', 'count', 'b')
827+
self.checkequal(0, 'aaa', 'count', 'b')
828+
self.checkequal(2, 'aaa', 'count', 'a', 1)
829+
self.checkequal(0, 'aaa', 'count', 'a', 10)
830+
self.checkequal(1, 'aaa', 'count', 'a', -1)
831+
self.checkequal(3, 'aaa', 'count', 'a', -10)
832+
self.checkequal(1, 'aaa', 'count', 'a', 0, 1)
833+
self.checkequal(3, 'aaa', 'count', 'a', 0, 10)
834+
self.checkequal(2, 'aaa', 'count', 'a', 0, -1)
835+
self.checkequal(0, 'aaa', 'count', 'a', 0, -10)
836+
self.checkequal(3, 'aaa', 'count', '', 1)
837+
self.checkequal(1, 'aaa', 'count', '', 3)
838+
self.checkequal(0, 'aaa', 'count', '', 10)
839+
self.checkequal(2, 'aaa', 'count', '', -1)
840+
self.checkequal(4, 'aaa', 'count', '', -10)
841+
842+
self.checkequal(1, '', 'count', '')
843+
self.checkequal(0, '', 'count', '', 1, 1)
844+
self.checkequal(0, '', 'count', '', sys.maxsize, 0)
845+
846+
self.checkequal(0, '', 'count', 'xx')
847+
self.checkequal(0, '', 'count', 'xx', 1, 1)
848+
self.checkequal(0, '', 'count', 'xx', sys.maxsize, 0)
849+
850+
self.checkraises(TypeError, 'hello', 'count')
851+
852+
if self.contains_bytes:
853+
self.checkequal(0, 'hello', 'count', 42)
854+
else:
855+
self.checkraises(TypeError, 'hello', 'count', 42)
856+
857+
# For a variety of combinations,
858+
# verify that str.count() matches an equivalent function
859+
# replacing all occurrences and then differencing the string lengths
860+
charset = ['', 'a', 'b']
861+
digits = 7
862+
base = len(charset)
863+
teststrings = set()
864+
for i in range(base ** digits):
865+
entry = []
866+
for j in range(digits):
867+
i, m = divmod(i, base)
868+
entry.append(charset[m])
869+
teststrings.add(''.join(entry))
870+
teststrings = [self.fixtype(ts) for ts in teststrings]
871+
for i in teststrings:
872+
n = len(i)
873+
for j in teststrings:
874+
r1 = i.count(j)
875+
if j:
876+
r2, rem = divmod(n - len(i.replace(j, self.fixtype(''))),
877+
len(j))
878+
else:
879+
r2, rem = len(i)+1, 0
880+
if rem or r1 != r2:
881+
self.assertEqual(rem, 0, '%s != 0 for %s' % (rem, i))
882+
self.assertEqual(r1, r2, '%s != %s for %s' % (r1, r2, i))
883+
884+
777885
def test_same_id():
778886
empty_ids = set([id(str()) for i in range(100)])
779887
assert len(empty_ids) == 1

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/str/StringBuiltins.java

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -522,42 +522,77 @@ public Object endsWith(Object self, Object prefix) {
522522
@TypeSystemReference(PythonArithmeticTypes.class)
523523
abstract static class FindBaseNode extends PythonBuiltinNode {
524524

525+
private @Child CastToIndexNode startNode;
526+
private @Child CastToIndexNode endNode;
527+
528+
private CastToIndexNode getStartNode() {
529+
if (startNode == null) {
530+
CompilerDirectives.transferToInterpreterAndInvalidate();
531+
startNode = insert(CastToIndexNode.createOverflow());
532+
}
533+
return startNode;
534+
}
535+
536+
private CastToIndexNode getEndNode() {
537+
if (endNode == null) {
538+
CompilerDirectives.transferToInterpreterAndInvalidate();
539+
endNode = insert(CastToIndexNode.createOverflow());
540+
}
541+
return endNode;
542+
}
543+
544+
private SliceInfo computeSlice(int length, long start, long end) {
545+
PSlice tmpSlice = factory().createSlice(getStartNode().execute(start), getEndNode().execute(end), 1);
546+
return tmpSlice.computeIndices(length);
547+
}
548+
549+
private SliceInfo computeSlice(int length, Object startO, Object endO) {
550+
int start = startO == PNone.NO_VALUE || startO == PNone.NONE ? 0 : getStartNode().execute(startO);
551+
int end = endO == PNone.NO_VALUE || endO == PNone.NONE ? length : getEndNode().execute(endO);
552+
PSlice tmpSlice = factory().createSlice(start, end, 1);
553+
return tmpSlice.computeIndices(length);
554+
}
555+
525556
@Specialization
526557
Object find(String self, String str, @SuppressWarnings("unused") PNone start, @SuppressWarnings("unused") PNone end) {
527558
return find(self, str);
528559
}
529560

530561
@Specialization
531562
Object find(String self, String str, long start, @SuppressWarnings("unused") PNone end) {
532-
return findGeneric(self, str, start, -1);
563+
int len = self.length();
564+
SliceInfo info = computeSlice(len, start, len);
565+
if (info.length == 0) {
566+
return -1;
567+
}
568+
return findWithBounds(self, str, info.start, info.stop);
533569
}
534570

535571
@Specialization
536572
Object find(String self, String str, @SuppressWarnings("unused") PNone start, long end) {
537-
return findGeneric(self, str, -1, end);
573+
SliceInfo info = computeSlice(self.length(), 0, end);
574+
if (info.length == 0) {
575+
return -1;
576+
}
577+
return findWithBounds(self, str, info.start, info.stop);
538578
}
539579

540580
@Specialization
541581
Object find(String self, String str, long start, long end) {
542-
return findGeneric(self, str, start, end);
582+
SliceInfo info = computeSlice(self.length(), start, end);
583+
if (info.length == 0) {
584+
return -1;
585+
}
586+
return findWithBounds(self, str, info.start, info.stop);
543587
}
544588

545-
@Specialization(guards = {"isNumberOrNone(start)", "isNumberOrNone(end)"}, rewriteOn = ArithmeticException.class)
589+
@Specialization
546590
Object findGeneric(String self, String str, Object start, Object end) throws ArithmeticException {
547-
int startInt = getIntValue(start);
548-
int endInt = getIntValue(end);
549-
return findWithBounds(self, str, startInt, endInt);
550-
}
551-
552-
@Specialization(guards = {"isNumberOrNone(start)", "isNumberOrNone(end)"}, replaces = "findGeneric")
553-
Object findGenericOvf(String self, String str, Object start, Object end) {
554-
try {
555-
int startInt = getIntValue(start);
556-
int endInt = getIntValue(end);
557-
return findWithBounds(self, str, startInt, endInt);
558-
} catch (ArithmeticException e) {
559-
throw raise(ValueError, "cannot fit 'int' into an index-sized integer");
591+
SliceInfo info = computeSlice(self.length(), start, end);
592+
if (info.length == 0) {
593+
return -1;
560594
}
595+
return findWithBounds(self, str, info.start, info.stop);
561596
}
562597

563598
@Fallback
@@ -566,25 +601,6 @@ Object findFail(Object self, Object str, Object start, Object end) {
566601
throw raise(TypeError, "must be str, not %p", str);
567602
}
568603

569-
protected static boolean isNumberOrNone(Object o) {
570-
return o instanceof PInt || o instanceof PNone;
571-
}
572-
573-
private static int getIntValue(Object o) throws ArithmeticException {
574-
if (o instanceof Integer) {
575-
return (int) o;
576-
} else if (o instanceof Long) {
577-
return PInt.intValueExact((long) o);
578-
} else if (o instanceof Boolean) {
579-
return PInt.intValue((boolean) o);
580-
} else if (o instanceof PInt) {
581-
return ((PInt) o).intValueExact();
582-
} else if (o instanceof PNone) {
583-
return -1;
584-
}
585-
throw new IllegalArgumentException();
586-
}
587-
588604
@SuppressWarnings("unused")
589605
protected int find(String self, String findStr) {
590606
throw new AssertionError("must not be reached");
@@ -610,16 +626,8 @@ protected int find(String self, String findStr) {
610626
@Override
611627
@TruffleBoundary
612628
protected int findWithBounds(String self, String str, int start, int end) {
613-
if (start != -1 && end != -1) {
614-
int idx = self.lastIndexOf(str, end - str.length() - 1);
615-
return idx >= start ? idx : -1;
616-
} else if (start != -1) {
617-
int idx = self.lastIndexOf(str);
618-
return idx >= start ? idx : -1;
619-
} else {
620-
assert end != -1;
621-
return self.lastIndexOf(str, end - str.length() - 1);
622-
}
629+
int idx = self.lastIndexOf(str, end - str.length());
630+
return idx >= start ? idx : -1;
623631
}
624632
}
625633

@@ -637,16 +645,8 @@ protected int find(String self, String findStr) {
637645
@Override
638646
@TruffleBoundary
639647
protected int findWithBounds(String self, String str, int start, int end) {
640-
if (start != -1 && end != -1) {
641-
int idx = self.indexOf(str, start);
642-
return idx + str.length() <= end ? idx : -1;
643-
} else if (start != -1) {
644-
return self.indexOf(str, start);
645-
} else {
646-
assert end != -1;
647-
int idx = self.indexOf(str);
648-
return idx + str.length() <= end ? idx : -1;
649-
}
648+
int idx = self.indexOf(str, start);
649+
return idx + str.length() <= end ? idx : -1;
650650
}
651651
}
652652

graalpython/lib-graalpython/str.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -341,17 +341,29 @@ def __iter__(self):
341341
str.__iter__ = __iter__
342342

343343

344-
def strcount(self, sub, start=0, end=-1):
345-
if len(self) == 0:
344+
def strcount(self, sub, start=None, end=None):
345+
selfLeng = len(self)
346+
subLeng = len(sub)
347+
if start == None:
348+
start = 0
349+
if selfLeng == 0:
350+
if subLeng == 0 and start <= 0:
351+
return 1
346352
return 0
347-
if end < 0:
348-
end = (len(self) + end) % len(self)
349-
cnt = 0
353+
if end == None:
354+
end = selfLeng
355+
if subLeng == 0:
356+
if start <= selfLeng:
357+
return len(self[start:end]) + 1
358+
return 0
359+
350360
idx = self.find(sub, start, end)
351361
if idx < 0:
352362
return 0
353-
while idx < end and idx >= 0:
354-
start = idx + 1
363+
364+
cnt = 1
365+
while idx < selfLeng and idx >= 0 and cnt < selfLeng:
366+
start = idx + subLeng
355367
idx = self.find(sub, start, end)
356368
if idx >= 0:
357369
cnt += 1

0 commit comments

Comments
 (0)