Skip to content

Commit 3368b97

Browse files
committed
translation environment reset tracked scopes loop counters (for comprehension expressions)
- GeneratorTranslator: single point of entry for handling the outer most loop iterator expression - generator translator outer most loop handling case only for comprehension expressions - fix PythonBaseTreeTranslator issue of placing the outermost loop iterator expression in the parent scope
1 parent 243a47c commit 3368b97

File tree

7 files changed

+91
-43
lines changed

7 files changed

+91
-43
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/PNodeUtil.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,11 @@ public static void clearSourceSections(PNode node) {
8484
}
8585
}
8686

87+
public static <T extends PNode> T replace(PNode oldNode, T node) {
88+
if (oldNode.isStatement()) {
89+
node.markAsStatement();
90+
}
91+
node.assignSourceSection(oldNode.getSourceSection());
92+
return oldNode.replace(node);
93+
}
8794
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/function/GeneratorExpressionNode.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import com.oracle.truffle.api.frame.FrameDescriptor;
3838
import com.oracle.truffle.api.frame.VirtualFrame;
3939
import com.oracle.truffle.api.nodes.RootNode;
40+
import com.oracle.truffle.api.profiles.ConditionProfile;
4041

4142
public final class GeneratorExpressionNode extends ExpressionDefinitionNode {
4243

@@ -53,6 +54,8 @@ public final class GeneratorExpressionNode extends ExpressionDefinitionNode {
5354
@CompilationFinal private boolean isOptimized;
5455
@Child private PNode getIterator;
5556

57+
@CompilationFinal private ConditionProfile iteratorProfile = ConditionProfile.createBinaryProfile();
58+
5659
public GeneratorExpressionNode(String name, RootCallTarget callTarget, PNode getIterator, FrameDescriptor descriptor, DefinitionCellSlots definitionCellSlots,
5760
ExecutionCellSlots executionCellSlots,
5861
int numOfActiveFlags, int numOfGeneratorBlockNode, int numOfGeneratorForNode) {
@@ -121,8 +124,13 @@ public RootNode getFunctionRootNode() {
121124

122125
@Override
123126
public Object execute(VirtualFrame frame) {
124-
Object[] arguments = PArguments.create(1);
125-
PArguments.setArgument(arguments, 0, getIterator.execute(frame));
127+
Object[] arguments;
128+
if (iteratorProfile.profile(getIterator == null)) {
129+
arguments = PArguments.create(0);
130+
} else {
131+
arguments = PArguments.create(1);
132+
PArguments.setArgument(arguments, 0, getIterator.execute(frame));
133+
}
126134
PArguments.setGlobals(arguments, PArguments.getGlobals(frame));
127135

128136
PCell[] closure;

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/parser/GeneratorTranslator.java

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
*/
2626
package com.oracle.graal.python.parser;
2727

28+
import static com.oracle.graal.python.nodes.PNodeUtil.replace;
29+
2830
import java.util.ArrayList;
2931
import java.util.List;
3032

@@ -79,18 +81,12 @@ public class GeneratorTranslator {
7981
private int numOfGeneratorBlockNode;
8082
private int numOfGeneratorForNode;
8183
private boolean needToHandleComplicatedYieldExpression;
82-
private boolean replacedOuterMostLoopIterator = false;
84+
private PNode getOuterMostLoopIterator;
85+
private final boolean inGeneratorExpression;
8386

84-
public GeneratorTranslator(FunctionRootNode root) {
87+
public GeneratorTranslator(FunctionRootNode root, boolean inGeneratorExpression) {
8588
this.root = root;
86-
}
87-
88-
private static <T extends PNode> T replace(PNode oldNode, T node) {
89-
if (oldNode.isStatement()) {
90-
node.markAsStatement();
91-
}
92-
node.assignSourceSection(oldNode.getSourceSection());
93-
return oldNode.replace(node);
89+
this.inGeneratorExpression = inGeneratorExpression;
9490
}
9591

9692
public RootCallTarget translate() {
@@ -106,6 +102,11 @@ public RootCallTarget translate() {
106102
/**
107103
* Redirect local variable accesses to materialized persistent frame.
108104
*/
105+
ForNode outerMostLoop = NodeUtil.findFirstNodeInstance(root, ForNode.class);
106+
if (outerMostLoop != null && inGeneratorExpression) {
107+
replaceOuterMostForNode(outerMostLoop);
108+
}
109+
109110
for (WriteLocalVariableNode write : NodeUtil.findAllNodeInstances(root, WriteLocalVariableNode.class)) {
110111
replace(write, WriteGeneratorFrameVariableNode.create(write.getSlot(), write.getRhs()));
111112
}
@@ -308,6 +309,20 @@ private void splitArgumentLoads(ReturnTargetNode returnTarget) {
308309
}
309310
}
310311

312+
private void replaceForNode(ForNode forNode) {
313+
WriteNode target = (WriteNode) forNode.getTarget();
314+
PNode getIter = forNode.getIterator();
315+
replace(forNode, GeneratorForNode.create(target, getIter, forNode.getBody(), nextGeneratorForNodeSlot()));
316+
}
317+
318+
private void replaceOuterMostForNode(ForNode forNode) {
319+
WriteNode target = (WriteNode) forNode.getTarget();
320+
PNode getIter = forNode.getIterator();
321+
getOuterMostLoopIterator = getIter;
322+
getIter = ReadIndexedArgumentNode.create(0);
323+
replace(forNode, GeneratorForNode.create(target, getIter, forNode.getBody(), nextGeneratorForNodeSlot()));
324+
}
325+
311326
private void replaceControl(PNode node, YieldNode yield) {
312327
/**
313328
* Has it been replaced already?
@@ -329,15 +344,11 @@ private void replaceControl(PNode node, YieldNode yield) {
329344
int ifFlag = nextActiveFlagSlot();
330345
int elseFlag = nextActiveFlagSlot();
331346
replace(node, GeneratorIfNode.create(ifNode.getCondition(), ifNode.getThen(), ifNode.getElse(), ifFlag, elseFlag));
347+
332348
} else if (node instanceof ForNode) {
333349
ForNode forNode = (ForNode) node;
334-
WriteNode target = (WriteNode) forNode.getTarget();
335-
PNode getIter = forNode.getIterator();
336-
if (!replacedOuterMostLoopIterator) {
337-
replacedOuterMostLoopIterator = true;
338-
getIter = ReadIndexedArgumentNode.create(0);
339-
}
340-
replace(node, GeneratorForNode.create(target, getIter, forNode.getBody(), nextGeneratorForNodeSlot()));
350+
replaceForNode(forNode);
351+
341352
} else if (node instanceof BlockNode) {
342353
BlockNode block = (BlockNode) node;
343354
int slotOfBlockIndex = nextGeneratorBlockIndexSlot();
@@ -347,16 +358,19 @@ private void replaceControl(PNode node, YieldNode yield) {
347358
}
348359

349360
replace(node, new GeneratorBlockNode(block.getStatements(), slotOfBlockIndex));
361+
350362
} else if (node instanceof TryExceptNode) {
351363
TryExceptNode tryExceptNode = (TryExceptNode) node;
352364
int exceptFlag = nextActiveFlagSlot();
353365
int elseFlag = nextActiveFlagSlot();
354366
int exceptIndex = nextGeneratorBlockIndexSlot();
355367
replace(node, new GeneratorTryExceptNode(tryExceptNode.getBody(), tryExceptNode.getExceptNodes(), tryExceptNode.getOrelse(), exceptFlag, elseFlag, exceptIndex));
368+
356369
} else if (node instanceof TryFinallyNode) {
357370
TryFinallyNode tryFinally = (TryFinallyNode) node;
358371
int finallyFlag = nextActiveFlagSlot();
359372
replace(node, new GeneratorTryFinallyNode(tryFinally.getBody(), tryFinally.getFinalbody(), finallyFlag));
373+
360374
} else if (node instanceof StatementNode) {
361375
// do nothing for now
362376
} else {
@@ -424,4 +438,8 @@ private int nextGeneratorForNodeSlot() {
424438
public int getNumOfGeneratorForNode() {
425439
return numOfGeneratorForNode;
426440
}
441+
442+
public PNode getGetOuterMostLoopIterator() {
443+
return getOuterMostLoopIterator;
444+
}
427445
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/parser/PythonBaseTreeTranslator.java

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
import com.oracle.graal.python.nodes.control.BlockNode;
7171
import com.oracle.graal.python.nodes.control.ForNode;
7272
import com.oracle.graal.python.nodes.control.GetIteratorNode;
73+
import com.oracle.graal.python.nodes.control.LoopNode;
7374
import com.oracle.graal.python.nodes.control.ReturnTargetNode;
7475
import com.oracle.graal.python.nodes.expression.AndNode;
7576
import com.oracle.graal.python.nodes.expression.CastToBooleanNode;
@@ -440,13 +441,12 @@ private PNode createComprehensionExpression(ParserRuleContext ctx, Function<Pars
440441
PNode block = getBlock.apply(ctx);
441442
PNode yield = factory.createYield(block, environment.getReturnSlot());
442443
yield.assignSourceSection(block.getSourceSection());
443-
ForNode genFor = createGeneratorExpression(ctx.getChild(Python3Parser.Comp_forContext.class, 0), yield);
444-
SourceSection srcSection = genFor.getSourceSection();
445-
PNode getIterator = genFor.getIterator();
446-
PNode body = new ReturnTargetNode(genFor, factory.createReadLocal(environment.getReturnSlot()));
444+
PNode body = createGeneratorExpression(ctx.getChild(Python3Parser.Comp_forContext.class, 0), yield);
445+
SourceSection srcSection = body.getSourceSection();
446+
body = new ReturnTargetNode(body, factory.createReadLocal(environment.getReturnSlot()));
447447
body.assignSourceSection(srcSection);
448448
int lineNum = ctx.getStart().getLine();
449-
GeneratorExpressionNode genExprDef = createGeneratorExpressionDefinition(body, getIterator, lineNum);
449+
GeneratorExpressionNode genExprDef = createGeneratorExpressionDefinition(body, lineNum);
450450
genExprDef.setEnclosingFrameDescriptor(environment.getEnclosingFrame());
451451
genExprDef.assignSourceSection(srcSection);
452452
return genExprDef;
@@ -455,13 +455,14 @@ private PNode createComprehensionExpression(ParserRuleContext ctx, Function<Pars
455455
}
456456
}
457457

458-
private GeneratorExpressionNode createGeneratorExpressionDefinition(PNode body, PNode getIterator, int lineNum) {
458+
private GeneratorExpressionNode createGeneratorExpressionDefinition(PNode body, int lineNum) {
459459
FrameDescriptor fd = environment.getCurrentFrame();
460460
String generatorName = "generator_exp:" + lineNum;
461461
FunctionRootNode funcRoot = factory.createFunctionRoot(body.getSourceSection(), generatorName, true, fd, body, environment.getExecutionCellSlots());
462-
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot);
462+
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot, true);
463463
RootCallTarget callTarget = gtran.translate();
464-
return new GeneratorExpressionNode(generatorName, callTarget, getIterator, fd, environment.getDefinitionCellSlots(), environment.getExecutionCellSlots(),
464+
PNode loopIterator = gtran.getGetOuterMostLoopIterator();
465+
return new GeneratorExpressionNode(generatorName, callTarget, loopIterator, fd, environment.getDefinitionCellSlots(), environment.getExecutionCellSlots(),
465466
gtran.getNumOfActiveFlags(),
466467
gtran.getNumOfGeneratorBlockNode(),
467468
gtran.getNumOfGeneratorForNode());
@@ -1201,7 +1202,7 @@ private PNode createForNode(PNode target, PNode iter, PNode body, PNode orelse,
12011202
}
12021203
}
12031204

1204-
private ForNode createForInScope(PNode target, PNode iterator, PNode body) {
1205+
private LoopNode createForInScope(PNode target, PNode iterator, PNode body) {
12051206
GetIteratorNode getIterator = factory.createGetIterator(iterator);
12061207
getIterator.assignSourceSection(iterator.getSourceSection());
12071208
return new ForNode(body, target, getIterator);
@@ -1415,8 +1416,9 @@ public Object visitFuncdef(Python3Parser.FuncdefContext ctx) {
14151416
*/
14161417
PNode funcDef;
14171418
if (environment.isInGeneratorScope()) {
1418-
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot);
1419-
funcDef = GeneratorFunctionDefinitionNode.create(funcName, enclosingClassName, core, arity, defaults, gtran.translate(), fd,
1419+
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot, false);
1420+
RootCallTarget callTarget = gtran.translate();
1421+
funcDef = GeneratorFunctionDefinitionNode.create(funcName, enclosingClassName, core, arity, defaults, callTarget, fd,
14201422
environment.getDefinitionCellSlots(), environment.getExecutionCellSlots(),
14211423
gtran.getNumOfActiveFlags(), gtran.getNumOfGeneratorBlockNode(), gtran.getNumOfGeneratorForNode());
14221424
} else {
@@ -1568,8 +1570,9 @@ public Object visitLambdef(Python3Parser.LambdefContext ctx) {
15681570
*/
15691571
PNode funcDef;
15701572
if (environment.isInGeneratorScope()) {
1571-
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot);
1572-
funcDef = GeneratorFunctionDefinitionNode.create(funcname, null, core, arity, defaults, gtran.translate(), fd,
1573+
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot, false);
1574+
RootCallTarget callTarget = gtran.translate();
1575+
funcDef = GeneratorFunctionDefinitionNode.create(funcname, null, core, arity, defaults, callTarget, fd,
15731576
environment.getDefinitionCellSlots(), environment.getExecutionCellSlots(),
15741577
gtran.getNumOfActiveFlags(), gtran.getNumOfGeneratorBlockNode(), gtran.getNumOfGeneratorForNode());
15751578
} else {
@@ -1740,11 +1743,19 @@ public Object visitTestlist_comp(Python3Parser.Testlist_compContext ctx) {
17401743
}
17411744
}
17421745

1743-
private ForNode createGeneratorExpression(Python3Parser.Comp_forContext comp_for, PNode yield) {
1746+
private PNode createGeneratorExpression(Python3Parser.Comp_forContext comp_for, PNode yield) {
1747+
return createGeneratorExpression(comp_for, yield, true);
1748+
}
1749+
1750+
private PNode createGeneratorExpression(Python3Parser.Comp_forContext comp_for, PNode yield, boolean iteratorInParentScope) {
17441751
// TODO: async
1745-
environment.pushCurentScope();
1752+
if (iteratorInParentScope) {
1753+
environment.pushCurentScope();
1754+
}
17461755
PNode iterator = asBlockOrPNode(comp_for.or_test().accept(this));
1747-
environment.popCurrentScope();
1756+
if (iteratorInParentScope) {
1757+
environment.popCurrentScope();
1758+
}
17481759

17491760
PNode targets = assigns.translate(comp_for.exprlist());
17501761
PNode myBody = yield;
@@ -1760,12 +1771,12 @@ private ForNode createGeneratorExpression(Python3Parser.Comp_forContext comp_for
17601771
comp_iter = comp_iter.comp_if().comp_iter();
17611772
}
17621773
if (comp_iter != null && comp_iter.comp_for() != null) {
1763-
myBody = createGeneratorExpression(comp_iter.comp_for(), yield);
1774+
myBody = createGeneratorExpression(comp_iter.comp_for(), yield, false);
17641775
}
17651776
if (condition != null) {
17661777
myBody = factory.createIf(factory.createYesNode(condition), myBody, EmptyNode.create());
17671778
}
1768-
ForNode loop = createForInScope(targets, iterator, myBody);
1779+
LoopNode loop = createForInScope(targets, iterator, myBody);
17691780
deriveSourceSection(comp_for, loop);
17701781
return loop;
17711782
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/parser/ScopeInfo.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ public int getLoopCount() {
9898
return loopCount;
9999
}
100100

101+
public void resetLoopCount() {
102+
this.loopCount = 0;
103+
}
104+
101105
public Set<ScopeInfo> getChildScopes() {
102106
return childScopes;
103107
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/parser/ScopeTranslator.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,7 @@ public T visitArgument(Python3Parser.ArgumentContext ctx) {
354354
@Override
355355
public T visitComp_for(Python3Parser.Comp_forContext ctx) {
356356
declareNames(ctx.exprlist());
357-
if (trackCells) {
358-
environment.incCurrentScopeLoopCount();
359-
}
357+
environment.incCurrentScopeLoopCount();
360358
return super.visitComp_for(ctx);
361359
}
362360

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/parser/TranslationEnvironment.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ public TranslationEnvironment(PythonLanguage language) {
7575
public TranslationEnvironment reset() {
7676
scopeLevel = 0;
7777
listComprehensionSlotCounter = 0;
78+
scopesStack.clear();
79+
for (ScopeInfo scope : scopeInfos.values()) {
80+
scope.resetLoopCount();
81+
}
7882
return this;
7983
}
8084

@@ -114,13 +118,11 @@ public ScopeInfo pushCurentScope() {
114118
return null;
115119
}
116120

117-
public ScopeInfo popCurrentScope() {
121+
public void popCurrentScope() {
118122
if (!scopesStack.isEmpty()) {
119123
scopeLevel++;
120124
currentScope = scopesStack.pop();
121-
return currentScope;
122125
}
123-
return null;
124126
}
125127

126128
public boolean atModuleLevel() {

0 commit comments

Comments
 (0)