Skip to content

Commit 238f616

Browse files
committed
Implement OSR for generators
1 parent fb79d6c commit 238f616

File tree

2 files changed

+65
-19
lines changed

2 files changed

+65
-19
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/bytecode/PBytecodeGeneratorRootNode.java

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,23 @@
6060
import com.oracle.truffle.api.frame.Frame;
6161
import com.oracle.truffle.api.frame.MaterializedFrame;
6262
import com.oracle.truffle.api.frame.VirtualFrame;
63+
import com.oracle.truffle.api.nodes.BytecodeOSRNode;
6364
import com.oracle.truffle.api.nodes.ExplodeLoop;
6465
import com.oracle.truffle.api.profiles.ConditionProfile;
6566
import com.oracle.truffle.api.source.SourceSection;
6667

67-
public class PBytecodeGeneratorRootNode extends PRootNode {
68+
public class PBytecodeGeneratorRootNode extends PRootNode implements BytecodeOSRNode {
6869
private final PBytecodeRootNode rootNode;
6970
private final int resumeBci;
7071
private final int resumeStackTop;
7172

7273
@Child private ExecutionContext.CalleeContext calleeContext = ExecutionContext.CalleeContext.create();
7374
@Child private IsBuiltinClassProfile errorProfile;
7475
@Child private PRaiseNode raise = PRaiseNode.create();
76+
private final ConditionProfile returnProfile = ConditionProfile.create();
77+
78+
@CompilationFinal private Object osrMetadata;
79+
@CompilationFinal(dimensions = 1) private FrameSlotType[] frameSlotTypes;
7580

7681
private enum FrameSlotType {
7782
Object,
@@ -81,17 +86,13 @@ private enum FrameSlotType {
8186
Boolean
8287
}
8388

84-
@CompilationFinal(dimensions = 1) private FrameSlotType[] frameSlotTypes;
85-
86-
private final ConditionProfile returnProfile = ConditionProfile.create();
87-
8889
@TruffleBoundary
8990
public PBytecodeGeneratorRootNode(PythonLanguage language, PBytecodeRootNode rootNode, int resumeBci, int resumeStackTop) {
9091
super(language, rootNode.getFrameDescriptor());
9192
this.rootNode = rootNode;
9293
this.resumeBci = resumeBci;
9394
this.resumeStackTop = resumeStackTop;
94-
frameSlotTypes = new FrameSlotType[resumeStackTop];
95+
frameSlotTypes = new FrameSlotType[resumeStackTop + 1];
9596
}
9697

9798
@ExplodeLoop
@@ -153,7 +154,9 @@ private void copyFrameSlotsIntoVirtualFrame(MaterializedFrame generatorFrame, Vi
153154

154155
@ExplodeLoop
155156
private void copyFrameSlotsToGeneratorFrame(VirtualFrame virtualFrame, MaterializedFrame generatorFrame) {
156-
for (int i = 0; i < frameSlotTypes.length; i++) {
157+
int stackTop = getFrameDescriptor().getNumberOfSlots();
158+
CompilerAsserts.partialEvaluationConstant(stackTop);
159+
for (int i = 0; i < stackTop; i++) {
157160
if (virtualFrame.isObject(i)) {
158161
generatorFrame.setObject(i, virtualFrame.getObject(i));
159162
} else if (virtualFrame.isInt(i)) {
@@ -168,9 +171,6 @@ private void copyFrameSlotsToGeneratorFrame(VirtualFrame virtualFrame, Materiali
168171
throw CompilerDirectives.shouldNotReachHere("unexpected frame slot type");
169172
}
170173
}
171-
generatorFrame.setInt(rootNode.bcioffset, virtualFrame.getInt(rootNode.bcioffset));
172-
generatorFrame.setInt(rootNode.generatorStackTopOffset, virtualFrame.getInt(rootNode.generatorStackTopOffset));
173-
generatorFrame.setObject(rootNode.generatorReturnOffset, virtualFrame.getObject(rootNode.generatorReturnOffset));
174174
}
175175

176176
private void profileFrameSlots(MaterializedFrame generatorFrame) {
@@ -192,6 +192,31 @@ private void profileFrameSlots(MaterializedFrame generatorFrame) {
192192
}
193193
}
194194

195+
@Override
196+
public Object executeOSR(VirtualFrame osrFrame, int target, Object interpreterState) {
197+
Integer osrStackTop = (Integer) interpreterState;
198+
MaterializedFrame generatorFrame = PArguments.getGeneratorFrame(osrFrame);
199+
copyFrameSlotsIntoVirtualFrame(generatorFrame, osrFrame);
200+
copyOSRStackRemainderIntoVirtualFrame(generatorFrame, osrFrame, osrStackTop);
201+
try {
202+
return rootNode.executeFromBci(osrFrame, osrFrame, this, target, osrStackTop);
203+
} finally {
204+
copyFrameSlotsToGeneratorFrame(osrFrame, generatorFrame);
205+
}
206+
}
207+
208+
@ExplodeLoop
209+
private void copyOSRStackRemainderIntoVirtualFrame(MaterializedFrame generatorFrame, VirtualFrame osrFrame, int stackTop) {
210+
/*
211+
* In addition to local variables and stack slots present at resume, OSR needs to also
212+
* revirtualize stack items that have been pushed since resume. Stack slots at a back edge
213+
* should never be primitives.
214+
*/
215+
for (int i = resumeStackTop; i <= stackTop; i++) {
216+
osrFrame.setObject(i, generatorFrame.getObject(i));
217+
}
218+
}
219+
195220
@Override
196221
public Object execute(VirtualFrame frame) {
197222
calleeContext.enter(frame);
@@ -206,22 +231,23 @@ public Object execute(VirtualFrame frame) {
206231
PArguments.setException(frame, localException == null ? outerException : localException);
207232
Object result;
208233
Frame localFrame;
209-
if (CompilerDirectives.inInterpreter()) {
234+
boolean usingMaterializedFrame = CompilerDirectives.inInterpreter();
235+
if (usingMaterializedFrame) {
210236
profileFrameSlots(generatorFrame);
211237
localFrame = generatorFrame;
212238
} else {
213239
copyFrameSlotsIntoVirtualFrame(generatorFrame, frame);
214240
localFrame = frame;
215241
}
216242
try {
217-
result = rootNode.executeFromBci(frame, localFrame, resumeBci, resumeStackTop);
243+
result = rootNode.executeFromBci(frame, localFrame, this, resumeBci, resumeStackTop);
218244
} catch (PException pe) {
219245
// PEP 479 - StopIteration raised from generator body needs to be wrapped in
220246
// RuntimeError
221247
pe.expectStopIteration(getErrorProfile());
222248
throw raise.raise(RuntimeError, pe.setCatchingFrameAndGetEscapedException(frame, this), ErrorMessages.GENERATOR_RAISED_STOPITER);
223249
} finally {
224-
if (CompilerDirectives.inCompiledCode()) {
250+
if (!usingMaterializedFrame) {
225251
copyFrameSlotsToGeneratorFrame(frame, generatorFrame);
226252
}
227253
calleeContext.exit(frame, this);
@@ -243,6 +269,26 @@ public Object execute(VirtualFrame frame) {
243269
return result;
244270
}
245271

272+
@Override
273+
public Object getOSRMetadata() {
274+
return osrMetadata;
275+
}
276+
277+
@Override
278+
public void setOSRMetadata(Object osrMetadata) {
279+
this.osrMetadata = osrMetadata;
280+
}
281+
282+
@Override
283+
public Object[] storeParentFrameInArguments(VirtualFrame parentFrame) {
284+
return rootNode.storeParentFrameInArguments(parentFrame);
285+
}
286+
287+
@Override
288+
public Frame restoreParentFrameFromArguments(Object[] arguments) {
289+
return rootNode.restoreParentFrameFromArguments(arguments);
290+
}
291+
246292
@Override
247293
public String getName() {
248294
return rootNode.getName();
@@ -251,7 +297,7 @@ public String getName() {
251297
@Override
252298
public String toString() {
253299
CompilerAsserts.neverPartOfCompilation();
254-
return "<bytecode " + rootNode.getName() + ">";
300+
return "<bytecode " + rootNode.getName() + " (generator resume bci=" + resumeBci + ")>";
255301
}
256302

257303
@Override

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/bytecode/PBytecodeRootNode.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ public Object execute(VirtualFrame virtualFrame) {
720720
copyArgsAndCells(virtualFrame, virtualFrame.getArguments());
721721
}
722722

723-
return executeFromBci(virtualFrame, virtualFrame, 0, getInitialStackTop());
723+
return executeFromBci(virtualFrame, virtualFrame, this, 0, getInitialStackTop());
724724
} finally {
725725
calleeContext.exit(virtualFrame, this);
726726
}
@@ -740,13 +740,13 @@ public Frame restoreParentFrameFromArguments(Object[] arguments) {
740740

741741
@Override
742742
public Object executeOSR(VirtualFrame osrFrame, int target, Object interpreterState) {
743-
return executeFromBci(osrFrame, osrFrame, target, (Integer) interpreterState);
743+
return executeFromBci(osrFrame, osrFrame, this, target, (Integer) interpreterState);
744744
}
745745

746746
@BytecodeInterpreterSwitch
747747
@ExplodeLoop(kind = ExplodeLoop.LoopExplosionKind.MERGE_EXPLODE)
748748
@SuppressWarnings("fallthrough")
749-
Object executeFromBci(VirtualFrame virtualFrame, Frame localFrame, int initialBci, int initialStackTop) {
749+
Object executeFromBci(VirtualFrame virtualFrame, Frame localFrame, BytecodeOSRNode osrNode, int initialBci, int initialStackTop) {
750750
Object globals = PArguments.getGlobals(virtualFrame);
751751
Object locals = PArguments.getSpecialArgument(virtualFrame);
752752

@@ -1158,7 +1158,7 @@ Object executeFromBci(VirtualFrame virtualFrame, Frame localFrame, int initialBc
11581158
if (CompilerDirectives.inInterpreter()) {
11591159
loopCount++;
11601160
}
1161-
if (CompilerDirectives.inInterpreter() && BytecodeOSRNode.pollOSRBackEdge(this)) {
1161+
if (CompilerDirectives.inInterpreter() && BytecodeOSRNode.pollOSRBackEdge(osrNode)) {
11621162
/*
11631163
* Beware of race conditions when adding more things to the
11641164
* interpreterState argument. It gets stored already at this point, but
@@ -1169,7 +1169,7 @@ Object executeFromBci(VirtualFrame virtualFrame, Frame localFrame, int initialBc
11691169
* will get mixed up. To retain such state, put it into the frame
11701170
* instead.
11711171
*/
1172-
Object osrResult = BytecodeOSRNode.tryOSR(this, bci, stackTop, null, virtualFrame);
1172+
Object osrResult = BytecodeOSRNode.tryOSR(osrNode, bci, stackTop, null, virtualFrame);
11731173
if (osrResult != null) {
11741174
if (CompilerDirectives.inInterpreter() && loopCount > 0) {
11751175
LoopNode.reportLoopCount(this, loopCount);

0 commit comments

Comments
 (0)