Skip to content

Commit 44f00b3

Browse files
fangerercosminbasca
authored andcommitted
Implement 'set.union': resolve conflicts
1 parent 4483a7a commit 44f00b3

File tree

4 files changed

+201
-30
lines changed

4 files changed

+201
-30
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_set.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,32 @@ def check_pass_thru():
5757
yield 1
5858

5959

60-
def test_set_or():
60+
def test_set_or_union():
6161
s1 = {1, 2, 3}
6262
s2 = {4, 5, 6}
6363
s3 = {1, 2, 4}
6464
s4 = {1, 2, 3}
6565

66-
union = s1 | s2
67-
assert union == {1, 2, 3, 4, 5, 6}
66+
or_result = s1 | s2
67+
union_result = s1.union(s2)
68+
assert or_result == {1, 2, 3, 4, 5, 6}
69+
assert union_result == {1, 2, 3, 4, 5, 6}
6870

69-
union = s1 | s3
70-
assert union == {1, 2, 3, 4}
71+
or_result = s1 | s3
72+
union_result = s1.union(s3)
73+
assert or_result == {1, 2, 3, 4}
74+
assert union_result == {1, 2, 3, 4}
7175

72-
union = s1 | s4
73-
assert union == {1, 2, 3}
76+
or_result = s1 | s4
77+
union_result = s1.union(s4)
78+
assert or_result == {1, 2, 3}
79+
assert union_result == {1, 2, 3}
80+
81+
82+
def test_set_union():
83+
assert {1, 2, 3}.union({1: 'a', 2: 'b', 4: 'd'}) == {1, 2, 3, 4}
84+
assert {1, 2, 3}.union([2, 3, 4, 5]) == {1, 2, 3, 4, 5}
85+
assert {1, 2, 3}.union((3, 4, 5, 6)) == {1, 2, 3, 4, 5, 6}
7486

7587

7688
def test_set_remove():

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingStorageNodes.java

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodesFactory.InitNodeGen;
6262
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodesFactory.KeysEqualsNodeGen;
6363
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodesFactory.SetItemNodeGen;
64+
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodesFactory.UnionNodeGen;
6465
import com.oracle.graal.python.builtins.objects.dict.PDict;
6566
import com.oracle.graal.python.builtins.objects.function.PKeyword;
6667
import com.oracle.graal.python.builtins.objects.ints.PInt;
@@ -1034,6 +1035,7 @@ public static GetItemNode create() {
10341035
}
10351036

10361037
public abstract static class EqualsNode extends DictStorageBaseNode {
1038+
10371039
@Child private GetItemNode getLeftItemNode;
10381040
@Child private GetItemNode getRightItemNode;
10391041

@@ -1056,9 +1058,9 @@ private GetItemNode getRightItemNode() {
10561058
}
10571059

10581060
@Specialization(guards = "selfStorage.length() == other.length()")
1059-
boolean doKeywordsString(LocalsStorage selfStorage, LocalsStorage other) {
1061+
boolean doLocals(LocalsStorage selfStorage, LocalsStorage other) {
10601062
if (selfStorage.getFrame().getFrameDescriptor() == other.getFrame().getFrameDescriptor()) {
1061-
return doKeywordsString(selfStorage, other);
1063+
return doGeneric(selfStorage, other);
10621064
}
10631065
return false;
10641066
}
@@ -1080,7 +1082,7 @@ boolean doKeywordsString(DynamicObjectStorage selfStorage, DynamicObjectStorage
10801082
}
10811083

10821084
@Specialization(guards = "selfStorage.length() == other.length()")
1083-
boolean doKeywordsString(HashingStorage selfStorage, HashingStorage other) {
1085+
boolean doGeneric(HashingStorage selfStorage, HashingStorage other) {
10841086
if (selfStorage.length() == other.length()) {
10851087
Iterable<Object> keys = selfStorage.keys();
10861088
for (Object key : keys) {
@@ -1097,7 +1099,7 @@ boolean doKeywordsString(HashingStorage selfStorage, HashingStorage other) {
10971099

10981100
@SuppressWarnings("unused")
10991101
@Fallback
1100-
boolean doGeneric(HashingStorage selfStorage, HashingStorage other) {
1102+
boolean doFallback(HashingStorage selfStorage, HashingStorage other) {
11011103
return false;
11021104
}
11031105

@@ -1324,17 +1326,42 @@ public static IntersectNode create() {
13241326
}
13251327
}
13261328

1327-
public static class UnionNode extends Node {
1329+
public abstract static class UnionNode extends DictStorageBaseNode {
13281330

1329-
public HashingStorage execute(HashingStorage left, HashingStorage right) {
1330-
EconomicMapStorage newStorage = EconomicMapStorage.create(false);
1331+
protected final boolean setUnion;
1332+
1333+
public UnionNode(boolean setUnion) {
1334+
this.setUnion = setUnion;
1335+
}
1336+
1337+
public abstract HashingStorage execute(HashingStorage left, HashingStorage right);
1338+
1339+
@Specialization(guards = "setUnion")
1340+
public HashingStorage doGenericSet(HashingStorage left, HashingStorage right) {
1341+
EconomicMapStorage newStorage = EconomicMapStorage.create(setUnion);
1342+
for (Object key : left.keys()) {
1343+
newStorage.setItem(key, PNone.NO_VALUE, getEquivalence());
1344+
}
1345+
for (Object key : right.keys()) {
1346+
newStorage.setItem(key, PNone.NO_VALUE, getEquivalence());
1347+
}
1348+
return newStorage;
1349+
}
1350+
1351+
@Specialization(guards = "!setUnion")
1352+
public HashingStorage doGeneric(HashingStorage left, HashingStorage right) {
1353+
EconomicMapStorage newStorage = EconomicMapStorage.create(setUnion);
13311354
newStorage.addAll(left);
13321355
newStorage.addAll(right);
13331356
return newStorage;
13341357
}
13351358

13361359
public static UnionNode create() {
1337-
return new UnionNode();
1360+
return create(false);
1361+
}
1362+
1363+
public static UnionNode create(boolean setUnion) {
1364+
return UnionNodeGen.create(setUnion);
13381365
}
13391366
}
13401367

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/FrozenSetBuiltins.java

Lines changed: 144 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,32 @@
3838
import com.oracle.graal.python.builtins.Builtin;
3939
import com.oracle.graal.python.builtins.CoreFunctions;
4040
import com.oracle.graal.python.builtins.PythonBuiltins;
41+
import com.oracle.graal.python.builtins.objects.PNone;
4142
import com.oracle.graal.python.builtins.objects.PNotImplemented;
43+
import com.oracle.graal.python.builtins.objects.common.EconomicMapStorage;
4244
import com.oracle.graal.python.builtins.objects.common.HashingStorage;
45+
import com.oracle.graal.python.builtins.objects.common.HashingStorage.Equivalence;
4346
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes;
47+
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.PythonEquivalence;
48+
import com.oracle.graal.python.builtins.objects.common.PHashingCollection;
49+
import com.oracle.graal.python.builtins.objects.set.FrozenSetBuiltinsFactory.BinaryUnionNodeGen;
50+
import com.oracle.graal.python.nodes.PBaseNode;
51+
import com.oracle.graal.python.nodes.control.GetIteratorNode;
52+
import com.oracle.graal.python.nodes.control.GetNextNode;
4453
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
54+
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
4555
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
4656
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
57+
import com.oracle.graal.python.runtime.exception.PException;
58+
import com.oracle.truffle.api.CompilerDirectives;
59+
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
4760
import com.oracle.truffle.api.dsl.Cached;
4861
import com.oracle.truffle.api.dsl.Fallback;
4962
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
5063
import com.oracle.truffle.api.dsl.NodeFactory;
5164
import com.oracle.truffle.api.dsl.Specialization;
65+
import com.oracle.truffle.api.profiles.ConditionProfile;
66+
import com.oracle.truffle.api.profiles.ValueProfile;
5267

5368
@CoreFunctions(extendClasses = {PFrozenSet.class, PSet.class})
5469
public final class FrozenSetBuiltins extends PythonBuiltins {
@@ -116,35 +131,51 @@ Object run(PBaseSet self, PBaseSet other) {
116131
@Builtin(name = __AND__, fixedNumOfArguments = 2)
117132
@GenerateNodeFactory
118133
abstract static class AndNode extends PythonBinaryBuiltinNode {
134+
@Child private HashingStorageNodes.IntersectNode intersectNode;
135+
119136
@Specialization
120-
PBaseSet doPBaseSet(PSet left, PBaseSet right,
121-
@Cached("create()") HashingStorageNodes.IntersectNode intersectNode) {
122-
HashingStorage intersectedStorage = intersectNode.execute(left.getDictStorage(), right.getDictStorage());
137+
PBaseSet doPBaseSet(PSet left, PBaseSet right) {
138+
HashingStorage intersectedStorage = getIntersectNode().execute(left.getDictStorage(), right.getDictStorage());
123139
return factory().createSet(intersectedStorage);
124140
}
125141

126142
@Specialization
127-
PBaseSet doPBaseSet(PFrozenSet left, PBaseSet right,
128-
@Cached("create()") HashingStorageNodes.IntersectNode intersectNode) {
129-
HashingStorage intersectedStorage = intersectNode.execute(left.getDictStorage(), right.getDictStorage());
143+
PBaseSet doPBaseSet(PFrozenSet left, PBaseSet right) {
144+
HashingStorage intersectedStorage = getIntersectNode().execute(left.getDictStorage(), right.getDictStorage());
130145
return factory().createFrozenSet(intersectedStorage);
131146
}
147+
148+
private HashingStorageNodes.IntersectNode getIntersectNode() {
149+
if (intersectNode == null) {
150+
CompilerDirectives.transferToInterpreterAndInvalidate();
151+
intersectNode = insert(HashingStorageNodes.IntersectNode.create());
152+
}
153+
return intersectNode;
154+
}
132155
}
133156

134157
@Builtin(name = __SUB__, fixedNumOfArguments = 2)
135158
@GenerateNodeFactory
136159
abstract static class SubNode extends PythonBinaryBuiltinNode {
160+
@Child private HashingStorageNodes.DiffNode diffNode;
161+
162+
private HashingStorageNodes.DiffNode getDiffNode() {
163+
if (diffNode == null) {
164+
CompilerDirectives.transferToInterpreterAndInvalidate();
165+
diffNode = HashingStorageNodes.DiffNode.create();
166+
}
167+
return diffNode;
168+
}
169+
137170
@Specialization
138-
PBaseSet doPBaseSet(PSet left, PBaseSet right,
139-
@Cached("create()") HashingStorageNodes.DiffNode diffNode) {
140-
HashingStorage storage = diffNode.execute(left.getDictStorage(), right.getDictStorage());
171+
PBaseSet doPBaseSet(PSet left, PBaseSet right) {
172+
HashingStorage storage = getDiffNode().execute(left.getDictStorage(), right.getDictStorage());
141173
return factory().createSet(storage);
142174
}
143175

144176
@Specialization
145-
PBaseSet doPBaseSet(PFrozenSet left, PBaseSet right,
146-
@Cached("create()") HashingStorageNodes.DiffNode diffNode) {
147-
HashingStorage storage = diffNode.execute(left.getDictStorage(), right.getDictStorage());
177+
PBaseSet doPBaseSet(PFrozenSet left, PBaseSet right) {
178+
HashingStorage storage = getDiffNode().execute(left.getDictStorage(), right.getDictStorage());
148179
return factory().createSet(storage);
149180
}
150181
}
@@ -158,4 +189,105 @@ boolean contains(PBaseSet self, Object key,
158189
return containsKeyNode.execute(self.getDictStorage(), key);
159190
}
160191
}
192+
193+
@Builtin(name = "union", minNumOfArguments = 1, takesVariableArguments = true)
194+
@GenerateNodeFactory
195+
abstract static class UnionNode extends PythonBuiltinNode {
196+
197+
@Child private BinaryUnionNode binaryUnionNode;
198+
199+
@CompilationFinal private ValueProfile setTypeProfile;
200+
201+
@Specialization(guards = {"args.length == len", "args.length < 32"}, limit = "3")
202+
PBaseSet doCached(PBaseSet self, Object[] args,
203+
@Cached("args.length") int len,
204+
@Cached("create()") HashingStorageNodes.CopyNode copyNode) {
205+
PBaseSet result = create(self, copyNode.execute(self.getDictStorage()));
206+
for (int i = 0; i < len; i++) {
207+
getBinaryUnionNode().execute(result, result.getDictStorage(), args[i]);
208+
}
209+
return result;
210+
}
211+
212+
@Specialization(replaces = "doCached")
213+
PBaseSet doGeneric(PBaseSet self, Object[] args,
214+
@Cached("create()") HashingStorageNodes.CopyNode copyNode) {
215+
PBaseSet result = create(self, copyNode.execute(self.getDictStorage()));
216+
for (int i = 0; i < args.length; i++) {
217+
getBinaryUnionNode().execute(result, result.getDictStorage(), args[i]);
218+
}
219+
return result;
220+
}
221+
222+
private PBaseSet create(PBaseSet left, HashingStorage storage) {
223+
if (getSetTypeProfile().profile(left) instanceof PFrozenSet) {
224+
return factory().createFrozenSet(storage);
225+
}
226+
return factory().createSet(storage);
227+
}
228+
229+
private BinaryUnionNode getBinaryUnionNode() {
230+
if (binaryUnionNode == null) {
231+
CompilerDirectives.transferToInterpreterAndInvalidate();
232+
binaryUnionNode = insert(BinaryUnionNode.create());
233+
}
234+
return binaryUnionNode;
235+
}
236+
237+
private ValueProfile getSetTypeProfile() {
238+
if (setTypeProfile == null) {
239+
CompilerDirectives.transferToInterpreterAndInvalidate();
240+
setTypeProfile = ValueProfile.createClassProfile();
241+
}
242+
return setTypeProfile;
243+
}
244+
245+
}
246+
247+
abstract static class BinaryUnionNode extends PBaseNode {
248+
@Child private Equivalence equivalenceNode;
249+
250+
public abstract PBaseSet execute(PBaseSet container, HashingStorage left, Object right);
251+
252+
@Specialization
253+
PBaseSet doHashingCollection(PBaseSet container, EconomicMapStorage selfStorage, PHashingCollection other) {
254+
for (Object key : other.getDictStorage().keys()) {
255+
selfStorage.setItem(key, PNone.NO_VALUE, getEquivalence());
256+
}
257+
return container;
258+
}
259+
260+
@Specialization
261+
PBaseSet doIterable(PBaseSet container, HashingStorage dictStorage, Object iterable,
262+
@Cached("create()") GetIteratorNode getIteratorNode,
263+
@Cached("create()") GetNextNode next,
264+
@Cached("createBinaryProfile()") ConditionProfile errorProfile,
265+
@Cached("create()") HashingStorageNodes.SetItemNode setItemNode) {
266+
267+
Object iterator = getIteratorNode.executeWith(iterable);
268+
while (true) {
269+
Object value;
270+
try {
271+
value = next.execute(iterator);
272+
} catch (PException e) {
273+
e.expectStopIteration(getCore(), errorProfile);
274+
return container;
275+
}
276+
setItemNode.execute(container, dictStorage, value, PNone.NO_VALUE);
277+
}
278+
}
279+
280+
protected Equivalence getEquivalence() {
281+
if (equivalenceNode == null) {
282+
CompilerDirectives.transferToInterpreterAndInvalidate();
283+
equivalenceNode = insert(new PythonEquivalence());
284+
}
285+
return equivalenceNode;
286+
}
287+
288+
public static BinaryUnionNode create() {
289+
return BinaryUnionNodeGen.create();
290+
}
291+
292+
}
161293
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/SetBuiltins.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import com.oracle.graal.python.builtins.objects.PNone;
3838
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes;
3939
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
40-
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
4140
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
4241
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
4342
import com.oracle.graal.python.runtime.exception.PythonErrorType;
@@ -80,7 +79,7 @@ public Object add(PSet self, Object o,
8079

8180
@Builtin(name = __HASH__, fixedNumOfArguments = 1)
8281
@GenerateNodeFactory
83-
public abstract static class HashNode extends PythonBuiltinNode {
82+
public abstract static class HashNode extends PythonUnaryBuiltinNode {
8483
@Specialization
8584
Object doGeneric(Object self) {
8685
throw raise(TypeError, "unhashable type: '%p'", self);
@@ -89,7 +88,7 @@ Object doGeneric(Object self) {
8988

9089
@Builtin(name = __OR__, fixedNumOfArguments = 2)
9190
@GenerateNodeFactory
92-
public abstract static class SetOrNode extends PythonBuiltinNode {
91+
public abstract static class SetOrNode extends PythonBinaryBuiltinNode {
9392
@Specialization
9493
Object doSet(PBaseSet self, PBaseSet other,
9594
@Cached("create()") HashingStorageNodes.UnionNode unionNode) {
@@ -122,4 +121,5 @@ Object discard(PBaseSet self, Object other,
122121
return PNone.NONE;
123122
}
124123
}
124+
125125
}

0 commit comments

Comments
 (0)