Skip to content

Commit 3e51dcd

Browse files
committed
Support iterable unpacking in SetLiteralNode.
1 parent 6e95970 commit 3e51dcd

File tree

3 files changed

+85
-9
lines changed

3 files changed

+85
-9
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_set.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,13 @@ def test_set_delete():
293293
assert s == {'a', 'b', 'c'}
294294
s.discard('c')
295295
assert s == {'a', 'b'}
296+
297+
298+
def test_literal():
299+
d = {"a": 1, "b": 2, "c": 3}
300+
e = {"uff": "foo"}
301+
assert {*d, *e} == {"a", "b", "c", "uff"}
302+
303+
d = {}
304+
assert {*d} == set()
305+

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/literal/SetLiteralNode.java

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,35 +28,100 @@
2828
import com.oracle.graal.python.builtins.objects.PNone;
2929
import com.oracle.graal.python.builtins.objects.common.HashingStorage;
3030
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes;
31+
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
3132
import com.oracle.graal.python.builtins.objects.dict.PDict;
33+
import com.oracle.graal.python.builtins.objects.set.PSet;
34+
import com.oracle.graal.python.nodes.PNode;
3235
import com.oracle.graal.python.nodes.expression.ExpressionNode;
3336
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
37+
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
3438
import com.oracle.truffle.api.CompilerDirectives;
3539
import com.oracle.truffle.api.frame.VirtualFrame;
3640
import com.oracle.truffle.api.nodes.ExplodeLoop;
3741

3842
public final class SetLiteralNode extends LiteralNode {
3943
@Child private PythonObjectFactory factory = PythonObjectFactory.create();
4044
@Children private final ExpressionNode[] values;
41-
@Child HashingStorageNodes.SetItemNode setItemNode;
45+
@Child private HashingStorageNodes.SetItemNode setItemNode;
46+
@Child private SequenceStorageNodes.LenNode lenNode;
47+
@Child private SequenceStorageNodes.GetItemNode getItemNode;
48+
49+
private final boolean hasStarredExpressions;
4250

4351
public SetLiteralNode(ExpressionNode[] values) {
4452
this.values = values;
53+
for (PNode v : values) {
54+
if (v instanceof StarredExpressionNode) {
55+
hasStarredExpressions = true;
56+
return;
57+
}
58+
}
59+
hasStarredExpressions = false;
4560
}
4661

4762
@Override
63+
public PSet execute(VirtualFrame frame) {
64+
if (!hasStarredExpressions) {
65+
return directSet(frame);
66+
} else {
67+
return expandingSet(frame);
68+
}
69+
}
70+
4871
@ExplodeLoop
49-
public Object execute(VirtualFrame frame) {
72+
private PSet expandingSet(VirtualFrame frame) {
73+
// we will usually have more than 'values.length' elements
5074
HashingStorage storage = PDict.createNewStorage(true, values.length);
75+
for (ExpressionNode n : values) {
76+
if (n instanceof StarredExpressionNode) {
77+
storage = addAllElement(frame, storage, ((StarredExpressionNode) n).getStorage(frame));
78+
} else {
79+
Object element = n.execute(frame);
80+
storage = ensureSetItemNode().execute(frame, storage, element, PNone.NO_VALUE);
81+
}
82+
}
83+
return factory.createSet(storage);
84+
}
5185

52-
if (setItemNode == null && values.length > 0) {
53-
CompilerDirectives.transferToInterpreterAndInvalidate();
54-
setItemNode = insert(HashingStorageNodes.SetItemNode.create());
86+
private HashingStorage addAllElement(VirtualFrame frame, HashingStorage setStorage, SequenceStorage sequenceStorage) {
87+
int n = ensureLenNode().execute(sequenceStorage);
88+
for (int i = 0; i < n; i++) {
89+
Object element = ensureGetItemNode().execute(sequenceStorage, i);
90+
setStorage = ensureSetItemNode().execute(frame, setStorage, element, PNone.NO_VALUE);
5591
}
92+
return setStorage;
93+
}
94+
95+
@ExplodeLoop
96+
private PSet directSet(VirtualFrame frame) {
97+
HashingStorage storage = PDict.createNewStorage(true, values.length);
5698
for (ExpressionNode v : this.values) {
57-
storage = setItemNode.execute(frame, storage, v.execute(frame), PNone.NO_VALUE);
99+
storage = ensureSetItemNode().execute(frame, storage, v.execute(frame), PNone.NO_VALUE);
58100
}
59-
60101
return factory.createSet(storage);
61102
}
103+
104+
private SequenceStorageNodes.LenNode ensureLenNode() {
105+
if (lenNode == null) {
106+
CompilerDirectives.transferToInterpreterAndInvalidate();
107+
lenNode = insert(SequenceStorageNodes.LenNode.create());
108+
}
109+
return lenNode;
110+
}
111+
112+
private SequenceStorageNodes.GetItemNode ensureGetItemNode() {
113+
if (getItemNode == null) {
114+
CompilerDirectives.transferToInterpreterAndInvalidate();
115+
getItemNode = insert(SequenceStorageNodes.GetItemNode.create());
116+
}
117+
return getItemNode;
118+
}
119+
120+
private HashingStorageNodes.SetItemNode ensureSetItemNode() {
121+
if (setItemNode == null) {
122+
CompilerDirectives.transferToInterpreterAndInvalidate();
123+
setItemNode = insert(HashingStorageNodes.SetItemNode.create());
124+
}
125+
return setItemNode;
126+
}
62127
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/literal/TupleLiteralNode.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
2929
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.NoGeneralizationNode;
30+
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
3031
import com.oracle.graal.python.nodes.PNode;
3132
import com.oracle.graal.python.nodes.expression.ExpressionNode;
3233
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
@@ -68,7 +69,7 @@ public Object execute(VirtualFrame frame) {
6869
}
6970

7071
@ExplodeLoop
71-
private Object expandingTuple(VirtualFrame frame) {
72+
private PTuple expandingTuple(VirtualFrame frame) {
7273
// we will usually have more than 'values.length' elements
7374
SequenceStorage storage = new ObjectSequenceStorage(values.length);
7475
for (ExpressionNode n : values) {
@@ -84,7 +85,7 @@ private Object expandingTuple(VirtualFrame frame) {
8485
}
8586

8687
@ExplodeLoop
87-
private Object directTuple(VirtualFrame frame) {
88+
private PTuple directTuple(VirtualFrame frame) {
8889
final Object[] elements = new Object[values.length];
8990
for (int i = 0; i < values.length; i++) {
9091
elements[i] = values[i].execute(frame);

0 commit comments

Comments
 (0)