Skip to content

Commit 035852b

Browse files
committed
[GR-12855] Implement or for frozenset.
PullRequest: graalpython/317
2 parents 0111cdf + 534b42d commit 035852b

File tree

2 files changed

+102
-1
lines changed

2 files changed

+102
-1
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_set.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ def test_set_or_union():
8080
assert or_result == {1, 2, 3}
8181
assert union_result == {1, 2, 3}
8282

83+
assert frozenset((1,2)) | {1:2}.items() == {1, 2, (1, 2)}
84+
assert frozenset((1,2)) | {1:2}.keys() == {1, 2}
85+
86+
def test_set_and():
87+
assert frozenset((1,2)) & {1:2}.items() == set()
88+
assert frozenset((1,2)) & {1:2}.keys() == {1}
8389

8490
def test_set_union():
8591
assert {1, 2, 3}.union({1: 'a', 2: 'b', 4: 'd'}) == {1, 2, 3, 4}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/FrozenSetBuiltins.java

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
package com.oracle.graal.python.builtins.objects.set;
2727

2828
import static com.oracle.graal.python.nodes.SpecialMethodNames.__AND__;
29+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__OR__;
2930
import static com.oracle.graal.python.nodes.SpecialMethodNames.__CONTAINS__;
3031
import static com.oracle.graal.python.nodes.SpecialMethodNames.__EQ__;
3132
import static com.oracle.graal.python.nodes.SpecialMethodNames.__GE__;
@@ -51,6 +52,7 @@
5152
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes;
5253
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.PythonEquivalence;
5354
import com.oracle.graal.python.builtins.objects.common.PHashingCollection;
55+
import com.oracle.graal.python.builtins.objects.dict.PDictView;
5456
import com.oracle.graal.python.builtins.objects.set.FrozenSetBuiltinsFactory.BinaryUnionNodeGen;
5557
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
5658
import com.oracle.graal.python.nodes.PNodeWithContext;
@@ -73,6 +75,7 @@
7375
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
7476
import com.oracle.truffle.api.dsl.NodeFactory;
7577
import com.oracle.truffle.api.dsl.Specialization;
78+
import com.oracle.truffle.api.nodes.Node;
7679
import com.oracle.truffle.api.profiles.ConditionProfile;
7780
import com.oracle.truffle.api.profiles.ValueProfile;
7881

@@ -206,9 +209,101 @@ PBaseSet doPBaseSet(PFrozenSet left, PBaseSet right) {
206209
return factory().createFrozenSet(intersectedStorage);
207210
}
208211

212+
@Specialization
213+
PBaseSet doPBaseSet(PSet left, PDictView right,
214+
@Cached("create()") SetNodes.ConstructSetNode constructSetNode) {
215+
PSet rightSet = constructSetNode.executeWith(right);
216+
HashingStorage intersectedStorage = getIntersectNode().execute(left.getDictStorage(), rightSet.getDictStorage());
217+
return factory().createSet(intersectedStorage);
218+
}
219+
220+
@Specialization
221+
PBaseSet doPBaseSet(PFrozenSet left, PDictView right,
222+
@Cached("create()") SetNodes.ConstructSetNode constructSetNode) {
223+
PSet rightSet = constructSetNode.executeWith(right);
224+
HashingStorage intersectedStorage = getIntersectNode().execute(left.getDictStorage(), rightSet.getDictStorage());
225+
return factory().createSet(intersectedStorage);
226+
}
227+
209228
@Fallback
210229
Object doAnd(Object self, Object other) {
211-
throw raise(PythonErrorType.TypeError, "unsupported operand type(s) for &=: '%p' and '%p'", self, other);
230+
throw raise(PythonErrorType.TypeError, "unsupported operand type(s) for &: '%p' and '%p'", self, other);
231+
}
232+
}
233+
234+
@Builtin(name = __OR__, fixedNumOfPositionalArgs = 2)
235+
@GenerateNodeFactory
236+
abstract static class OrNode extends PythonBinaryBuiltinNode {
237+
@Node.Child private HashingStorageNodes.UnionNode unionNode;
238+
@Node.Child private HashingStorageNodes.SetItemNode setItemNode;
239+
240+
private HashingStorageNodes.SetItemNode getSetItemNode() {
241+
if (setItemNode == null) {
242+
CompilerDirectives.transferToInterpreterAndInvalidate();
243+
setItemNode = insert(HashingStorageNodes.SetItemNode.create());
244+
}
245+
return setItemNode;
246+
}
247+
248+
@TruffleBoundary
249+
private HashingStorage getStringAsHashingStorage(String str) {
250+
HashingStorage storage = EconomicMapStorage.create(str.length(), true);
251+
for (int i = 0; i < str.length(); i++) {
252+
String key = String.valueOf(str.charAt(i));
253+
getSetItemNode().execute(storage, key, PNone.NO_VALUE);
254+
}
255+
return storage;
256+
}
257+
258+
@Specialization
259+
PBaseSet doPBaseSet(PSet left, String right) {
260+
return factory().createSet(getUnionNode().execute(left.getDictStorage(), getStringAsHashingStorage(right)));
261+
}
262+
263+
@Specialization
264+
PBaseSet doPBaseSet(PFrozenSet left, String right) {
265+
return factory().createFrozenSet(getUnionNode().execute(left.getDictStorage(), getStringAsHashingStorage(right)));
266+
}
267+
268+
private HashingStorageNodes.UnionNode getUnionNode() {
269+
if (unionNode == null) {
270+
CompilerDirectives.transferToInterpreterAndInvalidate();
271+
unionNode = insert(HashingStorageNodes.UnionNode.create());
272+
}
273+
return unionNode;
274+
}
275+
276+
@Specialization
277+
PBaseSet doPBaseSet(PSet left, PBaseSet right) {
278+
HashingStorage intersectedStorage = getUnionNode().execute(left.getDictStorage(), right.getDictStorage());
279+
return factory().createSet(intersectedStorage);
280+
}
281+
282+
@Specialization
283+
PBaseSet doPBaseSet(PFrozenSet left, PBaseSet right) {
284+
HashingStorage intersectedStorage = getUnionNode().execute(left.getDictStorage(), right.getDictStorage());
285+
return factory().createFrozenSet(intersectedStorage);
286+
}
287+
288+
@Specialization
289+
PBaseSet doPBaseSet(PSet left, PDictView right,
290+
@Cached("create()") SetNodes.ConstructSetNode constructSetNode) {
291+
PSet rightSet = constructSetNode.executeWith(right);
292+
HashingStorage intersectedStorage = getUnionNode().execute(left.getDictStorage(), rightSet.getDictStorage());
293+
return factory().createSet(intersectedStorage);
294+
}
295+
296+
@Specialization
297+
PBaseSet doPBaseSet(PFrozenSet left, PDictView right,
298+
@Cached("create()") SetNodes.ConstructSetNode constructSetNode) {
299+
PSet rightSet = constructSetNode.executeWith(right);
300+
HashingStorage intersectedStorage = getUnionNode().execute(left.getDictStorage(), rightSet.getDictStorage());
301+
return factory().createSet(intersectedStorage);
302+
}
303+
304+
@Fallback
305+
Object doOr(Object self, Object other) {
306+
throw raise(PythonErrorType.TypeError, "unsupported operand type(s) for |: '%p' and '%p'", self, other);
212307
}
213308
}
214309

0 commit comments

Comments
 (0)