Skip to content

Commit 463562b

Browse files
committed
fix bugs in range and reverse iteration
1 parent 28d8659 commit 463562b

File tree

5 files changed

+91
-40
lines changed

5 files changed

+91
-40
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinConstructors.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,28 @@ public void enumerate(@SuppressWarnings("unused") PythonClass cls, @SuppressWarn
354354
public abstract static class ReversedNode extends PythonBuiltinNode {
355355

356356
@Specialization
357-
public PythonObject reversed(@SuppressWarnings("unused") PythonClass cls, PRange range) {
358-
return factory().createRangeReverseIterator(range);
357+
public PythonObject reversed(@SuppressWarnings("unused") PythonClass cls, PRange range,
358+
@Cached("createBinaryProfile()") ConditionProfile stepOneProfile,
359+
@Cached("createBinaryProfile()") ConditionProfile stepMinusOneProfile) {
360+
int stop;
361+
int start;
362+
int step = range.getStep();
363+
if (stepOneProfile.profile(step == 1)) {
364+
start = range.getStop() - 1;
365+
stop = range.getStart() - 1;
366+
step = -1;
367+
} else if (stepMinusOneProfile.profile(step == -1)) {
368+
start = range.getStop() + 1;
369+
stop = range.getStart() + 1;
370+
step = 1;
371+
} else {
372+
assert step != 0;
373+
long delta = (range.getStop() - (long) range.getStart() - (step > 0 ? -1 : 1)) / step * step;
374+
start = (int) (range.getStart() + delta);
375+
stop = range.getStart() - step;
376+
step = -step;
377+
}
378+
return factory().createRangeIterator(start, stop, step);
359379
}
360380

361381
@Specialization

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/iterator/PRangeIterator.java

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
*/
2626
package com.oracle.graal.python.builtins.objects.iterator;
2727

28-
import com.oracle.graal.python.builtins.objects.range.PRange;
2928
import com.oracle.graal.python.builtins.objects.type.PythonClass;
3029
import com.oracle.truffle.api.CompilerDirectives;
3130

@@ -35,11 +34,11 @@ public final class PRangeIterator extends PIntegerIterator {
3534
final int step;
3635
int index;
3736

38-
public PRangeIterator(PythonClass clazz, PRange range) {
37+
public PRangeIterator(PythonClass clazz, int start, int stop, int step) {
3938
super(clazz);
40-
this.index = range.getStart();
41-
this.stop = range.getStop();
42-
this.step = range.getStep();
39+
index = start;
40+
this.stop = stop;
41+
this.step = step;
4342
}
4443

4544
public int getStart() {
@@ -72,13 +71,6 @@ public static final class PRangeReverseIterator extends PIntegerIterator {
7271
final int step;
7372
int index;
7473

75-
public PRangeReverseIterator(PythonClass clazz, PRange range) {
76-
super(clazz);
77-
this.index = range.getStop() - 1;
78-
this.stop = range.getStart() - 1;
79-
this.step = range.getStep();
80-
}
81-
8274
public PRangeReverseIterator(PythonClass clazz, int index, int stop, int step) {
8375
super(clazz);
8476
this.index = index;

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/range/RangeBuiltins.java

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,20 @@
3737
import com.oracle.graal.python.builtins.CoreFunctions;
3838
import com.oracle.graal.python.builtins.PythonBuiltins;
3939
import com.oracle.graal.python.builtins.objects.PNotImplemented;
40-
import com.oracle.graal.python.builtins.objects.iterator.PRangeIterator;
40+
import com.oracle.graal.python.builtins.objects.ints.PInt;
41+
import com.oracle.graal.python.builtins.objects.iterator.PIntegerIterator;
4142
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
4243
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
4344
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
4445
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
45-
import com.oracle.graal.python.runtime.sequence.PSequence;
46+
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
4647
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
4748
import com.oracle.truffle.api.dsl.Fallback;
4849
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
4950
import com.oracle.truffle.api.dsl.NodeFactory;
5051
import com.oracle.truffle.api.dsl.Specialization;
52+
import com.oracle.truffle.api.dsl.TypeSystemReference;
53+
import com.oracle.truffle.api.profiles.ConditionProfile;
5154

5255
@CoreFunctions(extendClasses = PRange.class)
5356
public class RangeBuiltins extends PythonBuiltins {
@@ -94,19 +97,65 @@ Object doGeneric(Object left, Object right) {
9497

9598
@Builtin(name = __CONTAINS__, fixedNumOfArguments = 2)
9699
@GenerateNodeFactory
100+
@TypeSystemReference(PythonArithmeticTypes.class)
97101
abstract static class ContainsNode extends PythonBinaryBuiltinNode {
102+
private final ConditionProfile stepOneProfile = ConditionProfile.createBinaryProfile();
103+
private final ConditionProfile stepMinusOneProfile = ConditionProfile.createBinaryProfile();
104+
105+
@Specialization
106+
boolean contains(PRange self, long other) {
107+
int step = self.getStep();
108+
int start = self.getStart();
109+
int stop = self.getStop();
110+
111+
if (stepOneProfile.profile(step == 1)) {
112+
return other >= start && other < stop;
113+
} else if (stepMinusOneProfile.profile(step == -1)) {
114+
return other <= start && other > stop;
115+
} else {
116+
assert step != 0;
117+
if (step > 0) {
118+
if (other >= start && other < stop) {
119+
// discard based on range
120+
return false;
121+
}
122+
} else {
123+
if (other <= start && other > stop) {
124+
// discard based on range
125+
return false;
126+
}
127+
}
128+
return (other - start) % step == 0;
129+
}
130+
}
131+
132+
@Specialization
133+
boolean contains(PRange self, double other) {
134+
return (long) other == other ? contains(self, (long) other) : false;
135+
}
136+
98137
@Specialization
99-
boolean contains(PSequence self, Object other) {
100-
return self.index(other) != -1;
138+
boolean contains(PRange self, PInt other) {
139+
try {
140+
return contains(self, other.longValueExact());
141+
} catch (ArithmeticException e) {
142+
return false;
143+
}
144+
}
145+
146+
@SuppressWarnings("unused")
147+
@Fallback
148+
boolean containsFallback(Object self, Object other) {
149+
return false;
101150
}
102151
}
103152

104153
@Builtin(name = __ITER__, fixedNumOfArguments = 1)
105154
@GenerateNodeFactory
106155
abstract static class IterNode extends PythonUnaryBuiltinNode {
107156
@Specialization
108-
PRangeIterator iter(PRange self) {
109-
return factory().createRangeIterator(self);
157+
PIntegerIterator iter(PRange self) {
158+
return factory().createRangeIterator(self.getStart(), self.getStop(), self.getStep());
110159
}
111160
}
112161
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/control/GetIteratorNode.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,7 @@ protected boolean iterCannotBeOverridden(Object value) {
9292

9393
@Specialization(guards = "iterCannotBeOverridden(value)")
9494
public PythonObject doPRange(PRange value) {
95-
if (value.getStep() > 0) {
96-
return factory().createRangeIterator(value);
97-
} else {
98-
return factory().createRangeReverseIterator(value.getStart(), value.getStop(), -value.getStep());
99-
}
95+
return factory().createRangeIterator(value.getStart(), value.getStop(), value.getStep());
10096
}
10197

10298
@Specialization(guards = "iterCannotBeOverridden(value)")
@@ -150,11 +146,6 @@ public PythonObject doPZip(PZip value) {
150146
return value;
151147
}
152148

153-
@Specialization(guards = "iterCannotBeOverridden(range)")
154-
public PythonObject doRange(PRange range) {
155-
return factory().createRangeIterator(range);
156-
}
157-
158149
@Specialization(guards = {"!isNoValue(value)"})
159150
public Object doGeneric(Object value,
160151
@Cached("createIdentityProfile()") ValueProfile getattributeProfile,

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/runtime/object/PythonObjectFactory.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
import com.oracle.graal.python.builtins.objects.iterator.PDoubleSequenceIterator;
7171
import com.oracle.graal.python.builtins.objects.iterator.PForeignArrayIterator;
7272
import com.oracle.graal.python.builtins.objects.iterator.PIntArrayIterator;
73+
import com.oracle.graal.python.builtins.objects.iterator.PIntegerIterator;
7374
import com.oracle.graal.python.builtins.objects.iterator.PIntegerSequenceIterator;
7475
import com.oracle.graal.python.builtins.objects.iterator.PLongArrayIterator;
7576
import com.oracle.graal.python.builtins.objects.iterator.PLongSequenceIterator;
@@ -606,16 +607,14 @@ public PSequenceReverseIterator createSequenceReverseIterator(PythonClass cls, O
606607
return trace(new PSequenceReverseIterator(cls, sequence, lengthHint));
607608
}
608609

609-
public PRangeIterator createRangeIterator(PRange range) {
610-
return trace(new PRangeIterator(lookupClass(PythonBuiltinClassType.PRangeIterator), range));
611-
}
612-
613-
public PRangeReverseIterator createRangeReverseIterator(PRange range) {
614-
return trace(new PRangeReverseIterator(lookupClass(PythonBuiltinClassType.PRangeReverseIterator), range));
615-
}
616-
617-
public PRangeReverseIterator createRangeReverseIterator(int start, int stop, int step) {
618-
return trace(new PRangeReverseIterator(lookupClass(PythonBuiltinClassType.PRangeReverseIterator), start, stop, step));
610+
public PIntegerIterator createRangeIterator(int start, int stop, int step) {
611+
PIntegerIterator object;
612+
if (step > 0) {
613+
object = new PRangeIterator(lookupClass(PythonBuiltinClassType.PRangeIterator), start, stop, step);
614+
} else {
615+
object = new PRangeReverseIterator(lookupClass(PythonBuiltinClassType.PRangeReverseIterator), start, stop, -step);
616+
}
617+
return trace(object);
619618
}
620619

621620
public PIntArrayIterator createIntArrayIterator(PIntArray array) {

0 commit comments

Comments
 (0)