Skip to content

Commit 3c55bad

Browse files
committed
Fix: cache length of sequence for destructuring assignment.
1 parent 2ab52b6 commit 3c55bad

File tree

2 files changed

+141
-54
lines changed

2 files changed

+141
-54
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_assign.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2018, 2020, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -84,7 +84,24 @@ def test_destructuring():
8484
assert a == 'а' and b == 'б' and c == 'в'
8585
# TODO not supported yet
8686
# a, b, c = "\U0001d49c\U0001d49e\U0001d4b5"
87-
# assert a == '𝒜' and b == '𝒞' and c == '𝒵'
87+
# assert a == '𝒜' and b == '𝒞' and c == '𝒵
88+
89+
# starred desctructuring assignment
90+
a, b, *s, c, d = tuple(range(4))
91+
assert a == 0 and b == 1 and c == 2 and d == 3
92+
93+
a, b, *s, c, d = tuple(range(10))
94+
assert a == 0 and b == 1 and s == [2, 3, 4, 5, 6, 7] and c == 8 and d == 9
95+
96+
c = -1
97+
d = -1
98+
a, b, *s = tuple(range(10))
99+
assert a == 0 and b == 1 and s == [2, 3, 4, 5, 6, 7, 8, 9] and c == -1 and d == -1
100+
101+
a = -1
102+
b = -1
103+
*s, c, d = tuple(range(10))
104+
assert a == -1 and b == -1 and s == [0, 1, 2, 3, 4, 5, 6, 7] and c == 8 and d == 9
88105

89106

90107
def test_assigning_hidden_keys():

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/frame/DestructuringAssignmentNode.java

Lines changed: 122 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import com.oracle.graal.python.nodes.PRaiseNode;
3636
import com.oracle.graal.python.nodes.builtins.TupleNodes;
3737
import com.oracle.graal.python.nodes.expression.ExpressionNode;
38+
import com.oracle.graal.python.nodes.frame.DestructuringAssignmentNodeGen.WriteSequenceStorageStarredNodeGen;
3839
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
3940
import com.oracle.graal.python.nodes.statement.StatementNode;
4041
import com.oracle.graal.python.runtime.PythonContext;
@@ -46,14 +47,15 @@
4647
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
4748
import com.oracle.truffle.api.TruffleLanguage.ContextReference;
4849
import com.oracle.truffle.api.dsl.Cached;
50+
import com.oracle.truffle.api.dsl.Cached.Shared;
4951
import com.oracle.truffle.api.dsl.Specialization;
5052
import com.oracle.truffle.api.frame.VirtualFrame;
5153
import com.oracle.truffle.api.nodes.ExplodeLoop;
54+
import com.oracle.truffle.api.nodes.Node;
5255

5356
public abstract class DestructuringAssignmentNode extends StatementNode implements WriteNode {
5457
/* Lazily initialized helpers, also acting as branch profiles */
5558
@Child private PRaiseNode raiseNode;
56-
@Child private PythonObjectFactory factory;
5759
@CompilationFinal private ContextReference<PythonContext> contextRef;
5860

5961
/* Syntactic children */
@@ -102,7 +104,7 @@ protected static boolean isBuiltinTuple(Object object, IsBuiltinClassProfile pro
102104
}
103105

104106
@Specialization(guards = {"isBuiltinList(rhsVal, isBuiltinClass)", "starredIndex < 0"})
105-
public void writeList(VirtualFrame frame, PList rhsVal,
107+
void writeList(VirtualFrame frame, PList rhsVal,
106108
@Cached SequenceStorageNodes.LenNode lenNode,
107109
@Cached SequenceStorageNodes.GetItemNode getItemNode,
108110
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile isBuiltinClass) {
@@ -112,7 +114,7 @@ public void writeList(VirtualFrame frame, PList rhsVal,
112114
}
113115

114116
@Specialization(guards = {"isBuiltinTuple(rhsVal, isBuiltinClass)", "starredIndex < 0"})
115-
public void writeTuple(VirtualFrame frame, PTuple rhsVal,
117+
void writeTuple(VirtualFrame frame, PTuple rhsVal,
116118
@Cached SequenceStorageNodes.LenNode lenNode,
117119
@Cached SequenceStorageNodes.GetItemNode getItemNode,
118120
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile isBuiltinClass) {
@@ -137,51 +139,24 @@ private void writeSequenceStorage(VirtualFrame frame, SequenceStorage sequenceSt
137139
}
138140
}
139141

140-
@Specialization(guards = {"isBuiltinList(rhsVal, isBuiltinClass)", "starredIndex >= 0"})
141-
public void writeListStarred(VirtualFrame frame, PList rhsVal,
142-
@Cached SequenceStorageNodes.LenNode lenNode,
143-
@Cached SequenceStorageNodes.GetItemNode getItemNode,
142+
@Specialization(guards = {"isBuiltinList(rhsVal, isBuiltinClass)", "starredIndex >= 0"}, limit = "1")
143+
void writeListStarred(VirtualFrame frame, PList rhsVal,
144+
@Shared("writeStarred") @Cached WriteSequenceStorageStarredNode writeSequenceStorageStarredNode,
144145
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile isBuiltinClass) {
145146
SequenceStorage sequenceStorage = rhsVal.getSequenceStorage();
146-
writeSequenceStorageStarred(frame, sequenceStorage, lenNode, getItemNode);
147+
writeSequenceStorageStarredNode.execute(frame, sequenceStorage, slots, starredIndex);
147148
performAssignments(frame);
148149
}
149150

150-
@Specialization(guards = {"isBuiltinTuple(rhsVal, isBuiltinClass)", "starredIndex >= 0"})
151-
public void writeTupleStarred(VirtualFrame frame, PTuple rhsVal,
152-
@Cached SequenceStorageNodes.LenNode lenNode,
153-
@Cached SequenceStorageNodes.GetItemNode getItemNode,
151+
@Specialization(guards = {"isBuiltinTuple(rhsVal, isBuiltinClass)", "starredIndex >= 0"}, limit = "1")
152+
void writeTupleStarred(VirtualFrame frame, PTuple rhsVal,
153+
@Shared("writeStarred") @Cached WriteSequenceStorageStarredNode writeSequenceStorageStarredNode,
154154
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile isBuiltinClass) {
155155
SequenceStorage sequenceStorage = rhsVal.getSequenceStorage();
156-
writeSequenceStorageStarred(frame, sequenceStorage, lenNode, getItemNode);
156+
writeSequenceStorageStarredNode.execute(frame, sequenceStorage, slots, starredIndex);
157157
performAssignments(frame);
158158
}
159159

160-
@ExplodeLoop
161-
private void writeSequenceStorageStarred(VirtualFrame frame, SequenceStorage sequenceStorage, SequenceStorageNodes.LenNode lenNode, SequenceStorageNodes.GetItemNode getItemNode) {
162-
int len = lenNode.execute(sequenceStorage);
163-
if (len < slots.length - 1) {
164-
throw ensureRaiseNode().raise(ValueError, "not enough values to unpack (expected %d, got %d)", slots.length, len);
165-
} else {
166-
for (int i = 0; i < starredIndex; i++) {
167-
Object value = getItemNode.execute(frame, sequenceStorage, i);
168-
slots[i].doWrite(frame, value);
169-
}
170-
final int starredLength = len - (slots.length - 1);
171-
Object[] array = new Object[starredLength];
172-
CompilerAsserts.partialEvaluationConstant(starredLength);
173-
int pos = starredIndex;
174-
for (int i = 0; i < starredLength; i++) {
175-
array[i] = getItemNode.execute(frame, sequenceStorage, pos++);
176-
}
177-
slots[starredIndex].doWrite(frame, factory().createList(array));
178-
for (int i = starredIndex + 1; i < slots.length; i++) {
179-
Object value = getItemNode.execute(frame, sequenceStorage, pos++);
180-
slots[i].doWrite(frame, value);
181-
}
182-
}
183-
}
184-
185160
@ExplodeLoop
186161
private void performAssignments(VirtualFrame frame) {
187162
for (int i = 0; i < assignments.length; i++) {
@@ -190,28 +165,25 @@ private void performAssignments(VirtualFrame frame) {
190165
}
191166

192167
@Specialization(guards = {"!isBuiltinTuple(iterable, tupleProfile)", "!isBuiltinList(iterable, listProfile)", "starredIndex < 0"})
193-
public void writeIterable(VirtualFrame frame, Object iterable,
168+
void writeIterable(VirtualFrame frame, Object iterable,
194169
@Cached TupleNodes.ConstructTupleNode constructTupleNode,
195170
@Cached SequenceStorageNodes.LenNode lenNode,
196171
@Cached SequenceStorageNodes.GetItemNode getItemNode,
197172
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile tupleProfile,
198173
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile listProfile) {
199174
PTuple rhsValue = constructTupleNode.execute(frame, iterable);
200-
SequenceStorage sequenceStorage = rhsValue.getSequenceStorage();
201-
writeSequenceStorage(frame, sequenceStorage, lenNode, getItemNode);
175+
writeSequenceStorage(frame, rhsValue.getSequenceStorage(), lenNode, getItemNode);
202176
performAssignments(frame);
203177
}
204178

205-
@Specialization(guards = {"!isBuiltinTuple(iterable, tupleProfile)", "!isBuiltinList(iterable, listProfile)", "starredIndex >= 0"})
206-
public void writeIterableStarred(VirtualFrame frame, Object iterable,
179+
@Specialization(guards = {"!isBuiltinTuple(iterable, tupleProfile)", "!isBuiltinList(iterable, listProfile)", "starredIndex >= 0"}, limit = "1")
180+
void writeIterableStarred(VirtualFrame frame, Object iterable,
207181
@Cached TupleNodes.ConstructTupleNode constructTupleNode,
208-
@Cached SequenceStorageNodes.LenNode lenNode,
209-
@Cached SequenceStorageNodes.GetItemNode getItemNode,
182+
@Shared("writeStarred") @Cached WriteSequenceStorageStarredNode writeSequenceStorageStarredNode,
210183
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile tupleProfile,
211184
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile listProfile) {
212185
PTuple rhsValue = constructTupleNode.execute(frame, iterable);
213-
SequenceStorage sequenceStorage = rhsValue.getSequenceStorage();
214-
writeSequenceStorageStarred(frame, sequenceStorage, lenNode, getItemNode);
186+
writeSequenceStorageStarredNode.execute(frame, rhsValue.getSequenceStorage(), slots, starredIndex);
215187
performAssignments(frame);
216188
}
217189

@@ -231,11 +203,109 @@ private PRaiseNode ensureRaiseNode() {
231203
return raiseNode;
232204
}
233205

234-
private PythonObjectFactory factory() {
235-
if (factory == null) {
236-
CompilerDirectives.transferToInterpreterAndInvalidate();
237-
factory = insert(PythonObjectFactory.create());
206+
/**
207+
* This node performs assignments in form of
208+
* {@code pre_0, pre_1, ..., pre_i, *starred, post_0, post_1, ..., post_k = sequenceObject}.
209+
* Note that the parameters {@code slots} and {@code starredIndex} must be PE constant!
210+
*/
211+
abstract static class WriteSequenceStorageStarredNode extends Node {
212+
213+
@Child private PythonObjectFactory factory;
214+
@Child private PRaiseNode raiseNode;
215+
216+
abstract void execute(VirtualFrame frame, SequenceStorage storage, WriteNode[] slots, int starredIndex);
217+
218+
@Specialization(guards = {"getLength(lenNode, storage) == cachedLength"}, limit = "1")
219+
void doExploded(VirtualFrame frame, SequenceStorage storage, WriteNode[] slots, int starredIndex,
220+
@Shared("getItemNode") @Cached SequenceStorageNodes.GetItemNode getItemNode,
221+
@Shared("lenNode") @Cached @SuppressWarnings("unused") SequenceStorageNodes.LenNode lenNode,
222+
@Cached("getLength(lenNode, storage)") int cachedLength) {
223+
224+
CompilerAsserts.partialEvaluationConstant(slots);
225+
CompilerAsserts.partialEvaluationConstant(starredIndex);
226+
if (cachedLength < slots.length - 1) {
227+
throw ensureRaiseNode().raise(ValueError, "not enough values to unpack (expected %d, got %d)", slots.length, cachedLength);
228+
} else {
229+
writeSlots(frame, storage, getItemNode, slots, starredIndex);
230+
final int starredLength = cachedLength - (slots.length - 1);
231+
CompilerAsserts.partialEvaluationConstant(starredLength);
232+
Object[] array = consumeStarredItems(frame, storage, starredLength, getItemNode, starredIndex);
233+
assert starredLength == array.length;
234+
slots[starredIndex].doWrite(frame, factory().createList(array));
235+
performAssignmentsAfterStar(frame, storage, starredIndex + starredLength, getItemNode, slots, starredIndex);
236+
}
237+
}
238+
239+
@Specialization(replaces = "doExploded")
240+
void doGeneric(VirtualFrame frame, SequenceStorage storage, WriteNode[] slots, int starredIndex,
241+
@Shared("getItemNode") @Cached SequenceStorageNodes.GetItemNode getItemNode,
242+
@Shared("lenNode") @Cached SequenceStorageNodes.LenNode lenNode) {
243+
CompilerAsserts.partialEvaluationConstant(slots);
244+
CompilerAsserts.partialEvaluationConstant(starredIndex);
245+
int len = lenNode.execute(storage);
246+
if (len < slots.length - 1) {
247+
throw ensureRaiseNode().raise(ValueError, "not enough values to unpack (expected %d, got %d)", slots.length, len);
248+
} else {
249+
writeSlots(frame, storage, getItemNode, slots, starredIndex);
250+
final int starredLength = len - (slots.length - 1);
251+
Object[] array = new Object[starredLength];
252+
int pos = starredIndex;
253+
for (int i = 0; i < starredLength; i++) {
254+
array[i] = getItemNode.execute(frame, storage, pos++);
255+
}
256+
slots[starredIndex].doWrite(frame, factory().createList(array));
257+
for (int i = starredIndex + 1; i < slots.length; i++) {
258+
Object value = getItemNode.execute(frame, storage, pos++);
259+
slots[i].doWrite(frame, value);
260+
}
261+
}
262+
}
263+
264+
@ExplodeLoop
265+
private void writeSlots(VirtualFrame frame, SequenceStorage storage, SequenceStorageNodes.GetItemNode getItemNode, WriteNode[] slots, int starredIndex) {
266+
for (int i = 0; i < starredIndex; i++) {
267+
Object value = getItemNode.execute(frame, storage, i);
268+
slots[i].doWrite(frame, value);
269+
}
270+
}
271+
272+
@ExplodeLoop
273+
private Object[] consumeStarredItems(VirtualFrame frame, SequenceStorage sequenceStorage, int starredLength, SequenceStorageNodes.GetItemNode getItemNode, int starredIndex) {
274+
Object[] array = new Object[starredLength];
275+
CompilerAsserts.partialEvaluationConstant(starredLength);
276+
for (int i = 0; i < starredLength; i++) {
277+
array[i] = getItemNode.execute(frame, sequenceStorage, starredIndex + i);
278+
}
279+
return array;
280+
}
281+
282+
@ExplodeLoop
283+
private void performAssignmentsAfterStar(VirtualFrame frame, SequenceStorage sequenceStorage, int startPos, SequenceStorageNodes.GetItemNode getItemNode, WriteNode[] slots, int starredIndex) {
284+
for (int i = starredIndex + 1, pos = startPos; i < slots.length; i++, pos++) {
285+
Object value = getItemNode.execute(frame, sequenceStorage, pos);
286+
slots[i].doWrite(frame, value);
287+
}
288+
}
289+
290+
static int getLength(SequenceStorageNodes.LenNode lenNode, SequenceStorage storage) {
291+
return lenNode.execute(storage);
238292
}
239-
return factory;
293+
294+
private PythonObjectFactory factory() {
295+
if (factory == null) {
296+
CompilerDirectives.transferToInterpreterAndInvalidate();
297+
factory = insert(PythonObjectFactory.create());
298+
}
299+
return factory;
300+
}
301+
302+
private PRaiseNode ensureRaiseNode() {
303+
if (raiseNode == null) {
304+
CompilerDirectives.transferToInterpreterAndInvalidate();
305+
raiseNode = insert(PRaiseNode.create());
306+
}
307+
return raiseNode;
308+
}
309+
240310
}
241311
}

0 commit comments

Comments
 (0)