Skip to content

Commit 243a47c

Browse files
committed
ScopeTranslator: execute the generator iterator outside of the current generator scope
- TranslationUtils: added human friently antlr context rendering
1 parent b5af4a7 commit 243a47c

File tree

7 files changed

+159
-58
lines changed

7 files changed

+159
-58
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/function/GeneratorExpressionNode.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import com.oracle.graal.python.builtins.objects.cell.PCell;
2929
import com.oracle.graal.python.builtins.objects.function.PArguments;
30+
import com.oracle.graal.python.nodes.PNode;
3031
import com.oracle.graal.python.parser.DefinitionCellSlots;
3132
import com.oracle.graal.python.parser.ExecutionCellSlots;
3233
import com.oracle.truffle.api.CompilerAsserts;
@@ -50,12 +51,15 @@ public final class GeneratorExpressionNode extends ExpressionDefinitionNode {
5051
@CompilationFinal private FrameDescriptor enclosingFrameDescriptor;
5152
@CompilationFinal private boolean isEnclosingFrameGenerator;
5253
@CompilationFinal private boolean isOptimized;
54+
@Child private PNode getIterator;
5355

54-
public GeneratorExpressionNode(String name, RootCallTarget callTarget, FrameDescriptor descriptor, DefinitionCellSlots definitionCellSlots, ExecutionCellSlots executionCellSlots,
56+
public GeneratorExpressionNode(String name, RootCallTarget callTarget, PNode getIterator, FrameDescriptor descriptor, DefinitionCellSlots definitionCellSlots,
57+
ExecutionCellSlots executionCellSlots,
5558
int numOfActiveFlags, int numOfGeneratorBlockNode, int numOfGeneratorForNode) {
5659
super(definitionCellSlots, executionCellSlots);
5760
this.name = name;
5861
this.callTarget = callTarget;
62+
this.getIterator = getIterator;
5963
this.frameDescriptor = descriptor;
6064
this.numOfActiveFlags = numOfActiveFlags;
6165
this.numOfGeneratorBlockNode = numOfGeneratorBlockNode;
@@ -117,7 +121,8 @@ public RootNode getFunctionRootNode() {
117121

118122
@Override
119123
public Object execute(VirtualFrame frame) {
120-
Object[] arguments = PArguments.create();
124+
Object[] arguments = PArguments.create(1);
125+
PArguments.setArgument(arguments, 0, getIterator.execute(frame));
121126
PArguments.setGlobals(arguments, PArguments.getGlobals(frame));
122127

123128
PCell[] closure;

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/parser/GeneratorTranslator.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@
3131
import com.oracle.graal.python.nodes.EmptyNode;
3232
import com.oracle.graal.python.nodes.PNode;
3333
import com.oracle.graal.python.nodes.PNodeUtil;
34+
import com.oracle.graal.python.nodes.argument.ReadIndexedArgumentNode;
3435
import com.oracle.graal.python.nodes.control.BlockNode;
3536
import com.oracle.graal.python.nodes.control.BreakNode;
3637
import com.oracle.graal.python.nodes.control.BreakTargetNode;
3738
import com.oracle.graal.python.nodes.control.ContinueNode;
3839
import com.oracle.graal.python.nodes.control.ContinueTargetNode;
3940
import com.oracle.graal.python.nodes.control.ForNode;
40-
import com.oracle.graal.python.nodes.control.GetIteratorNode;
4141
import com.oracle.graal.python.nodes.control.IfNode;
4242
import com.oracle.graal.python.nodes.control.LoopNode;
4343
import com.oracle.graal.python.nodes.control.ReturnTargetNode;
@@ -79,6 +79,7 @@ public class GeneratorTranslator {
7979
private int numOfGeneratorBlockNode;
8080
private int numOfGeneratorForNode;
8181
private boolean needToHandleComplicatedYieldExpression;
82+
private boolean replacedOuterMostLoopIterator = false;
8283

8384
public GeneratorTranslator(FunctionRootNode root) {
8485
this.root = root;
@@ -331,7 +332,11 @@ private void replaceControl(PNode node, YieldNode yield) {
331332
} else if (node instanceof ForNode) {
332333
ForNode forNode = (ForNode) node;
333334
WriteNode target = (WriteNode) forNode.getTarget();
334-
GetIteratorNode getIter = (GetIteratorNode) forNode.getIterator();
335+
PNode getIter = forNode.getIterator();
336+
if (!replacedOuterMostLoopIterator) {
337+
replacedOuterMostLoopIterator = true;
338+
getIter = ReadIndexedArgumentNode.create(0);
339+
}
335340
replace(node, GeneratorForNode.create(target, getIter, forNode.getBody(), nextGeneratorForNodeSlot()));
336341
} else if (node instanceof BlockNode) {
337342
BlockNode block = (BlockNode) node;
@@ -419,5 +424,4 @@ private int nextGeneratorForNodeSlot() {
419424
public int getNumOfGeneratorForNode() {
420425
return numOfGeneratorForNode;
421426
}
422-
423427
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/parser/PythonBaseTreeTranslator.java

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
import com.oracle.graal.python.nodes.control.BlockNode;
7171
import com.oracle.graal.python.nodes.control.ForNode;
7272
import com.oracle.graal.python.nodes.control.GetIteratorNode;
73-
import com.oracle.graal.python.nodes.control.LoopNode;
7473
import com.oracle.graal.python.nodes.control.ReturnTargetNode;
7574
import com.oracle.graal.python.nodes.expression.AndNode;
7675
import com.oracle.graal.python.nodes.expression.CastToBooleanNode;
@@ -441,12 +440,13 @@ private PNode createComprehensionExpression(ParserRuleContext ctx, Function<Pars
441440
PNode block = getBlock.apply(ctx);
442441
PNode yield = factory.createYield(block, environment.getReturnSlot());
443442
yield.assignSourceSection(block.getSourceSection());
444-
PNode body = createGeneratorExpression(ctx.getChild(Python3Parser.Comp_forContext.class, 0), yield);
445-
SourceSection srcSection = body.getSourceSection();
446-
body = new ReturnTargetNode(body, factory.createReadLocal(environment.getReturnSlot()));
443+
ForNode genFor = createGeneratorExpression(ctx.getChild(Python3Parser.Comp_forContext.class, 0), yield);
444+
SourceSection srcSection = genFor.getSourceSection();
445+
PNode getIterator = genFor.getIterator();
446+
PNode body = new ReturnTargetNode(genFor, factory.createReadLocal(environment.getReturnSlot()));
447447
body.assignSourceSection(srcSection);
448448
int lineNum = ctx.getStart().getLine();
449-
GeneratorExpressionNode genExprDef = createGeneratorExpressionDefinition(body, lineNum);
449+
GeneratorExpressionNode genExprDef = createGeneratorExpressionDefinition(body, getIterator, lineNum);
450450
genExprDef.setEnclosingFrameDescriptor(environment.getEnclosingFrame());
451451
genExprDef.assignSourceSection(srcSection);
452452
return genExprDef;
@@ -455,12 +455,13 @@ private PNode createComprehensionExpression(ParserRuleContext ctx, Function<Pars
455455
}
456456
}
457457

458-
private GeneratorExpressionNode createGeneratorExpressionDefinition(PNode body, int lineNum) {
458+
private GeneratorExpressionNode createGeneratorExpressionDefinition(PNode body, PNode getIterator, int lineNum) {
459459
FrameDescriptor fd = environment.getCurrentFrame();
460460
String generatorName = "generator_exp:" + lineNum;
461461
FunctionRootNode funcRoot = factory.createFunctionRoot(body.getSourceSection(), generatorName, true, fd, body, environment.getExecutionCellSlots());
462462
GeneratorTranslator gtran = new GeneratorTranslator(funcRoot);
463-
return new GeneratorExpressionNode(generatorName, gtran.translate(), fd, environment.getDefinitionCellSlots(), environment.getExecutionCellSlots(),
463+
RootCallTarget callTarget = gtran.translate();
464+
return new GeneratorExpressionNode(generatorName, callTarget, getIterator, fd, environment.getDefinitionCellSlots(), environment.getExecutionCellSlots(),
464465
gtran.getNumOfActiveFlags(),
465466
gtran.getNumOfGeneratorBlockNode(),
466467
gtran.getNumOfGeneratorForNode());
@@ -1200,7 +1201,7 @@ private PNode createForNode(PNode target, PNode iter, PNode body, PNode orelse,
12001201
}
12011202
}
12021203

1203-
private LoopNode createForInScope(PNode target, PNode iterator, PNode body) {
1204+
private ForNode createForInScope(PNode target, PNode iterator, PNode body) {
12041205
GetIteratorNode getIterator = factory.createGetIterator(iterator);
12051206
getIterator.assignSourceSection(iterator.getSourceSection());
12061207
return new ForNode(body, target, getIterator);
@@ -1739,11 +1740,14 @@ public Object visitTestlist_comp(Python3Parser.Testlist_compContext ctx) {
17391740
}
17401741
}
17411742

1742-
private PNode createGeneratorExpression(Python3Parser.Comp_forContext comp_for, PNode yield) {
1743+
private ForNode createGeneratorExpression(Python3Parser.Comp_forContext comp_for, PNode yield) {
17431744
// TODO: async
1745+
environment.pushCurentScope();
1746+
PNode iterator = asBlockOrPNode(comp_for.or_test().accept(this));
1747+
environment.popCurrentScope();
1748+
17441749
PNode targets = assigns.translate(comp_for.exprlist());
17451750
PNode myBody = yield;
1746-
PNode iterator = asBlockOrPNode(comp_for.or_test().accept(this));
17471751
PNode condition = null;
17481752
Python3Parser.Comp_iterContext comp_iter = comp_for.comp_iter();
17491753
while (comp_iter != null && comp_iter.comp_if() != null) {
@@ -1761,7 +1765,7 @@ private PNode createGeneratorExpression(Python3Parser.Comp_forContext comp_for,
17611765
if (condition != null) {
17621766
myBody = factory.createIf(factory.createYesNode(condition), myBody, EmptyNode.create());
17631767
}
1764-
LoopNode loop = createForInScope(targets, iterator, myBody);
1768+
ForNode loop = createForInScope(targets, iterator, myBody);
17651769
deriveSourceSection(comp_for, loop);
17661770
return loop;
17671771
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/parser/ScopeInfo.java

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
*/
2626
package com.oracle.graal.python.parser;
2727

28+
import java.util.Collection;
2829
import java.util.HashSet;
2930
import java.util.LinkedHashSet;
3031
import java.util.List;
@@ -76,6 +77,8 @@ public enum ScopeKind {
7677
private List<PNode> defaultArgumentNodes;
7778
private ReadDefaultArgumentNode[] defaultArgumentReads;
7879

80+
private int loopCount = 0;
81+
7982
public ScopeInfo(String scopeId, ScopeKind kind, FrameDescriptor frameDescriptor, ScopeInfo parent) {
8083
this.scopeId = scopeId;
8184
this.scopeKind = kind;
@@ -87,6 +90,14 @@ public ScopeInfo(String scopeId, ScopeKind kind, FrameDescriptor frameDescriptor
8790
}
8891
}
8992

93+
public void incLoopCount() {
94+
loopCount++;
95+
}
96+
97+
public int getLoopCount() {
98+
return loopCount;
99+
}
100+
90101
public Set<ScopeInfo> getChildScopes() {
91102
return childScopes;
92103
}
@@ -143,7 +154,7 @@ public void addCellVar(String identifier) {
143154
addCellVar(identifier, false);
144155
}
145156

146-
protected void addCellVar(String identifier, boolean createFrameSlot) {
157+
public void addCellVar(String identifier, boolean createFrameSlot) {
147158
this.cellVars.add(identifier);
148159
if (createFrameSlot) {
149160
this.createSlotIfNotPresent(identifier);
@@ -169,32 +180,27 @@ public boolean isFreeVar(String identifier) {
169180
return this.freeVars.contains(identifier);
170181
}
171182

172-
public FrameSlot[] getCellVarSlots() {
173-
FrameSlot[] cellVarSlots = new FrameSlot[this.cellVars.size()];
183+
private FrameSlot[] getFrameSlots(Collection<String> identifiers, ScopeInfo scope) {
184+
assert scope != null : "getting frame slots: scope cannot be null!";
185+
FrameSlot[] slots = new FrameSlot[identifiers.size()];
174186
int i = 0;
175-
for (String identifier : this.cellVars) {
176-
cellVarSlots[i++] = findFrameSlot(identifier);
187+
for (String identifier : identifiers) {
188+
slots[i++] = scope.findFrameSlot(identifier);
177189
}
178-
return cellVarSlots;
190+
return slots;
191+
}
192+
193+
public FrameSlot[] getCellVarSlots() {
194+
return getFrameSlots(cellVars, this);
179195
}
180196

181197
public FrameSlot[] getFreeVarSlots() {
182-
FrameSlot[] freeVarSlots = new FrameSlot[this.freeVars.size()];
183-
int i = 0;
184-
for (String identifier : this.freeVars) {
185-
freeVarSlots[i++] = findFrameSlot(identifier);
186-
}
187-
return freeVarSlots;
198+
return getFrameSlots(freeVars, this);
188199
}
189200

190201
public FrameSlot[] getFreeVarSlotsInParentScope() {
191202
assert parent != null : "cannot get current freeVars in parent scope, parent scope cannot be null!";
192-
FrameSlot[] freeVarSlots = new FrameSlot[this.freeVars.size()];
193-
int i = 0;
194-
for (String identifier : this.freeVars) {
195-
freeVarSlots[i++] = parent.findFrameSlot(identifier);
196-
}
197-
return freeVarSlots;
203+
return getFrameSlots(freeVars, parent);
198204
}
199205

200206
public void setDefaultArgumentNodes(List<PNode> defaultArgumentNodes) {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/parser/ScopeTranslator.java

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ public class ScopeTranslator<T> extends Python3BaseVisitor<T> {
4747
private final PythonCore core;
4848
private final boolean interactive;
4949
private final boolean trackCells;
50+
private int comprehensionOrTestDepth = 0;
5051

5152
public ScopeTranslator(PythonCore core, TranslationEnvironment environment, boolean interactive, boolean trackCells) {
5253
this.core = core;
@@ -83,30 +84,35 @@ public static void accept(ParserRuleContext input, TranslationEnvironment enviro
8384
@Override
8485
public T visitFile_input(Python3Parser.File_inputContext ctx) {
8586
environment.beginScope(ctx, ScopeInfo.ScopeKind.Module);
86-
T node = super.visitFile_input(ctx);
87-
environment.endScope(ctx);
88-
return node;
87+
try {
88+
return super.visitFile_input(ctx);
89+
} finally {
90+
environment.endScope(ctx);
91+
}
8992
}
9093

9194
@Override
9295
public T visitSingle_input(Single_inputContext ctx) {
9396
if (interactive) {
9497
environment.beginScope(ctx, ScopeInfo.ScopeKind.Module);
9598
}
96-
T node = super.visitSingle_input(ctx);
97-
if (interactive) {
98-
environment.endScope(ctx);
99+
try {
100+
return super.visitSingle_input(ctx);
101+
} finally {
102+
if (interactive) {
103+
environment.endScope(ctx);
104+
}
99105
}
100-
101-
return node;
102106
}
103107

104108
@Override
105109
public T visitEval_input(Python3Parser.Eval_inputContext ctx) {
106110
environment.beginScope(ctx, ScopeInfo.ScopeKind.Module);
107-
T node = super.visitEval_input(ctx);
108-
environment.endScope(ctx);
109-
return node;
111+
try {
112+
return super.visitEval_input(ctx);
113+
} finally {
114+
environment.endScope(ctx);
115+
}
110116
}
111117

112118
@Override
@@ -348,11 +354,32 @@ public T visitArgument(Python3Parser.ArgumentContext ctx) {
348354
@Override
349355
public T visitComp_for(Python3Parser.Comp_forContext ctx) {
350356
declareNames(ctx.exprlist());
351-
T or_test = ctx.or_test().accept(this);
352-
if (ctx.comp_iter() != null) {
353-
return aggregateResult(or_test, ctx.comp_iter().accept(this));
354-
} else {
355-
return or_test;
357+
if (trackCells) {
358+
environment.incCurrentScopeLoopCount();
359+
}
360+
return super.visitComp_for(ctx);
361+
}
362+
363+
@Override
364+
public T visitOr_test(Python3Parser.Or_testContext ctx) {
365+
ScopeInfo generatorScope = null;
366+
if (ctx.getParent() instanceof Python3Parser.Comp_forContext) {
367+
if (comprehensionOrTestDepth == 0 && environment.getCurrentScopeLoopCount() == 1) {
368+
// the generator iterator needs to be early evaluated in the parent scope
369+
generatorScope = environment.pushCurentScope();
370+
}
371+
comprehensionOrTestDepth++;
372+
}
373+
try {
374+
return super.visitOr_test(ctx);
375+
} finally {
376+
if (ctx.getParent() instanceof Python3Parser.Comp_forContext) {
377+
comprehensionOrTestDepth--;
378+
if (comprehensionOrTestDepth == 0 && generatorScope != null && generatorScope.getLoopCount() == 1) {
379+
// restore the current scope
380+
environment.popCurrentScope();
381+
}
382+
}
356383
}
357384
}
358385

@@ -370,7 +397,7 @@ public T visitDefparameter(Python3Parser.DefparameterContext ctx) {
370397
if (trackCells) {
371398
if (ctx.test() != null) {
372399
String identifier = ctx.test().getText();
373-
environment.registerCellVariable(identifier);
400+
environment.registerCell(identifier);
374401
}
375402
}
376403
return super.visitDefparameter(ctx);
@@ -381,7 +408,8 @@ public T visitAtom(Python3Parser.AtomContext ctx) {
381408
if (trackCells) {
382409
TerminalNode name = ctx.NAME();
383410
if (name != null) {
384-
environment.registerCellVariable(name.getText());
411+
String identifier = name.getText();
412+
environment.registerCell(identifier);
385413
}
386414
}
387415
return super.visitAtom(ctx);

0 commit comments

Comments
 (0)