Skip to content

Commit fad312f

Browse files
committed
[GR-20244] Cache length of sequence for destructuring assignment.
PullRequest: graalpython/772
2 parents a7dc545 + a738191 commit fad312f

File tree

2 files changed

+141
-55
lines changed

2 files changed

+141
-55
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_assign.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2018, 2020, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -84,7 +84,24 @@ def test_destructuring():
8484
assert a == 'а' and b == 'б' and c == 'в'
8585
# TODO not supported yet
8686
# a, b, c = "\U0001d49c\U0001d49e\U0001d4b5"
87-
# assert a == '𝒜' and b == '𝒞' and c == '𝒵'
87+
# assert a == '𝒜' and b == '𝒞' and c == '𝒵
88+
89+
# starred desctructuring assignment
90+
a, b, *s, c, d = tuple(range(4))
91+
assert a == 0 and b == 1 and c == 2 and d == 3
92+
93+
a, b, *s, c, d = tuple(range(10))
94+
assert a == 0 and b == 1 and s == [2, 3, 4, 5, 6, 7] and c == 8 and d == 9
95+
96+
c = -1
97+
d = -1
98+
a, b, *s = tuple(range(10))
99+
assert a == 0 and b == 1 and s == [2, 3, 4, 5, 6, 7, 8, 9] and c == -1 and d == -1
100+
101+
a = -1
102+
b = -1
103+
*s, c, d = tuple(range(10))
104+
assert a == -1 and b == -1 and s == [0, 1, 2, 3, 4, 5, 6, 7] and c == 8 and d == 9
88105

89106

90107
def test_assigning_hidden_keys():

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/frame/DestructuringAssignmentNode.java

Lines changed: 122 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2017, 2019, Oracle and/or its affiliates.
2+
* Copyright (c) 2017, 2020, Oracle and/or its affiliates.
33
* Copyright (c) 2013, Regents of the University of California
44
*
55
* All rights reserved.
@@ -46,14 +46,15 @@
4646
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
4747
import com.oracle.truffle.api.TruffleLanguage.ContextReference;
4848
import com.oracle.truffle.api.dsl.Cached;
49+
import com.oracle.truffle.api.dsl.Cached.Shared;
4950
import com.oracle.truffle.api.dsl.Specialization;
5051
import com.oracle.truffle.api.frame.VirtualFrame;
5152
import com.oracle.truffle.api.nodes.ExplodeLoop;
53+
import com.oracle.truffle.api.nodes.Node;
5254

5355
public abstract class DestructuringAssignmentNode extends StatementNode implements WriteNode {
5456
/* Lazily initialized helpers, also acting as branch profiles */
5557
@Child private PRaiseNode raiseNode;
56-
@Child private PythonObjectFactory factory;
5758
@CompilationFinal private ContextReference<PythonContext> contextRef;
5859

5960
/* Syntactic children */
@@ -102,7 +103,7 @@ protected static boolean isBuiltinTuple(Object object, IsBuiltinClassProfile pro
102103
}
103104

104105
@Specialization(guards = {"isBuiltinList(rhsVal, isBuiltinClass)", "starredIndex < 0"})
105-
public void writeList(VirtualFrame frame, PList rhsVal,
106+
void writeList(VirtualFrame frame, PList rhsVal,
106107
@Cached SequenceStorageNodes.LenNode lenNode,
107108
@Cached SequenceStorageNodes.GetItemNode getItemNode,
108109
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile isBuiltinClass) {
@@ -112,7 +113,7 @@ public void writeList(VirtualFrame frame, PList rhsVal,
112113
}
113114

114115
@Specialization(guards = {"isBuiltinTuple(rhsVal, isBuiltinClass)", "starredIndex < 0"})
115-
public void writeTuple(VirtualFrame frame, PTuple rhsVal,
116+
void writeTuple(VirtualFrame frame, PTuple rhsVal,
116117
@Cached SequenceStorageNodes.LenNode lenNode,
117118
@Cached SequenceStorageNodes.GetItemNode getItemNode,
118119
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile isBuiltinClass) {
@@ -137,51 +138,24 @@ private void writeSequenceStorage(VirtualFrame frame, SequenceStorage sequenceSt
137138
}
138139
}
139140

140-
@Specialization(guards = {"isBuiltinList(rhsVal, isBuiltinClass)", "starredIndex >= 0"})
141-
public void writeListStarred(VirtualFrame frame, PList rhsVal,
142-
@Cached SequenceStorageNodes.LenNode lenNode,
143-
@Cached SequenceStorageNodes.GetItemNode getItemNode,
141+
@Specialization(guards = {"isBuiltinList(rhsVal, isBuiltinClass)", "starredIndex >= 0"}, limit = "1")
142+
void writeListStarred(VirtualFrame frame, PList rhsVal,
143+
@Shared("writeStarred") @Cached WriteSequenceStorageStarredNode writeSequenceStorageStarredNode,
144144
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile isBuiltinClass) {
145145
SequenceStorage sequenceStorage = rhsVal.getSequenceStorage();
146-
writeSequenceStorageStarred(frame, sequenceStorage, lenNode, getItemNode);
146+
writeSequenceStorageStarredNode.execute(frame, sequenceStorage, slots, starredIndex);
147147
performAssignments(frame);
148148
}
149149

150-
@Specialization(guards = {"isBuiltinTuple(rhsVal, isBuiltinClass)", "starredIndex >= 0"})
151-
public void writeTupleStarred(VirtualFrame frame, PTuple rhsVal,
152-
@Cached SequenceStorageNodes.LenNode lenNode,
153-
@Cached SequenceStorageNodes.GetItemNode getItemNode,
150+
@Specialization(guards = {"isBuiltinTuple(rhsVal, isBuiltinClass)", "starredIndex >= 0"}, limit = "1")
151+
void writeTupleStarred(VirtualFrame frame, PTuple rhsVal,
152+
@Shared("writeStarred") @Cached WriteSequenceStorageStarredNode writeSequenceStorageStarredNode,
154153
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile isBuiltinClass) {
155154
SequenceStorage sequenceStorage = rhsVal.getSequenceStorage();
156-
writeSequenceStorageStarred(frame, sequenceStorage, lenNode, getItemNode);
155+
writeSequenceStorageStarredNode.execute(frame, sequenceStorage, slots, starredIndex);
157156
performAssignments(frame);
158157
}
159158

160-
@ExplodeLoop
161-
private void writeSequenceStorageStarred(VirtualFrame frame, SequenceStorage sequenceStorage, SequenceStorageNodes.LenNode lenNode, SequenceStorageNodes.GetItemNode getItemNode) {
162-
int len = lenNode.execute(sequenceStorage);
163-
if (len < slots.length - 1) {
164-
throw ensureRaiseNode().raise(ValueError, "not enough values to unpack (expected %d, got %d)", slots.length, len);
165-
} else {
166-
for (int i = 0; i < starredIndex; i++) {
167-
Object value = getItemNode.execute(frame, sequenceStorage, i);
168-
slots[i].doWrite(frame, value);
169-
}
170-
final int starredLength = len - (slots.length - 1);
171-
Object[] array = new Object[starredLength];
172-
CompilerAsserts.partialEvaluationConstant(starredLength);
173-
int pos = starredIndex;
174-
for (int i = 0; i < starredLength; i++) {
175-
array[i] = getItemNode.execute(frame, sequenceStorage, pos++);
176-
}
177-
slots[starredIndex].doWrite(frame, factory().createList(array));
178-
for (int i = starredIndex + 1; i < slots.length; i++) {
179-
Object value = getItemNode.execute(frame, sequenceStorage, pos++);
180-
slots[i].doWrite(frame, value);
181-
}
182-
}
183-
}
184-
185159
@ExplodeLoop
186160
private void performAssignments(VirtualFrame frame) {
187161
for (int i = 0; i < assignments.length; i++) {
@@ -190,28 +164,25 @@ private void performAssignments(VirtualFrame frame) {
190164
}
191165

192166
@Specialization(guards = {"!isBuiltinTuple(iterable, tupleProfile)", "!isBuiltinList(iterable, listProfile)", "starredIndex < 0"})
193-
public void writeIterable(VirtualFrame frame, Object iterable,
167+
void writeIterable(VirtualFrame frame, Object iterable,
194168
@Cached TupleNodes.ConstructTupleNode constructTupleNode,
195169
@Cached SequenceStorageNodes.LenNode lenNode,
196170
@Cached SequenceStorageNodes.GetItemNode getItemNode,
197171
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile tupleProfile,
198172
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile listProfile) {
199173
PTuple rhsValue = constructTupleNode.execute(frame, iterable);
200-
SequenceStorage sequenceStorage = rhsValue.getSequenceStorage();
201-
writeSequenceStorage(frame, sequenceStorage, lenNode, getItemNode);
174+
writeSequenceStorage(frame, rhsValue.getSequenceStorage(), lenNode, getItemNode);
202175
performAssignments(frame);
203176
}
204177

205-
@Specialization(guards = {"!isBuiltinTuple(iterable, tupleProfile)", "!isBuiltinList(iterable, listProfile)", "starredIndex >= 0"})
206-
public void writeIterableStarred(VirtualFrame frame, Object iterable,
178+
@Specialization(guards = {"!isBuiltinTuple(iterable, tupleProfile)", "!isBuiltinList(iterable, listProfile)", "starredIndex >= 0"}, limit = "1")
179+
void writeIterableStarred(VirtualFrame frame, Object iterable,
207180
@Cached TupleNodes.ConstructTupleNode constructTupleNode,
208-
@Cached SequenceStorageNodes.LenNode lenNode,
209-
@Cached SequenceStorageNodes.GetItemNode getItemNode,
181+
@Shared("writeStarred") @Cached WriteSequenceStorageStarredNode writeSequenceStorageStarredNode,
210182
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile tupleProfile,
211183
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile listProfile) {
212184
PTuple rhsValue = constructTupleNode.execute(frame, iterable);
213-
SequenceStorage sequenceStorage = rhsValue.getSequenceStorage();
214-
writeSequenceStorageStarred(frame, sequenceStorage, lenNode, getItemNode);
185+
writeSequenceStorageStarredNode.execute(frame, rhsValue.getSequenceStorage(), slots, starredIndex);
215186
performAssignments(frame);
216187
}
217188

@@ -231,11 +202,109 @@ private PRaiseNode ensureRaiseNode() {
231202
return raiseNode;
232203
}
233204

234-
private PythonObjectFactory factory() {
235-
if (factory == null) {
236-
CompilerDirectives.transferToInterpreterAndInvalidate();
237-
factory = insert(PythonObjectFactory.create());
205+
/**
206+
* This node performs assignments in form of
207+
* {@code pre_0, pre_1, ..., pre_i, *starred, post_0, post_1, ..., post_k = sequenceObject}.
208+
* Note that the parameters {@code slots} and {@code starredIndex} must be PE constant!
209+
*/
210+
abstract static class WriteSequenceStorageStarredNode extends Node {
211+
212+
@Child private PythonObjectFactory factory;
213+
@Child private PRaiseNode raiseNode;
214+
215+
abstract void execute(VirtualFrame frame, SequenceStorage storage, WriteNode[] slots, int starredIndex);
216+
217+
@Specialization(guards = {"getLength(lenNode, storage) == cachedLength"}, limit = "1")
218+
void doExploded(VirtualFrame frame, SequenceStorage storage, WriteNode[] slots, int starredIndex,
219+
@Shared("getItemNode") @Cached SequenceStorageNodes.GetItemNode getItemNode,
220+
@Shared("lenNode") @Cached @SuppressWarnings("unused") SequenceStorageNodes.LenNode lenNode,
221+
@Cached("getLength(lenNode, storage)") int cachedLength) {
222+
223+
CompilerAsserts.partialEvaluationConstant(slots);
224+
CompilerAsserts.partialEvaluationConstant(starredIndex);
225+
if (cachedLength < slots.length - 1) {
226+
throw ensureRaiseNode().raise(ValueError, "not enough values to unpack (expected %d, got %d)", slots.length, cachedLength);
227+
} else {
228+
writeSlots(frame, storage, getItemNode, slots, starredIndex);
229+
final int starredLength = cachedLength - (slots.length - 1);
230+
CompilerAsserts.partialEvaluationConstant(starredLength);
231+
Object[] array = consumeStarredItems(frame, storage, starredLength, getItemNode, starredIndex);
232+
assert starredLength == array.length;
233+
slots[starredIndex].doWrite(frame, factory().createList(array));
234+
performAssignmentsAfterStar(frame, storage, starredIndex + starredLength, getItemNode, slots, starredIndex);
235+
}
236+
}
237+
238+
@Specialization(replaces = "doExploded")
239+
void doGeneric(VirtualFrame frame, SequenceStorage storage, WriteNode[] slots, int starredIndex,
240+
@Shared("getItemNode") @Cached SequenceStorageNodes.GetItemNode getItemNode,
241+
@Shared("lenNode") @Cached SequenceStorageNodes.LenNode lenNode) {
242+
CompilerAsserts.partialEvaluationConstant(slots);
243+
CompilerAsserts.partialEvaluationConstant(starredIndex);
244+
int len = lenNode.execute(storage);
245+
if (len < slots.length - 1) {
246+
throw ensureRaiseNode().raise(ValueError, "not enough values to unpack (expected %d, got %d)", slots.length, len);
247+
} else {
248+
writeSlots(frame, storage, getItemNode, slots, starredIndex);
249+
final int starredLength = len - (slots.length - 1);
250+
Object[] array = new Object[starredLength];
251+
int pos = starredIndex;
252+
for (int i = 0; i < starredLength; i++) {
253+
array[i] = getItemNode.execute(frame, storage, pos++);
254+
}
255+
slots[starredIndex].doWrite(frame, factory().createList(array));
256+
for (int i = starredIndex + 1; i < slots.length; i++) {
257+
Object value = getItemNode.execute(frame, storage, pos++);
258+
slots[i].doWrite(frame, value);
259+
}
260+
}
261+
}
262+
263+
@ExplodeLoop
264+
private void writeSlots(VirtualFrame frame, SequenceStorage storage, SequenceStorageNodes.GetItemNode getItemNode, WriteNode[] slots, int starredIndex) {
265+
for (int i = 0; i < starredIndex; i++) {
266+
Object value = getItemNode.execute(frame, storage, i);
267+
slots[i].doWrite(frame, value);
268+
}
269+
}
270+
271+
@ExplodeLoop
272+
private Object[] consumeStarredItems(VirtualFrame frame, SequenceStorage sequenceStorage, int starredLength, SequenceStorageNodes.GetItemNode getItemNode, int starredIndex) {
273+
Object[] array = new Object[starredLength];
274+
CompilerAsserts.partialEvaluationConstant(starredLength);
275+
for (int i = 0; i < starredLength; i++) {
276+
array[i] = getItemNode.execute(frame, sequenceStorage, starredIndex + i);
277+
}
278+
return array;
279+
}
280+
281+
@ExplodeLoop
282+
private void performAssignmentsAfterStar(VirtualFrame frame, SequenceStorage sequenceStorage, int startPos, SequenceStorageNodes.GetItemNode getItemNode, WriteNode[] slots, int starredIndex) {
283+
for (int i = starredIndex + 1, pos = startPos; i < slots.length; i++, pos++) {
284+
Object value = getItemNode.execute(frame, sequenceStorage, pos);
285+
slots[i].doWrite(frame, value);
286+
}
287+
}
288+
289+
static int getLength(SequenceStorageNodes.LenNode lenNode, SequenceStorage storage) {
290+
return lenNode.execute(storage);
238291
}
239-
return factory;
292+
293+
private PythonObjectFactory factory() {
294+
if (factory == null) {
295+
CompilerDirectives.transferToInterpreterAndInvalidate();
296+
factory = insert(PythonObjectFactory.create());
297+
}
298+
return factory;
299+
}
300+
301+
private PRaiseNode ensureRaiseNode() {
302+
if (raiseNode == null) {
303+
CompilerDirectives.transferToInterpreterAndInvalidate();
304+
raiseNode = insert(PRaiseNode.create());
305+
}
306+
return raiseNode;
307+
}
308+
240309
}
241310
}

0 commit comments

Comments
 (0)