Skip to content

Commit 538c787

Browse files
committed
support "in" keyword for iterators if the container does not define __contains__
1 parent dec5fb5 commit 538c787

File tree

3 files changed

+205
-2
lines changed

3 files changed

+205
-2
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_iterator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,10 @@ def test_iter_try_except():
115115
break
116116

117117
assert exit_via_break
118+
119+
120+
def test_iterator_in():
121+
assert 1 not in (i for i in range(1))
122+
assert 1 in (i for i in range(2))
123+
assert 1 in iter(range(2))
124+
assert 1 not in iter(range(1))

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/NodeFactory.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import com.oracle.graal.python.nodes.expression.BinaryArithmetic;
6060
import com.oracle.graal.python.nodes.expression.BinaryComparisonNode;
6161
import com.oracle.graal.python.nodes.expression.CastToBooleanNode;
62+
import com.oracle.graal.python.nodes.expression.ContainsNode;
6263
import com.oracle.graal.python.nodes.expression.ExpressionNode;
6364
import com.oracle.graal.python.nodes.expression.InplaceArithmetic;
6465
import com.oracle.graal.python.nodes.expression.IsNode;
@@ -389,9 +390,9 @@ public ExpressionNode createComparisonOperation(String operator, ExpressionNode
389390
case "!=":
390391
return BinaryComparisonNode.create(SpecialMethodNames.__NE__, SpecialMethodNames.__NE__, operator, left, right);
391392
case "in":
392-
return BinaryComparisonNode.create(SpecialMethodNames.__CONTAINS__, null, operator, right, left);
393+
return ContainsNode.create(right, left);
393394
case "notin":
394-
return CastToBooleanNode.createIfFalseNode(BinaryComparisonNode.create(SpecialMethodNames.__CONTAINS__, null, operator, right, left));
395+
return CastToBooleanNode.createIfFalseNode(ContainsNode.create(right, left));
395396
case "is":
396397
return IsNode.create(left, right);
397398
case "isnot":
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
package com.oracle.graal.python.nodes.expression;
2+
3+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__CONTAINS__;
4+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__EQ__;
5+
6+
import com.oracle.graal.python.builtins.objects.PNotImplemented;
7+
import com.oracle.graal.python.nodes.call.special.LookupAndCallBinaryNode;
8+
import com.oracle.graal.python.nodes.control.GetIteratorNode;
9+
import com.oracle.graal.python.nodes.control.GetNextNode;
10+
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
11+
import com.oracle.graal.python.runtime.exception.PException;
12+
import com.oracle.truffle.api.CompilerDirectives;
13+
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
14+
import com.oracle.truffle.api.dsl.Specialization;
15+
import com.oracle.truffle.api.nodes.UnexpectedResultException;
16+
17+
public abstract class ContainsNode extends BinaryOpNode {
18+
@Child private LookupAndCallBinaryNode callNode = LookupAndCallBinaryNode.create(__CONTAINS__, null);
19+
@Child private CastToBooleanNode castBool = CastToBooleanNode.createIfTrueNode();
20+
21+
@Child private GetIteratorNode getIterator;
22+
@Child private GetNextNode next;
23+
@Child private BinaryComparisonNode eqNode;
24+
@CompilationFinal private IsBuiltinClassProfile errorProfile;
25+
26+
public static ExpressionNode create(ExpressionNode right, ExpressionNode left) {
27+
return ContainsNodeGen.create(right, left);
28+
}
29+
30+
@Specialization(rewriteOn = UnexpectedResultException.class)
31+
boolean doBoolean(Object iter, boolean item) throws UnexpectedResultException {
32+
Object result = callNode.executeObject(iter, item);
33+
if (result == PNotImplemented.NOT_IMPLEMENTED) {
34+
Object iterator = getGetIterator().executeWith(iter);
35+
return sequenceContains(iterator, item);
36+
}
37+
return castBool.executeWith(result);
38+
}
39+
40+
@Specialization(rewriteOn = UnexpectedResultException.class)
41+
boolean doInt(Object iter, int item) throws UnexpectedResultException {
42+
Object result = callNode.executeObject(iter, item);
43+
if (result == PNotImplemented.NOT_IMPLEMENTED) {
44+
return sequenceContains(getGetIterator().executeWith(iter), item);
45+
}
46+
return castBool.executeWith(result);
47+
}
48+
49+
@Specialization(rewriteOn = UnexpectedResultException.class)
50+
boolean doLong(Object iter, long item) throws UnexpectedResultException {
51+
Object result = callNode.executeObject(iter, item);
52+
if (result == PNotImplemented.NOT_IMPLEMENTED) {
53+
return sequenceContains(getGetIterator().executeWith(iter), item);
54+
}
55+
return castBool.executeWith(result);
56+
}
57+
58+
@Specialization(rewriteOn = UnexpectedResultException.class)
59+
boolean doDouble(Object iter, double item) throws UnexpectedResultException {
60+
Object result = callNode.executeObject(iter, item);
61+
if (result == PNotImplemented.NOT_IMPLEMENTED) {
62+
return sequenceContains(getGetIterator().executeWith(iter), item);
63+
}
64+
return castBool.executeWith(result);
65+
}
66+
67+
@Specialization
68+
boolean doGeneric(Object iter, Object item) {
69+
Object result = callNode.executeObject(iter, item);
70+
if (result == PNotImplemented.NOT_IMPLEMENTED) {
71+
return sequenceContainsObject(getGetIterator().executeWith(iter), item);
72+
}
73+
return castBool.executeWith(result);
74+
}
75+
76+
private void handleUnexpectedResult(Object iterator, Object item, UnexpectedResultException e) throws UnexpectedResultException {
77+
// If we got an unexpected (non-primitive) result from the iterator, we need to compare it
78+
// and continue iterating with "next" through the generic case. However, we also want the
79+
// specialization to go away, so we wrap the boolean result in a new
80+
// UnexpectedResultException. This will cause the DSL to disable the specialization with the
81+
// primitive value and return the result we got, without iterating again.
82+
Object result = e.getResult();
83+
if (getEqNode().executeBool(result, item)) {
84+
result = true;
85+
} else {
86+
result = sequenceContainsObject(iterator, item);
87+
}
88+
throw new UnexpectedResultException(result);
89+
}
90+
91+
private boolean sequenceContains(Object iterator, boolean item) throws UnexpectedResultException {
92+
while (true) {
93+
try {
94+
if (getNext().executeBoolean(iterator) == item) {
95+
return true;
96+
}
97+
} catch (PException e) {
98+
e.expectStopIteration(getErrorProfile());
99+
return false;
100+
} catch (UnexpectedResultException e) {
101+
handleUnexpectedResult(iterator, item, e);
102+
}
103+
}
104+
}
105+
106+
private boolean sequenceContains(Object iterator, int item) throws UnexpectedResultException {
107+
while (true) {
108+
try {
109+
if (getNext().executeInt(iterator) == item) {
110+
return true;
111+
}
112+
} catch (PException e) {
113+
e.expectStopIteration(getErrorProfile());
114+
return false;
115+
} catch (UnexpectedResultException e) {
116+
handleUnexpectedResult(iterator, item, e);
117+
}
118+
}
119+
}
120+
121+
private boolean sequenceContains(Object iterator, long item) throws UnexpectedResultException {
122+
while (true) {
123+
try {
124+
if (getNext().executeLong(iterator) == item) {
125+
return true;
126+
}
127+
} catch (PException e) {
128+
e.expectStopIteration(getErrorProfile());
129+
return false;
130+
} catch (UnexpectedResultException e) {
131+
handleUnexpectedResult(iterator, item, e);
132+
}
133+
}
134+
}
135+
136+
private boolean sequenceContains(Object iterator, double item) throws UnexpectedResultException {
137+
while (true) {
138+
try {
139+
if (getNext().executeDouble(iterator) == item) {
140+
return true;
141+
}
142+
} catch (PException e) {
143+
e.expectStopIteration(getErrorProfile());
144+
return false;
145+
} catch (UnexpectedResultException e) {
146+
handleUnexpectedResult(iterator, item, e);
147+
}
148+
}
149+
}
150+
151+
private boolean sequenceContainsObject(Object iterator, Object item) {
152+
while (true) {
153+
try {
154+
if (getEqNode().executeBool(getNext().execute(iterator), item)) {
155+
return true;
156+
}
157+
} catch (PException e) {
158+
e.expectStopIteration(getErrorProfile());
159+
return false;
160+
}
161+
}
162+
}
163+
164+
private BinaryComparisonNode getEqNode() {
165+
if (eqNode == null) {
166+
CompilerDirectives.transferToInterpreterAndInvalidate();
167+
eqNode = insert(BinaryComparisonNode.create(__EQ__, __EQ__, "=="));
168+
}
169+
return eqNode;
170+
}
171+
172+
private IsBuiltinClassProfile getErrorProfile() {
173+
if (errorProfile == null) {
174+
CompilerDirectives.transferToInterpreterAndInvalidate();
175+
errorProfile = IsBuiltinClassProfile.create();
176+
}
177+
return errorProfile;
178+
}
179+
180+
private GetNextNode getNext() {
181+
if (next == null) {
182+
CompilerDirectives.transferToInterpreterAndInvalidate();
183+
next = insert(GetNextNode.create());
184+
}
185+
return next;
186+
}
187+
188+
private GetIteratorNode getGetIterator() {
189+
if (getIterator == null) {
190+
CompilerDirectives.transferToInterpreterAndInvalidate();
191+
getIterator = insert(GetIteratorNode.create());
192+
}
193+
return getIterator;
194+
}
195+
}

0 commit comments

Comments
 (0)