Skip to content

Commit 7d8f0af

Browse files
author
Adam Hrbac
committed
Use assumption on the language rather than threadState
Since threadState is not constant for a given code object, truffle cannot compile with the assumption in mind and it degenerates into a simple boolean flag with extra virtual calls. The language is constant and so provides a static path to the assumption
1 parent a82af13 commit 7d8f0af

File tree

4 files changed

+82
-75
lines changed

4 files changed

+82
-75
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/PythonLanguage.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,12 @@ public final class PythonLanguage extends TruffleLanguage<PythonContext> {
216216

217217
private static final LanguageReference<PythonLanguage> REFERENCE = LanguageReference.create(PythonLanguage.class);
218218

219+
/**
220+
* This assumption will be valid if no context set a trace function at any point. Calling
221+
* sys.settrace(None) will not invalidate it
222+
*/
223+
public final Assumption noTracingAssumption = Assumption.create("No tracing function was set");
224+
219225
@CompilationFinal private boolean singleContext = true;
220226

221227
public boolean isSingleContext() {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/SysModuleBuiltins.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -978,11 +978,12 @@ Object settrace(Object function) {
978978
if (!ctx.getOption(PythonOptions.EnableBytecodeInterpreter)) {
979979
throw raise(NotImplementedError, ErrorMessages.SETTRACE_NOT_IMPLEMENTED);
980980
}
981-
PythonContext.PythonThreadState state = ctx.getThreadState(getLanguage());
981+
PythonLanguage language = getLanguage();
982+
PythonContext.PythonThreadState state = ctx.getThreadState(language);
982983
if (function == PNone.NONE) {
983-
state.setTraceFun(null);
984+
state.setTraceFun(null, language);
984985
} else {
985-
state.setTraceFun(function);
986+
state.setTraceFun(function, language);
986987
}
987988
return PNone.NONE;
988989
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/bytecode/PBytecodeRootNode.java

Lines changed: 69 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@
194194
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
195195
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
196196
import com.oracle.truffle.api.CompilerDirectives.ValueType;
197+
import com.oracle.truffle.api.HostCompilerDirectives;
197198
import com.oracle.truffle.api.HostCompilerDirectives.BytecodeInterpreterSwitch;
198199
import com.oracle.truffle.api.Truffle;
199200
import com.oracle.truffle.api.TruffleLanguage;
@@ -496,9 +497,9 @@ public final class PBytecodeRootNode extends PRootNode implements BytecodeOSRNod
496497
@Child private CalleeContext calleeContext = CalleeContext.create();
497498
@Child private PythonObjectFactory factory = PythonObjectFactory.create();
498499
@Child private ExceptionStateNodes.GetCaughtExceptionNode getCaughtExceptionNode;
499-
@Child private GetExceptionTracebackNode traceGetExceptionTracebackNode = null;
500-
@Child private MaterializeFrameNode traceMaterializeFrameNode = null;
501-
@Child private CallTernaryMethodNode traceCallTernaryMethodNode = null;
500+
// private GetExceptionTracebackNode traceGetExceptionTracebackNode = null;
501+
private MaterializeFrameNode traceMaterializeFrameNode = null;
502+
// private CallTernaryMethodNode traceCallTernaryMethodNode = null;
502503

503504
private final LoopConditionProfile exceptionChainProfile1 = LoopConditionProfile.createCountingProfile();
504505
private final LoopConditionProfile exceptionChainProfile2 = LoopConditionProfile.createCountingProfile();
@@ -1011,6 +1012,14 @@ private InterpreterContinuation(int bci, int stackTop) {
10111012

10121013
@ValueType
10131014
private static final class MutableLoopData {
1015+
/*
1016+
* data for tracing
1017+
*/
1018+
int pastBci;
1019+
int pastLine;
1020+
int returnLine;
1021+
PFrame pyFrame;
1022+
10141023
int loopCount;
10151024
/*
10161025
* This separate tracking of local exception is necessary to make exception state saving
@@ -1054,6 +1063,8 @@ private Object executeUncached(VirtualFrame virtualFrame, Frame localFrame, Byte
10541063
return bytecodeLoop(virtualFrame, localFrame, osrNode, initialBci, initialStackTop, false, false);
10551064
}
10561065

1066+
@CompilationFinal private PythonLanguage cachedLanguage = null;
1067+
10571068
@ExplodeLoop(kind = ExplodeLoop.LoopExplosionKind.MERGE_EXPLODE)
10581069
@SuppressWarnings("fallthrough")
10591070
@BytecodeInterpreterSwitch
@@ -1062,14 +1073,16 @@ private Object bytecodeLoop(VirtualFrame virtualFrame, Frame localFrame, Bytecod
10621073
Object[] arguments = virtualFrame.getArguments();
10631074
Object globals = PArguments.getGlobals(arguments);
10641075
Object locals = PArguments.getSpecialArgument(arguments);
1065-
1066-
boolean isGeneratorOrCoroutine = co.isGeneratorOrCoroutine();
1067-
if (inCompiledCode && !isGeneratorOrCoroutine) {
1068-
unboxVariables(localFrame);
1076+
final PythonLanguage language;
1077+
if (cachedLanguage == null) {
1078+
CompilerDirectives.transferToInterpreterAndInvalidate();
1079+
language = cachedLanguage = PythonLanguage.get(this);
1080+
} else {
1081+
language = cachedLanguage;
10691082
}
1070-
1071-
final PythonContext context = PythonContext.get(this);
1072-
final PythonContext.PythonThreadState threadState = context.getThreadState(PythonLanguage.get(this));
1083+
final Assumption noTrace = language.noTracingAssumption;
1084+
final PythonContext pythonContext = PythonContext.get(this);
1085+
final PythonContext.PythonThreadState threadState = pythonContext.getThreadState(language);
10731086

10741087
/*
10751088
* We use an object as a workaround for not being able to specify which local variables are
@@ -1078,6 +1091,7 @@ private Object bytecodeLoop(VirtualFrame virtualFrame, Frame localFrame, Bytecod
10781091
MutableLoopData mutableData = new MutableLoopData();
10791092
int stackTop = initialStackTop;
10801093
int bci = initialBci;
1094+
mutableData.pastLine = -1;
10811095

10821096
byte[] localBC = bytecode;
10831097
Object[] localConsts = consts;
@@ -1093,34 +1107,29 @@ private Object bytecodeLoop(VirtualFrame virtualFrame, Frame localFrame, Bytecod
10931107
CompilerAsserts.partialEvaluationConstant(bci);
10941108
CompilerAsserts.partialEvaluationConstant(stackTop);
10951109

1096-
PFrame pyFrame;
1097-
1098-
int pastLine = -1;
1099-
if (threadState.getTraceFun() != null && !threadState.isTracing()) {
1100-
pyFrame = ensurePyFrame(virtualFrame, null);
1110+
if (!noTrace.isValid() && threadState.getTraceFun() != null && !threadState.isTracing()) {
1111+
mutableData.pyFrame = ensurePyFrame(virtualFrame, null);
11011112
// if we are simply continuing to run an OSR loop after the replacememnt, tracing an
11021113
// extra CALL event would be incorrect
11031114
if (!fromOSR) {
1104-
pyFrame.setLocalTraceFun(invokeTraceFunction(threadState.getTraceFun(), null, threadState, virtualFrame, pyFrame, PythonContext.TraceEvent.CALL,
1105-
initialBci == 0 ? getFirstLineno() : (pastLine = bciToLine(initialBci))));
1115+
mutableData.pyFrame.setLocalTraceFun(invokeTraceFunction(threadState.getTraceFun(), null, threadState, virtualFrame, mutableData.pyFrame, PythonContext.TraceEvent.CALL,
1116+
initialBci == 0 ? getFirstLineno() : (mutableData.pastLine = bciToLine(initialBci))));
11061117
}
1107-
} else {
1108-
pyFrame = null;
11091118
}
11101119

1111-
int returnLine = pastLine;
1112-
int pastBci = initialBci;
1120+
mutableData.returnLine = mutableData.pastLine;
1121+
mutableData.pastBci = initialBci;
11131122

11141123
int oparg = 0;
11151124
while (true) {
11161125
final byte bc = localBC[bci];
11171126
final int beginBci = bci;
1118-
if (threadState.getTraceFun() != null && !threadState.isTracing()) {
1119-
pyFrame = ensurePyFrame(virtualFrame, pyFrame);
1127+
if (!noTrace.isValid() && threadState.getTraceFun() != null && !threadState.isTracing()) {
1128+
mutableData.pyFrame = ensurePyFrame(virtualFrame, mutableData.pyFrame);
11201129
int thisLine = bciToLine(bci);
1121-
boolean onANewLine = thisLine != pastLine;
1122-
pastLine = thisLine;
1123-
OpCodes c = OpCodes.fromOpCode(localBC[pastBci]);
1130+
boolean onANewLine = thisLine != mutableData.pastLine;
1131+
mutableData.pastLine = thisLine;
1132+
OpCodes c = OpCodes.fromOpCode(localBC[mutableData.pastBci]);
11241133
/*
11251134
* normally, we trace a line every time the previous bytecode instruction was on a
11261135
* different line than the current one. There are a number of exceptions to this,
@@ -1138,26 +1147,26 @@ private Object bytecodeLoop(VirtualFrame virtualFrame, Frame localFrame, Bytecod
11381147
* https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt#L210-L215
11391148
* for more details
11401149
*/
1141-
boolean shouldTrace = pyFrame.getLocalTraceFun() != null && pyFrame.getTraceLine();
1150+
boolean shouldTrace = mutableData.pyFrame.getLocalTraceFun() != null && mutableData.pyFrame.getTraceLine();
11421151
if (shouldTrace) {
1143-
shouldTrace = pastBci > bci; // is a backward jump
1152+
shouldTrace = mutableData.pastBci > bci; // is a backward jump
11441153
if (!shouldTrace) {
11451154
shouldTrace = onANewLine &&
11461155
// is not a forward jump
1147-
(pastBci + c.length() >= bci ||
1156+
(mutableData.pastBci + c.length() >= bci ||
11481157
// is a forward jump to the start of line
11491158
bciToLine(bci - 1) != thisLine);
11501159
}
11511160
}
11521161
if (shouldTrace) {
1153-
returnLine = pastLine;
1154-
pyFrame.setLocalTraceFun(
1155-
invokeTraceFunction(pyFrame.getLocalTraceFun(), null, threadState, virtualFrame, pyFrame, PythonContext.TraceEvent.LINE,
1156-
pastLine));
1162+
mutableData.returnLine = mutableData.pastLine;
1163+
mutableData.pyFrame.setLocalTraceFun(
1164+
invokeTraceFunction(mutableData.pyFrame.getLocalTraceFun(), null, threadState, virtualFrame, mutableData.pyFrame, PythonContext.TraceEvent.LINE,
1165+
mutableData.pastLine));
11571166
}
11581167
}
1159-
if (threadState.getTraceFun() != null) {
1160-
pastBci = bci;
1168+
if (!noTrace.isValid() && threadState.getTraceFun() != null) {
1169+
mutableData.pastBci = bci;
11611170
}
11621171

11631172
CompilerAsserts.partialEvaluationConstant(bc);
@@ -1613,11 +1622,11 @@ private Object bytecodeLoop(VirtualFrame virtualFrame, Frame localFrame, Bytecod
16131622
LoopNode.reportLoopCount(this, mutableData.loopCount);
16141623
}
16151624
Object value = virtualFrame.getObject(stackTop);
1616-
if (threadState.getTraceFun() != null && !threadState.isTracing()) {
1617-
pyFrame = ensurePyFrame(virtualFrame, pyFrame);
1618-
if (pyFrame.getLocalTraceFun() != null) {
1619-
invokeTraceFunction(pyFrame.getLocalTraceFun(), value, threadState, virtualFrame, pyFrame, PythonContext.TraceEvent.RETURN,
1620-
pyFrame.getTraceLine() ? returnLine : bciToLine(bci));
1625+
if (!noTrace.isValid() && threadState.getTraceFun() != null && !threadState.isTracing()) {
1626+
mutableData.pyFrame = ensurePyFrame(virtualFrame, mutableData.pyFrame);
1627+
if (mutableData.pyFrame.getLocalTraceFun() != null) {
1628+
invokeTraceFunction(mutableData.pyFrame.getLocalTraceFun(), value, threadState, virtualFrame, mutableData.pyFrame, PythonContext.TraceEvent.RETURN,
1629+
mutableData.pyFrame.getTraceLine() ? mutableData.returnLine : bciToLine(bci));
16211630
}
16221631
}
16231632
if (isGeneratorOrCoroutine) {
@@ -2058,11 +2067,11 @@ private Object bytecodeLoop(VirtualFrame virtualFrame, Frame localFrame, Bytecod
20582067
// Clear slots that were popped (if any)
20592068
clearFrameSlots(localFrame, stackTop + 1, initialStackTop);
20602069
}
2061-
if (threadState.getTraceFun() != null && !threadState.isTracing()) {
2062-
pyFrame = ensurePyFrame(virtualFrame, pyFrame);
2063-
if (pyFrame.getLocalTraceFun() != null) {
2064-
invokeTraceFunction(pyFrame.getLocalTraceFun(), value, threadState, virtualFrame, pyFrame, PythonContext.TraceEvent.RETURN,
2065-
pyFrame.getTraceLine() ? returnLine : bciToLine(bci));
2070+
if (!noTrace.isValid() && threadState.getTraceFun() != null && !threadState.isTracing()) {
2071+
mutableData.pyFrame = ensurePyFrame(virtualFrame, mutableData.pyFrame);
2072+
if (mutableData.pyFrame.getLocalTraceFun() != null) {
2073+
invokeTraceFunction(mutableData.pyFrame.getLocalTraceFun(), value, threadState, virtualFrame, mutableData.pyFrame, PythonContext.TraceEvent.RETURN,
2074+
mutableData.pyFrame.getTraceLine() ? mutableData.returnLine : bciToLine(bci));
20662075
}
20672076
}
20682077
return new GeneratorYieldResult(bci + 1, stackTop, value);
@@ -2152,16 +2161,13 @@ private Object bytecodeLoop(VirtualFrame virtualFrame, Frame localFrame, Bytecod
21522161
}
21532162
}
21542163

2155-
if (threadState.getTraceFun() != null && !threadState.isTracing() && pe != null) {
2156-
pyFrame = ensurePyFrame(virtualFrame, pyFrame);
2157-
if (pyFrame.getLocalTraceFun() != null) {
2158-
if (traceGetExceptionTracebackNode == null) {
2159-
CompilerDirectives.transferToInterpreterAndInvalidate();
2160-
traceGetExceptionTracebackNode = insert(GetExceptionTracebackNode.create());
2161-
}
2162-
Object traceback = traceGetExceptionTracebackNode.execute(pe);
2163-
pyFrame.setLocalTraceFun(invokeTraceFunction(pyFrame.getLocalTraceFun(),
2164-
factory.createTuple(new Object[]{pe.getClass(), pe.setCatchingFrameAndGetEscapedException(virtualFrame, this), traceback}), threadState, virtualFrame, pyFrame,
2164+
if (!noTrace.isValid() && threadState.getTraceFun() != null && !threadState.isTracing() && pe != null) {
2165+
mutableData.pyFrame = ensurePyFrame(virtualFrame, mutableData.pyFrame);
2166+
if (mutableData.pyFrame.getLocalTraceFun() != null) {
2167+
Object traceback = GetExceptionTracebackNode.getUncached().execute(pe);
2168+
mutableData.pyFrame.setLocalTraceFun(invokeTraceFunction(mutableData.pyFrame.getLocalTraceFun(),
2169+
factory.createTuple(new Object[]{pe.getClass(), pe.setCatchingFrameAndGetEscapedException(virtualFrame, this), traceback}), threadState, virtualFrame,
2170+
mutableData.pyFrame,
21652171
PythonContext.TraceEvent.EXCEPTION, bciToLine(bci)));
21662172
}
21672173
}
@@ -2194,11 +2200,11 @@ private Object bytecodeLoop(VirtualFrame virtualFrame, Frame localFrame, Bytecod
21942200
if (CompilerDirectives.hasNextTier() && mutableData.loopCount > 0) {
21952201
LoopNode.reportLoopCount(this, mutableData.loopCount);
21962202
}
2197-
if (threadState.getTraceFun() != null) {
2198-
pyFrame = ensurePyFrame(virtualFrame, pyFrame);
2199-
if (pyFrame.getLocalTraceFun() != null) {
2200-
invokeTraceFunction(pyFrame.getLocalTraceFun(), PNone.NONE, threadState, virtualFrame, pyFrame, PythonContext.TraceEvent.RETURN,
2201-
pyFrame.getTraceLine() ? returnLine : bciToLine(bci));
2203+
if (!noTrace.isValid() && threadState.getTraceFun() != null) {
2204+
mutableData.pyFrame = ensurePyFrame(virtualFrame, mutableData.pyFrame);
2205+
if (mutableData.pyFrame.getLocalTraceFun() != null) {
2206+
invokeTraceFunction(mutableData.pyFrame.getLocalTraceFun(), PNone.NONE, threadState, virtualFrame, mutableData.pyFrame, PythonContext.TraceEvent.RETURN,
2207+
mutableData.pyFrame.getTraceLine() ? mutableData.returnLine : bciToLine(bci));
22022208
}
22032209
}
22042210
if (e == pe) {
@@ -2236,22 +2242,19 @@ private PFrame ensurePyFrame(VirtualFrame virtualFrame, PFrame pyFrame) {
22362242
return pyFrame;
22372243
}
22382244

2245+
@HostCompilerDirectives.InliningCutoff
22392246
private Object invokeTraceFunction(Object traceFn, Object arg, PythonContext.PythonThreadState threadState, VirtualFrame virtualFrame, PFrame tracing,
22402247
PythonContext.TraceEvent event, int line) {
22412248
threadState.tracingStart(event);
22422249
Object nonNullArg = arg == null ? PNone.NONE : arg;
2243-
if (traceCallTernaryMethodNode == null) {
2244-
CompilerDirectives.transferToInterpreterAndInvalidate();
2245-
traceCallTernaryMethodNode = insert(CallTernaryMethodNode.create());
2246-
}
22472250
try {
22482251
if (line != -1) {
22492252
tracing.setLineLock(line);
22502253
}
2251-
Object result = traceCallTernaryMethodNode.execute(virtualFrame, traceFn, tracing, event.pythonName, nonNullArg);
2254+
Object result = CallTernaryMethodNode.getUncached().execute(null, traceFn, tracing, event.pythonName, nonNullArg);
22522255
return result == PNone.NONE ? null : result;
22532256
} catch (Throwable e) {
2254-
threadState.setTraceFun(null);
2257+
threadState.setTraceFun(null, cachedLanguage);
22552258
throw e;
22562259
} finally {
22572260
if (line != -1) {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/runtime/PythonContext.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,6 @@ public static final class PythonThreadState {
236236
*/
237237
PThreadState nativeWrapper;
238238

239-
/* Assume that no trace function was ever set. */
240-
public final Assumption noTracingInThread = Assumption.create("noTracingInThread");
241-
242239
/* The global tracing function, set by sys.settrace and returned by sys.gettrace. */
243240
Object traceFun;
244241

@@ -345,12 +342,12 @@ public void dispose() {
345342
}
346343

347344
public Object getTraceFun() {
348-
return noTracingInThread.isValid() ? null : traceFun;
345+
return traceFun;
349346
}
350347

351-
public void setTraceFun(Object traceFun) {
348+
public void setTraceFun(Object traceFun, PythonLanguage language) {
352349
if (this.traceFun != traceFun) {
353-
noTracingInThread.invalidate();
350+
language.noTracingAssumption.invalidate();
354351
this.traceFun = traceFun;
355352
}
356353
}

0 commit comments

Comments
 (0)