Skip to content

Commit bf5b153

Browse files
committed
[GR-10568] early evaluation of generator comprehension iterator expression
PullRequest: graalpython/102
2 parents 59bbe5d + 94dc19f commit bf5b153

File tree

11 files changed

+260
-67
lines changed

11 files changed

+260
-67
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_scope.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,3 +850,24 @@ def register():
850850
register()
851851

852852
assert MyClass.a_counter == 24
853+
854+
855+
def test_generator_scope():
856+
my_obj = [1, 2, 3, 4]
857+
my_obj = (i for i in my_obj for j in y)
858+
y = [1, 2]
859+
860+
assert set(my_obj.gi_code.co_cellvars) == set()
861+
assert set(my_obj.gi_code.co_freevars) == {'y'}
862+
863+
864+
def test_func_scope():
865+
my_obj = [1, 2, 3, 4]
866+
867+
def my_obj():
868+
return [i for i in my_obj for j in y]
869+
870+
y = [1, 2]
871+
872+
assert set(my_obj.__code__.co_cellvars) == set()
873+
assert set(my_obj.__code__.co_freevars) == {'my_obj', 'y'}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/generator/GeneratorBuiltins.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import static com.oracle.graal.python.nodes.SpecialMethodNames.__ITER__;
2929
import static com.oracle.graal.python.nodes.SpecialMethodNames.__NEXT__;
30+
import static com.oracle.graal.python.runtime.exception.PythonErrorType.NotImplementedError;
3031
import static com.oracle.graal.python.runtime.exception.PythonErrorType.StopIteration;
3132
import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
3233

@@ -45,6 +46,7 @@
4546
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
4647
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
4748
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
49+
import com.oracle.graal.python.runtime.PythonParseResult;
4850
import com.oracle.graal.python.runtime.exception.PException;
4951
import com.oracle.truffle.api.dsl.Cached;
5052
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
@@ -173,4 +175,19 @@ Object sendThrow(PGenerator self, @SuppressWarnings("unused") PythonClass typ, P
173175
return resumeGenerator(self);
174176
}
175177
}
178+
179+
@Builtin(name = "gi_code", minNumOfArguments = 1, maxNumOfArguments = 2, isGetter = true, isSetter = true)
180+
@GenerateNodeFactory
181+
public abstract static class GetCodeNode extends PythonBuiltinNode {
182+
@Specialization(guards = {"isNoValue(none)"})
183+
Object getCode(PGenerator self, @SuppressWarnings("unused") PNone none) {
184+
return new PythonParseResult(self.getGeneratorRootNode(), getCore());
185+
}
186+
187+
@SuppressWarnings("unused")
188+
@Specialization
189+
Object setCode(PGenerator self, PythonParseResult code) {
190+
throw raise(NotImplementedError, "setting gi_code");
191+
}
192+
}
176193
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/generator/PGenerator.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import com.oracle.truffle.api.frame.FrameDescriptor;
3737
import com.oracle.truffle.api.frame.FrameSlot;
3838
import com.oracle.truffle.api.frame.MaterializedFrame;
39+
import com.oracle.truffle.api.nodes.RootNode;
3940

4041
public final class PGenerator extends PythonBuiltinObject {
4142

@@ -92,6 +93,10 @@ public RootCallTarget getCallTarget() {
9293
return callTarget;
9394
}
9495

96+
public RootNode getGeneratorRootNode() {
97+
return callTarget.getRootNode();
98+
}
99+
95100
public Object[] getArguments() {
96101
return arguments;
97102
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/PNodeUtil.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,11 @@ public static void clearSourceSections(PNode node) {
8484
}
8585
}
8686

87+
public static <T extends PNode> T replace(PNode oldNode, T node) {
88+
if (oldNode.isStatement()) {
89+
node.markAsStatement();
90+
}
91+
node.assignSourceSection(oldNode.getSourceSection());
92+
return oldNode.replace(node);
93+
}
8794
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/function/GeneratorExpressionNode.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import com.oracle.graal.python.builtins.objects.cell.PCell;
2929
import com.oracle.graal.python.builtins.objects.function.PArguments;
30+
import com.oracle.graal.python.nodes.PNode;
3031
import com.oracle.graal.python.parser.DefinitionCellSlots;
3132
import com.oracle.graal.python.parser.ExecutionCellSlots;
3233
import com.oracle.truffle.api.CompilerAsserts;
@@ -50,12 +51,15 @@ public final class GeneratorExpressionNode extends ExpressionDefinitionNode {
5051
@CompilationFinal private FrameDescriptor enclosingFrameDescriptor;
5152
@CompilationFinal private boolean isEnclosingFrameGenerator;
5253
@CompilationFinal private boolean isOptimized;
54+
@Child private PNode getIterator;
5355

54-
public GeneratorExpressionNode(String name, RootCallTarget callTarget, FrameDescriptor descriptor, DefinitionCellSlots definitionCellSlots, ExecutionCellSlots executionCellSlots,
56+
public GeneratorExpressionNode(String name, RootCallTarget callTarget, PNode getIterator, FrameDescriptor descriptor, DefinitionCellSlots definitionCellSlots,
57+
ExecutionCellSlots executionCellSlots,
5558
int numOfActiveFlags, int numOfGeneratorBlockNode, int numOfGeneratorForNode) {
5659
super(definitionCellSlots, executionCellSlots);
5760
this.name = name;
5861
this.callTarget = callTarget;
62+
this.getIterator = getIterator;
5963
this.frameDescriptor = descriptor;
6064
this.numOfActiveFlags = numOfActiveFlags;
6165
this.numOfGeneratorBlockNode = numOfGeneratorBlockNode;
@@ -117,7 +121,13 @@ public RootNode getFunctionRootNode() {
117121

118122
@Override
119123
public Object execute(VirtualFrame frame) {
120-
Object[] arguments = PArguments.create();
124+
Object[] arguments;
125+
if (getIterator == null) {
126+
arguments = PArguments.create(0);
127+
} else {
128+
arguments = PArguments.create(1);
129+
PArguments.setArgument(arguments, 0, getIterator.execute(frame));
130+
}
121131
PArguments.setGlobals(arguments, PArguments.getGlobals(frame));
122132

123133
PCell[] closure;

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/parser/GeneratorTranslator.java

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,21 @@
2525
*/
2626
package com.oracle.graal.python.parser;
2727

28+
import static com.oracle.graal.python.nodes.PNodeUtil.replace;
29+
2830
import java.util.ArrayList;
2931
import java.util.List;
3032

3133
import com.oracle.graal.python.nodes.EmptyNode;
3234
import com.oracle.graal.python.nodes.PNode;
3335
import com.oracle.graal.python.nodes.PNodeUtil;
36+
import com.oracle.graal.python.nodes.argument.ReadIndexedArgumentNode;
3437
import com.oracle.graal.python.nodes.control.BlockNode;
3538
import com.oracle.graal.python.nodes.control.BreakNode;
3639
import com.oracle.graal.python.nodes.control.BreakTargetNode;
3740
import com.oracle.graal.python.nodes.control.ContinueNode;
3841
import com.oracle.graal.python.nodes.control.ContinueTargetNode;
3942
import com.oracle.graal.python.nodes.control.ForNode;
40-
import com.oracle.graal.python.nodes.control.GetIteratorNode;
4143
import com.oracle.graal.python.nodes.control.IfNode;
4244
import com.oracle.graal.python.nodes.control.LoopNode;
4345
import com.oracle.graal.python.nodes.control.ReturnTargetNode;
@@ -79,17 +81,12 @@ public class GeneratorTranslator {
7981
private int numOfGeneratorBlockNode;
8082
private int numOfGeneratorForNode;
8183
private boolean needToHandleComplicatedYieldExpression;
84+
private PNode getOuterMostLoopIterator;
85+
private final boolean inGeneratorExpression;
8286

83-
public GeneratorTranslator(FunctionRootNode root) {
87+
public GeneratorTranslator(FunctionRootNode root, boolean inGeneratorExpression) {
8488
this.root = root;
85-
}
86-
87-
private static <T extends PNode> T replace(PNode oldNode, T node) {
88-
if (oldNode.isStatement()) {
89-
node.markAsStatement();
90-
}
91-
node.assignSourceSection(oldNode.getSourceSection());
92-
return oldNode.replace(node);
89+
this.inGeneratorExpression = inGeneratorExpression;
9390
}
9491

9592
public RootCallTarget translate() {
@@ -105,6 +102,11 @@ public RootCallTarget translate() {
105102
/**
106103
* Redirect local variable accesses to materialized persistent frame.
107104
*/
105+
ForNode outerMostLoop = NodeUtil.findFirstNodeInstance(root, ForNode.class);
106+
if (outerMostLoop != null && inGeneratorExpression) {
107+
replaceOuterMostForNode(outerMostLoop);
108+
}
109+
108110
for (WriteLocalVariableNode write : NodeUtil.findAllNodeInstances(root, WriteLocalVariableNode.class)) {
109111
replace(write, WriteGeneratorFrameVariableNode.create(write.getSlot(), write.getRhs()));
110112
}
@@ -307,6 +309,20 @@ private void splitArgumentLoads(ReturnTargetNode returnTarget) {
307309
}
308310
}
309311

312+
private void replaceForNode(ForNode forNode) {
313+
WriteNode target = (WriteNode) forNode.getTarget();
314+
PNode getIter = forNode.getIterator();
315+
replace(forNode, GeneratorForNode.create(target, getIter, forNode.getBody(), nextGeneratorForNodeSlot()));
316+
}
317+
318+
private void replaceOuterMostForNode(ForNode forNode) {
319+
WriteNode target = (WriteNode) forNode.getTarget();
320+
PNode getIter = forNode.getIterator();
321+
getOuterMostLoopIterator = getIter;
322+
getIter = ReadIndexedArgumentNode.create(0);
323+
replace(forNode, GeneratorForNode.create(target, getIter, forNode.getBody(), nextGeneratorForNodeSlot()));
324+
}
325+
310326
private void replaceControl(PNode node, YieldNode yield) {
311327
/**
312328
* Has it been replaced already?
@@ -328,11 +344,11 @@ private void replaceControl(PNode node, YieldNode yield) {
328344
int ifFlag = nextActiveFlagSlot();
329345
int elseFlag = nextActiveFlagSlot();
330346
replace(node, GeneratorIfNode.create(ifNode.getCondition(), ifNode.getThen(), ifNode.getElse(), ifFlag, elseFlag));
347+
331348
} else if (node instanceof ForNode) {
332349
ForNode forNode = (ForNode) node;
333-
WriteNode target = (WriteNode) forNode.getTarget();
334-
GetIteratorNode getIter = (GetIteratorNode) forNode.getIterator();
335-
replace(node, GeneratorForNode.create(target, getIter, forNode.getBody(), nextGeneratorForNodeSlot()));
350+
replaceForNode(forNode);
351+
336352
} else if (node instanceof BlockNode) {
337353
BlockNode block = (BlockNode) node;
338354
int slotOfBlockIndex = nextGeneratorBlockIndexSlot();
@@ -342,16 +358,19 @@ private void replaceControl(PNode node, YieldNode yield) {
342358
}
343359

344360
replace(node, new GeneratorBlockNode(block.getStatements(), slotOfBlockIndex));
361+
345362
} else if (node instanceof TryExceptNode) {
346363
TryExceptNode tryExceptNode = (TryExceptNode) node;
347364
int exceptFlag = nextActiveFlagSlot();
348365
int elseFlag = nextActiveFlagSlot();
349366
int exceptIndex = nextGeneratorBlockIndexSlot();
350367
replace(node, new GeneratorTryExceptNode(tryExceptNode.getBody(), tryExceptNode.getExceptNodes(), tryExceptNode.getOrelse(), exceptFlag, elseFlag, exceptIndex));
368+
351369
} else if (node instanceof TryFinallyNode) {
352370
TryFinallyNode tryFinally = (TryFinallyNode) node;
353371
int finallyFlag = nextActiveFlagSlot();
354372
replace(node, new GeneratorTryFinallyNode(tryFinally.getBody(), tryFinally.getFinalbody(), finallyFlag));
373+
355374
} else if (node instanceof StatementNode) {
356375
// do nothing for now
357376
} else {
@@ -420,4 +439,7 @@ public int getNumOfGeneratorForNode() {
420439
return numOfGeneratorForNode;
421440
}
422441

442+
public PNode getGetOuterMostLoopIterator() {
443+
return getOuterMostLoopIterator;
444+
}
423445
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/parser/PythonBaseTreeTranslator.java

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,10 @@ private GeneratorExpressionNode createGeneratorExpressionDefinition(PNode body,
459459
FrameDescriptor fd = environment.getCurrentFrame();
460460
String generatorName = "generator_exp:" + lineNum;
461461
FunctionRootNode funcRoot = factory.createFunctionRoot(body.getSourceSection(), generatorName, true, fd, body, environment.getExecutionCellSlots());
462-
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot);
463-
return new GeneratorExpressionNode(generatorName, gtran.translate(), fd, environment.getDefinitionCellSlots(), environment.getExecutionCellSlots(),
462+
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot, true);
463+
RootCallTarget callTarget = gtran.translate();
464+
PNode loopIterator = gtran.getGetOuterMostLoopIterator();
465+
return new GeneratorExpressionNode(generatorName, callTarget, loopIterator, fd, environment.getDefinitionCellSlots(), environment.getExecutionCellSlots(),
464466
gtran.getNumOfActiveFlags(),
465467
gtran.getNumOfGeneratorBlockNode(),
466468
gtran.getNumOfGeneratorForNode());
@@ -1414,8 +1416,9 @@ public Object visitFuncdef(Python3Parser.FuncdefContext ctx) {
14141416
*/
14151417
PNode funcDef;
14161418
if (environment.isInGeneratorScope()) {
1417-
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot);
1418-
funcDef = GeneratorFunctionDefinitionNode.create(funcName, enclosingClassName, core, arity, defaults, gtran.translate(), fd,
1419+
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot, false);
1420+
RootCallTarget callTarget = gtran.translate();
1421+
funcDef = GeneratorFunctionDefinitionNode.create(funcName, enclosingClassName, core, arity, defaults, callTarget, fd,
14191422
environment.getDefinitionCellSlots(), environment.getExecutionCellSlots(),
14201423
gtran.getNumOfActiveFlags(), gtran.getNumOfGeneratorBlockNode(), gtran.getNumOfGeneratorForNode());
14211424
} else {
@@ -1567,8 +1570,9 @@ public Object visitLambdef(Python3Parser.LambdefContext ctx) {
15671570
*/
15681571
PNode funcDef;
15691572
if (environment.isInGeneratorScope()) {
1570-
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot);
1571-
funcDef = GeneratorFunctionDefinitionNode.create(funcname, null, core, arity, defaults, gtran.translate(), fd,
1573+
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot, false);
1574+
RootCallTarget callTarget = gtran.translate();
1575+
funcDef = GeneratorFunctionDefinitionNode.create(funcname, null, core, arity, defaults, callTarget, fd,
15721576
environment.getDefinitionCellSlots(), environment.getExecutionCellSlots(),
15731577
gtran.getNumOfActiveFlags(), gtran.getNumOfGeneratorBlockNode(), gtran.getNumOfGeneratorForNode());
15741578
} else {
@@ -1740,10 +1744,21 @@ public Object visitTestlist_comp(Python3Parser.Testlist_compContext ctx) {
17401744
}
17411745

17421746
private PNode createGeneratorExpression(Python3Parser.Comp_forContext comp_for, PNode yield) {
1747+
return createGeneratorExpression(comp_for, yield, true);
1748+
}
1749+
1750+
private PNode createGeneratorExpression(Python3Parser.Comp_forContext comp_for, PNode yield, boolean iteratorInParentScope) {
17431751
// TODO: async
1752+
if (iteratorInParentScope) {
1753+
environment.pushCurentScope();
1754+
}
1755+
PNode iterator = asBlockOrPNode(comp_for.or_test().accept(this));
1756+
if (iteratorInParentScope) {
1757+
environment.popCurrentScope();
1758+
}
1759+
17441760
PNode targets = assigns.translate(comp_for.exprlist());
17451761
PNode myBody = yield;
1746-
PNode iterator = asBlockOrPNode(comp_for.or_test().accept(this));
17471762
PNode condition = null;
17481763
Python3Parser.Comp_iterContext comp_iter = comp_for.comp_iter();
17491764
while (comp_iter != null && comp_iter.comp_if() != null) {
@@ -1756,7 +1771,7 @@ private PNode createGeneratorExpression(Python3Parser.Comp_forContext comp_for,
17561771
comp_iter = comp_iter.comp_if().comp_iter();
17571772
}
17581773
if (comp_iter != null && comp_iter.comp_for() != null) {
1759-
myBody = createGeneratorExpression(comp_iter.comp_for(), yield);
1774+
myBody = createGeneratorExpression(comp_iter.comp_for(), yield, false);
17601775
}
17611776
if (condition != null) {
17621777
myBody = factory.createIf(factory.createYesNode(condition), myBody, EmptyNode.create());

0 commit comments

Comments
 (0)