Skip to content

Commit 0a1c37c

Browse files
committed
Make cloudpickle's hack with swapping cellvars and freevars work
1 parent 4f40531 commit 0a1c37c

File tree

5 files changed

+45
-17
lines changed

5 files changed

+45
-17
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinConstructors.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
import static com.oracle.graal.python.runtime.exception.PythonErrorType.RuntimeError;
8787
import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
8888
import static com.oracle.graal.python.runtime.exception.PythonErrorType.ValueError;
89+
import static com.oracle.graal.python.util.PythonUtils.objectArrayToTruffleStringArray;
8990
import static com.oracle.graal.python.util.PythonUtils.tsLiteral;
9091

9192
import java.math.BigInteger;
@@ -2486,15 +2487,16 @@ PCode call(VirtualFrame frame, @SuppressWarnings("unused") Object cls, int argco
24862487
PTuple freevars, PTuple cellvars,
24872488
@CachedLibrary(limit = "1") PythonBufferAccessLibrary bufferLib,
24882489
@Cached CodeNodes.CreateCodeNode createCodeNode,
2489-
@Cached GetObjectArrayNode getObjectArrayNode) {
2490+
@Cached GetObjectArrayNode getObjectArrayNode,
2491+
@Cached CastToTruffleStringNode castToTruffleStringNode) {
24902492
byte[] codeBytes = bufferLib.getCopiedByteArray(codestring);
24912493
byte[] lnotabBytes = bufferLib.getCopiedByteArray(lnotab);
24922494

24932495
Object[] constantsArr = getObjectArrayNode.execute(constants);
2494-
Object[] namesArr = getObjectArrayNode.execute(names);
2495-
Object[] varnamesArr = getObjectArrayNode.execute(varnames);
2496-
Object[] freevarsArr = getObjectArrayNode.execute(freevars);
2497-
Object[] cellcarsArr = getObjectArrayNode.execute(cellvars);
2496+
TruffleString[] namesArr = objectArrayToTruffleStringArray(getObjectArrayNode.execute(names), castToTruffleStringNode);
2497+
TruffleString[] varnamesArr = objectArrayToTruffleStringArray(getObjectArrayNode.execute(varnames), castToTruffleStringNode);
2498+
TruffleString[] freevarsArr = objectArrayToTruffleStringArray(getObjectArrayNode.execute(freevars), castToTruffleStringNode);
2499+
TruffleString[] cellcarsArr = objectArrayToTruffleStringArray(getObjectArrayNode.execute(cellvars), castToTruffleStringNode);
24982500

24992501
return createCodeNode.execute(frame, argcount, posonlyargcount, kwonlyargcount,
25002502
nlocals, stacksize, flags,

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cext/PythonCextCodeBuiltins.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import static com.oracle.graal.python.util.PythonUtils.EMPTY_BYTE_ARRAY;
4444
import static com.oracle.graal.python.util.PythonUtils.EMPTY_OBJECT_ARRAY;
45+
import static com.oracle.graal.python.util.PythonUtils.EMPTY_TRUFFLESTRING_ARRAY;
4546

4647
import java.util.List;
4748

@@ -126,8 +127,8 @@ abstract static class PyCodeNewEmpty extends PythonTernaryBuiltinNode {
126127
static PCode newEmpty(TruffleString filename, TruffleString funcname, int lineno,
127128
@Cached CodeNodes.CreateCodeNode createCodeNode) {
128129
return createCodeNode.execute(null, 0, 0, 0, 0, 0, 0,
129-
EMPTY_BYTE_ARRAY, EMPTY_OBJECT_ARRAY, EMPTY_OBJECT_ARRAY, EMPTY_OBJECT_ARRAY, EMPTY_OBJECT_ARRAY, EMPTY_OBJECT_ARRAY,
130-
filename, funcname, lineno, EMPTY_BYTE_ARRAY);
130+
EMPTY_BYTE_ARRAY, EMPTY_OBJECT_ARRAY, EMPTY_TRUFFLESTRING_ARRAY, EMPTY_TRUFFLESTRING_ARRAY, EMPTY_TRUFFLESTRING_ARRAY,
131+
EMPTY_TRUFFLESTRING_ARRAY, filename, funcname, lineno, EMPTY_BYTE_ARRAY);
131132
}
132133
}
133134
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/code/CodeBuiltins.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import static com.oracle.graal.python.nodes.SpecialMethodNames.J___REPR__;
3434
import static com.oracle.graal.python.nodes.StringLiterals.T_NONE;
3535
import static com.oracle.graal.python.util.PythonUtils.TS_ENCODING;
36+
import static com.oracle.graal.python.util.PythonUtils.objectArrayToTruffleStringArray;
3637

3738
import java.util.Arrays;
3839
import java.util.List;
@@ -54,6 +55,7 @@
5455
import com.oracle.graal.python.nodes.function.builtins.PythonClinicBuiltinNode;
5556
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
5657
import com.oracle.graal.python.nodes.function.builtins.clinic.ArgumentClinicProvider;
58+
import com.oracle.graal.python.nodes.util.CastToTruffleStringNode;
5759
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
5860
import com.oracle.graal.python.util.PythonUtils;
5961
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
@@ -409,6 +411,7 @@ PCode create(VirtualFrame frame, PCode self, int coArgcount,
409411
Object[] coCellvars, TruffleString coFilename,
410412
TruffleString coName, Object coLnotab,
411413
@Cached CodeNodes.CreateCodeNode createCodeNode,
414+
@Cached CastToTruffleStringNode castToTruffleStringNode,
412415
@CachedLibrary(limit = "2") PythonBufferAccessLibrary bufferLib) {
413416
try {
414417
return createCodeNode.execute(frame,
@@ -420,10 +423,10 @@ PCode create(VirtualFrame frame, PCode self, int coArgcount,
420423
coFlags == -1 ? self.co_flags() : coFlags,
421424
PGuards.isNone(coCode) ? self.getCodestring() : bufferLib.getInternalOrCopiedByteArray(coCode),
422425
coConsts.length == 0 ? null : coConsts,
423-
coNames.length == 0 ? self.getNames() : coNames,
424-
coVarnames.length == 0 ? self.getVarnames() : coVarnames,
425-
coFreevars.length == 0 ? self.getFreeVars() : coFreevars,
426-
coCellvars.length == 0 ? self.getCellVars() : coCellvars,
426+
coNames.length == 0 ? null : objectArrayToTruffleStringArray(coNames, castToTruffleStringNode),
427+
coVarnames.length == 0 ? null : objectArrayToTruffleStringArray(coVarnames, castToTruffleStringNode),
428+
coFreevars.length == 0 ? null : objectArrayToTruffleStringArray(coFreevars, castToTruffleStringNode),
429+
coCellvars.length == 0 ? null : objectArrayToTruffleStringArray(coCellvars, castToTruffleStringNode),
427430
coFilename.isEmpty() ? self.co_filename() : coFilename,
428431
coName.isEmpty() ? self.co_name() : coName,
429432
coFirstlineno == -1 ? self.co_firstlineno() : coFirstlineno,

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/code/CodeNodes.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import static com.oracle.graal.python.nodes.truffle.TruffleStringMigrationPythonTypes.assertNoJavaString;
4444

4545
import java.util.ArrayList;
46+
import java.util.Arrays;
4647
import java.util.List;
4748

4849
import org.graalvm.polyglot.io.ByteSequence;
@@ -101,8 +102,8 @@ public Assumption needNotPassExceptionAssumption() {
101102
public PCode execute(VirtualFrame frame, int argcount,
102103
int posonlyargcount, int kwonlyargcount,
103104
int nlocals, int stacksize, int flags,
104-
byte[] codedata, Object[] constants, Object[] names,
105-
Object[] varnames, Object[] freevars, Object[] cellvars,
105+
byte[] codedata, Object[] constants, TruffleString[] names,
106+
TruffleString[] varnames, TruffleString[] freevars, TruffleString[] cellvars,
106107
TruffleString filename, TruffleString name, int firstlineno,
107108
byte[] lnotab) {
108109

@@ -122,8 +123,8 @@ public PCode execute(VirtualFrame frame, int argcount,
122123
private static PCode createCode(PythonLanguage language, PythonContext context, @SuppressWarnings("unused") int argcount,
123124
@SuppressWarnings("unused") int posonlyargcount, @SuppressWarnings("unused") int kwonlyargcount,
124125
int nlocals, int stacksize, int flags,
125-
byte[] codedata, Object[] constants, Object[] names,
126-
Object[] varnames, Object[] freevars, Object[] cellvars,
126+
byte[] codedata, Object[] constants, TruffleString[] names,
127+
TruffleString[] varnames, TruffleString[] freevars, TruffleString[] cellvars,
127128
TruffleString filename, TruffleString name, int firstlineno,
128129
byte[] lnotab) {
129130

@@ -132,7 +133,7 @@ private static PCode createCode(PythonLanguage language, PythonContext context,
132133
ct = language.createCachedCallTarget(l -> new BadOPCodeNode(l, name), BadOPCodeNode.class, filename, name);
133134
} else {
134135
if (context.getOption(PythonOptions.EnableBytecodeInterpreter)) {
135-
ct = create().deserializeForBytecodeInterpreter(language, codedata);
136+
ct = create().deserializeForBytecodeInterpreter(language, codedata, cellvars, freevars);
136137
} else {
137138
RootNode rootNode = context.getSerializer().deserialize(context, codedata, toStringArray(cellvars), toStringArray(freevars));
138139
ct = PythonUtils.getOrCreateCallTarget(rootNode);
@@ -146,8 +147,17 @@ private static PCode createCode(PythonLanguage language, PythonContext context,
146147
firstlineno, lnotab);
147148
}
148149

149-
private RootCallTarget deserializeForBytecodeInterpreter(PythonLanguage language, byte[] data) {
150+
private RootCallTarget deserializeForBytecodeInterpreter(PythonLanguage language, byte[] data, TruffleString[] cellvars, TruffleString[] freevars) {
150151
CodeUnit code = MarshalModuleBuiltins.deserializeCodeUnit(data);
152+
if (cellvars != null && !Arrays.equals(code.cellvars, cellvars) || freevars != null && !Arrays.equals(code.freevars, freevars)) {
153+
code = new CodeUnit(code.name, code.qualname, code.argCount, code.kwOnlyArgCount, code.positionalOnlyArgCount, code.stacksize, code.code,
154+
code.srcOffsetTable, code.flags, code.names, code.varnames,
155+
cellvars != null ? cellvars : code.cellvars, freevars != null ? freevars : code.freevars,
156+
code.cell2arg, code.constants, code.primitiveConstants, code.exceptionHandlerRanges, code.conditionProfileCount,
157+
code.startOffset, code.startLine,
158+
code.outputCanQuicken, code.variableShouldUnbox,
159+
code.generalizeInputsMap, code.generalizeVarsMap);
160+
}
151161
RootNode rootNode = new PBytecodeRootNode(language, code, null, null);
152162
if (code.isGeneratorOrCoroutine()) {
153163
rootNode = new PBytecodeGeneratorFunctionRootNode(language, rootNode.getFrameDescriptor(), (PBytecodeRootNode) rootNode, code.name);

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/util/PythonUtils.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
7373
import com.oracle.graal.python.nodes.function.BuiltinFunctionRootNode;
7474
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
75+
import com.oracle.graal.python.nodes.util.CastToTruffleStringNode;
7576
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
7677
import com.oracle.graal.python.runtime.object.PythonObjectSlowPathFactory;
7778
import com.oracle.truffle.api.CallTarget;
@@ -678,4 +679,15 @@ public static void copyFrameSlot(Frame frameToSync, MaterializedFrame target, in
678679
target.setByte(slot, frameToSync.getByte(slot));
679680
}
680681
}
682+
683+
public static TruffleString[] objectArrayToTruffleStringArray(Object[] array, CastToTruffleStringNode cast) {
684+
if (array.length == 0) {
685+
return EMPTY_TRUFFLESTRING_ARRAY;
686+
}
687+
TruffleString[] result = new TruffleString[array.length];
688+
for (int i = 0; i < array.length; i++) {
689+
result[i] = cast.execute(array[i]);
690+
}
691+
return result;
692+
}
681693
}

0 commit comments

Comments
 (0)