Skip to content

Commit e36675b

Browse files
committed
add tests for marshalling and assigning code with freevars
1 parent 7d13ddb commit e36675b

File tree

5 files changed

+118
-92
lines changed

5 files changed

+118
-92
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_functions.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,44 @@ def test_code_change():
196196
def foo():
197197
return "foo"
198198

199-
def bar():
200-
return "bar"
199+
def bar(a):
200+
return "bar" + str(a)
201201

202202
assert foo() == "foo"
203203
foo.__code__ = bar.__code__
204-
assert foo() == "bar"
204+
assert foo(1) == "bar1"
205+
assert_raises(TypeError, foo)
206+
207+
208+
def test_code_marshal_with_freevars():
209+
import marshal
210+
def foo():
211+
x,y = 1,2
212+
def bar():
213+
return x,y
214+
return bar
215+
216+
def baz():
217+
x,y = 2,3
218+
def bar():
219+
return y,x
220+
return bar
221+
222+
foobar_str = marshal.dumps(foo().__code__)
223+
print(foobar_str)
224+
foobar_code = marshal.loads(foobar_str)
225+
assert_raises(TypeError, exec, foobar_code)
226+
227+
bazbar = baz()
228+
assert bazbar() == (3,2)
229+
230+
def assign_code(x, y):
231+
if isinstance(y, type(assign_code)):
232+
x.__code__ = y.__code__
233+
else:
234+
x.__code__ = y
235+
236+
assert_raises(ValueError, assign_code, foo, bazbar)
237+
assert_raises(ValueError, assign_code, foo, foobar_code)
238+
bazbar.__code__ = foobar_code
239+
assert bazbar() == (2,3)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/BuiltinFunctions.java

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@
119119
import com.oracle.graal.python.builtins.objects.type.TypeNodes;
120120
import com.oracle.graal.python.nodes.BuiltinNames;
121121
import com.oracle.graal.python.nodes.GraalPythonTranslationErrorNode;
122-
import com.oracle.graal.python.nodes.PClosureRootNode;
123122
import com.oracle.graal.python.nodes.PGuards;
124123
import com.oracle.graal.python.nodes.SpecialMethodNames;
125124
import com.oracle.graal.python.nodes.argument.ReadArgumentNode;
@@ -597,6 +596,7 @@ private static BigInteger[] divideAndRemainder(PInt a, PInt b) {
597596
@GenerateNodeFactory
598597
public abstract static class EvalNode extends PythonBuiltinNode {
599598
protected final String funcname = "eval";
599+
private final BranchProfile hasFreeVarsBranch = BranchProfile.create();
600600
@Child protected CompileNode compileNode = CompileNode.create(false);
601601
@Child private IndirectCallNode indirectCallNode = IndirectCallNode.create();
602602
@Child private HasInheritedAttributeNode hasGetItemNode;
@@ -609,6 +609,14 @@ private HasInheritedAttributeNode getHasGetItemNode() {
609609
return hasGetItemNode;
610610
}
611611

612+
protected void assertNoFreeVars(PCode code) {
613+
Object[] freeVars = code.getFreeVars();
614+
if (freeVars.length > 0) {
615+
hasFreeVarsBranch.enter();
616+
throw raise(PythonBuiltinClassType.TypeError, "code object passed to eval/exec may not contain free variables");
617+
}
618+
}
619+
612620
protected boolean isMapping(Object object) {
613621
// tfel: it seems that CPython only checks that there is __getitem__
614622
if (object instanceof PDict) {
@@ -623,7 +631,9 @@ protected boolean isAnyNone(Object object) {
623631
}
624632

625633
protected PCode createAndCheckCode(Object source) {
626-
return compileNode.execute(source, "<string>", "eval", 0, false, -1);
634+
PCode code = compileNode.execute(source, "<string>", "eval", 0, false, -1);
635+
assertNoFreeVars(code);
636+
return code;
627637
}
628638

629639
private static void inheritGlobals(Frame callerFrame, Object[] args) {
@@ -724,23 +734,6 @@ PNone badLocals(@SuppressWarnings("unused") Object source, @SuppressWarnings("un
724734
@Builtin(name = EXEC, minNumOfPositionalArgs = 1, parameterNames = {"source", "globals", "locals"})
725735
@GenerateNodeFactory
726736
abstract static class ExecNode extends EvalNode {
727-
private final BranchProfile hasFreeVars = BranchProfile.create();
728-
729-
private void assertNoFreeVars(PCode code) {
730-
RootNode rootNode = code.getRootNode();
731-
if (rootNode instanceof PClosureRootNode && ((PClosureRootNode) rootNode).hasFreeVars()) {
732-
hasFreeVars.enter();
733-
throw raise(PythonBuiltinClassType.TypeError, "code object passed to exec() may not contain free variables");
734-
}
735-
}
736-
737-
@Override
738-
protected PCode createAndCheckCode(Object source) {
739-
PCode code = compileNode.execute(source, "<string>", "exec", 0, false, -1);
740-
assertNoFreeVars(code);
741-
return code;
742-
}
743-
744737
protected abstract Object executeInternal(VirtualFrame frame);
745738

746739
@Override

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/code/PCode.java

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@
7676
import com.oracle.truffle.api.source.SourceSection;
7777

7878
public final class PCode extends PythonBuiltinObject {
79-
private final long FLAG_POS_GENERATOR = 5;
80-
private final long FLAG_POS_VAR_ARGS = 2;
81-
private final long FLAG_POS_VAR_KW_ARGS = 3;
79+
private static final String[] EMPTY_STRINGS = new String[0];
80+
private final static long FLAG_POS_GENERATOR = 5;
81+
private final static long FLAG_POS_VAR_ARGS = 2;
82+
private final static long FLAG_POS_VAR_KW_ARGS = 3;
8283

8384
private final RootCallTarget callTarget;
8485
private final Arity arity;
@@ -149,8 +150,8 @@ public PCode(LazyPythonClass cls, int argcount, int kwonlyargcount,
149150
// Derive a new call target from the code string, if we can
150151
FrameDescriptor frameDescriptor = new FrameDescriptor();
151152
MaterializedFrame frame = Truffle.getRuntime().createMaterializedFrame(new Object[0], frameDescriptor);
152-
for (int i = 0; i < cellvars.length; i++) {
153-
Object ident = cellvars[i];
153+
for (int i = 0; i < freevars.length; i++) {
154+
Object ident = freevars[i];
154155
FrameSlot slot = frameDescriptor.addFrameSlot(ident);
155156
frameDescriptor.setFrameSlotKind(slot, FrameSlotKind.Object);
156157
frame.setObject(slot, new PCell());
@@ -185,8 +186,8 @@ public PCode(LazyPythonClass cls, int argcount, int kwonlyargcount,
185186
}
186187

187188
@TruffleBoundary
188-
private static Set<String> asSet(String[] values) {
189-
return (values != null) ? new HashSet<>(Arrays.asList(values)) : new HashSet<>();
189+
private static Set<Object> asSet(Object[] objects) {
190+
return (objects != null) ? new HashSet<>(Arrays.asList(objects)) : new HashSet<>();
190191
}
191192

192193
private static String[] extractFreeVars(RootNode rootNode) {
@@ -197,7 +198,7 @@ private static String[] extractFreeVars(RootNode rootNode) {
197198
} else if (rootNode instanceof ModuleRootNode) {
198199
return ((ModuleRootNode) rootNode).getFreeVars();
199200
} else {
200-
return null;
201+
return EMPTY_STRINGS;
201202
}
202203
}
203204

@@ -206,10 +207,8 @@ private static String[] extractCellVars(RootNode rootNode) {
206207
return ((FunctionRootNode) rootNode).getCellVars();
207208
} else if (rootNode instanceof GeneratorFunctionRootNode) {
208209
return ((GeneratorFunctionRootNode) rootNode).getCellVars();
209-
} else if (rootNode instanceof ModuleRootNode) {
210-
return new String[0];
211210
} else {
212-
return null;
211+
return EMPTY_STRINGS;
213212
}
214213
}
215214

@@ -221,7 +220,7 @@ private static String extractFileName(RootNode rootNode) {
221220
} else if (funcRootNode instanceof ModuleRootNode) {
222221
return funcRootNode.getName();
223222
} else {
224-
return null;
223+
return "<unknown source>";
225224
}
226225
}
227226

@@ -272,34 +271,15 @@ private static int extractStackSize(RootNode rootNode) {
272271
}
273272

274273
@TruffleBoundary
275-
private void extractArgStats() {
276-
// 0x20 - generator
277-
this.flags = 0;
278-
RootNode funcRootNode = getRootNode();
279-
if (funcRootNode instanceof GeneratorFunctionRootNode) {
280-
flags |= (1 << FLAG_POS_GENERATOR);
281-
funcRootNode = ((GeneratorFunctionRootNode) funcRootNode).getFunctionRootNode();
282-
}
283-
284-
// 0x04 - *arguments
285-
if (NodeUtil.findAllNodeInstances(funcRootNode, ReadVarArgsNode.class).size() == 1) {
286-
flags |= (1 << FLAG_POS_VAR_ARGS);
287-
}
288-
// 0x08 - **keywords
289-
if (NodeUtil.findAllNodeInstances(funcRootNode, ReadVarKeywordsNode.class).size() == 1) {
290-
flags |= (1 << FLAG_POS_VAR_KW_ARGS);
291-
}
292-
293-
this.freevars = extractFreeVars(getRootNode());
294-
this.cellvars = extractCellVars(getRootNode());
295-
Set<String> freeVarsSet = asSet((String[]) freevars);
296-
Set<String> cellVarsSet = asSet((String[]) cellvars);
274+
private static Object[] extractVarnames(RootNode rootNode, String[] parameterIds, String[] keywordNames, Object[] freeVars, Object[] cellVars) {
275+
Set<Object> freeVarsSet = asSet(freeVars);
276+
Set<Object> cellVarsSet = asSet(cellVars);
297277

298278
ArrayList<String> varNameList = new ArrayList<>(); // must be ordered!
299-
varNameList.addAll(Arrays.asList(arity.getParameterIds()));
300-
varNameList.addAll(Arrays.asList(arity.getKeywordNames()));
279+
varNameList.addAll(Arrays.asList(parameterIds));
280+
varNameList.addAll(Arrays.asList(keywordNames));
301281

302-
for (Object identifier : getRootNode().getFrameDescriptor().getIdentifiers()) {
282+
for (Object identifier : rootNode.getFrameDescriptor().getIdentifiers()) {
303283
if (identifier instanceof String) {
304284
String varName = (String) identifier;
305285

@@ -315,8 +295,29 @@ private void extractArgStats() {
315295
}
316296
}
317297

318-
this.varnames = varNameList.toArray();
319-
this.nlocals = varNameList.size();
298+
return varNameList.toArray();
299+
}
300+
301+
@TruffleBoundary
302+
private static int extractFlags(RootNode rootNode) {
303+
// 0x20 - generator
304+
int flags = 0;
305+
RootNode funcRootNode = rootNode;
306+
if (funcRootNode instanceof GeneratorFunctionRootNode) {
307+
flags |= (1 << FLAG_POS_GENERATOR);
308+
funcRootNode = ((GeneratorFunctionRootNode) funcRootNode).getFunctionRootNode();
309+
}
310+
311+
// 0x04 - *arguments
312+
if (NodeUtil.findFirstNodeInstance(funcRootNode, ReadVarArgsNode.class) != null) {
313+
flags |= (1 << FLAG_POS_VAR_ARGS);
314+
}
315+
// 0x08 - **keywords
316+
if (NodeUtil.findFirstNodeInstance(funcRootNode, ReadVarKeywordsNode.class) != null) {
317+
flags |= (1 << FLAG_POS_VAR_KW_ARGS);
318+
}
319+
320+
return flags;
320321
}
321322

322323
@TruffleBoundary
@@ -338,14 +339,14 @@ public RootNode getRootNode() {
338339

339340
public Object[] getFreeVars() {
340341
if (freevars == null) {
341-
extractArgStats();
342+
freevars = extractFreeVars(getRootNode());
342343
}
343344
return freevars;
344345
}
345346

346347
public Object[] getCellVars() {
347-
if (freevars == null) {
348-
extractArgStats();
348+
if (cellvars == null) {
349+
cellvars = extractCellVars(getRootNode());
349350
}
350351
return cellvars;
351352
}
@@ -385,7 +386,7 @@ public int getKwonlyargcount() {
385386

386387
public int getNlocals() {
387388
if (nlocals == -1) {
388-
extractArgStats();
389+
nlocals = getVarnames().length;
389390
}
390391
return nlocals;
391392
}
@@ -399,14 +400,14 @@ public int getStacksize() {
399400

400401
public int getFlags() {
401402
if (flags == -1) {
402-
extractArgStats();
403+
flags = extractFlags(getRootNode());
403404
}
404405
return flags;
405406
}
406407

407408
public Object[] getVarnames() {
408409
if (varnames == null) {
409-
extractArgStats();
410+
varnames = extractVarnames(getRootNode(), getArity().getParameterIds(), getArity().getKeywordNames(), getFreeVars(), getCellVars());
410411
}
411412
return varnames;
412413
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/function/AbstractFunctionBuiltins.java

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
import static com.oracle.graal.python.nodes.SpecialAttributeNames.__ANNOTATIONS__;
3030
import static com.oracle.graal.python.nodes.SpecialAttributeNames.__CLOSURE__;
31-
import static com.oracle.graal.python.nodes.SpecialAttributeNames.__CODE__;
3231
import static com.oracle.graal.python.nodes.SpecialAttributeNames.__DICT__;
3332
import static com.oracle.graal.python.nodes.SpecialAttributeNames.__GLOBALS__;
3433
import static com.oracle.graal.python.nodes.SpecialAttributeNames.__MODULE__;
@@ -45,7 +44,6 @@
4544
import com.oracle.graal.python.builtins.PythonBuiltins;
4645
import com.oracle.graal.python.builtins.objects.PNone;
4746
import com.oracle.graal.python.builtins.objects.cell.PCell;
48-
import com.oracle.graal.python.builtins.objects.code.PCode;
4947
import com.oracle.graal.python.builtins.objects.common.PHashingCollection;
5048
import com.oracle.graal.python.builtins.objects.method.PBuiltinMethod;
5149
import com.oracle.graal.python.builtins.objects.method.PMethod;
@@ -57,7 +55,6 @@
5755
import com.oracle.graal.python.nodes.call.CallDispatchNode;
5856
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
5957
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
60-
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
6158
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
6259
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
6360
import com.oracle.graal.python.nodes.subscript.GetItemNode;
@@ -230,28 +227,6 @@ Object getModule(PBuiltinFunction self, Object value) {
230227
}
231228
}
232229

233-
@Builtin(name = __CODE__, minNumOfPositionalArgs = 1, maxNumOfPositionalArgs = 2, isGetter = true, isSetter = true)
234-
@GenerateNodeFactory
235-
public abstract static class GetCodeNode extends PythonBinaryBuiltinNode {
236-
@Specialization(guards = {"!isBuiltinFunction(self)", "isNoValue(none)"})
237-
Object getCode(PFunction self, @SuppressWarnings("unused") PNone none) {
238-
return self.getCode();
239-
}
240-
241-
@SuppressWarnings("unused")
242-
@Specialization(guards = "!isBuiltinFunction(self)")
243-
Object setCode(PFunction self, PCode code) {
244-
self.setCode(code);
245-
return PNone.NONE;
246-
}
247-
248-
@SuppressWarnings("unused")
249-
@Specialization
250-
Object builtinCode(PBuiltinFunction self, Object none) {
251-
throw raise(AttributeError, "'builtin_function_or_method' object has no attribute '__code__'");
252-
}
253-
}
254-
255230
@Builtin(name = __DICT__, minNumOfPositionalArgs = 1, isGetter = true)
256231
@GenerateNodeFactory
257232
abstract static class DictNode extends PythonUnaryBuiltinNode {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/function/FunctionBuiltins.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
package com.oracle.graal.python.builtins.objects.function;
2828

29+
import static com.oracle.graal.python.nodes.SpecialAttributeNames.__CODE__;
2930
import static com.oracle.graal.python.nodes.SpecialAttributeNames.__DEFAULTS__;
3031
import static com.oracle.graal.python.nodes.SpecialAttributeNames.__KWDEFAULTS__;
3132
import static com.oracle.graal.python.nodes.SpecialAttributeNames.__NAME__;
@@ -42,6 +43,7 @@
4243
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
4344
import com.oracle.graal.python.builtins.PythonBuiltins;
4445
import com.oracle.graal.python.builtins.objects.PNone;
46+
import com.oracle.graal.python.builtins.objects.code.PCode;
4547
import com.oracle.graal.python.builtins.objects.common.HashingStorage;
4648
import com.oracle.graal.python.builtins.objects.dict.PDict;
4749
import com.oracle.graal.python.builtins.objects.function.FunctionBuiltinsFactory.GetFunctionDefaultsNodeFactory;
@@ -240,4 +242,24 @@ Object doGeneric(Object object) {
240242
}
241243
}
242244

245+
@Builtin(name = __CODE__, minNumOfPositionalArgs = 1, maxNumOfPositionalArgs = 2, isGetter = true, isSetter = true)
246+
@GenerateNodeFactory
247+
public abstract static class GetCodeNode extends PythonBinaryBuiltinNode {
248+
@Specialization(guards = {"isNoValue(none)"})
249+
Object getCode(PFunction self, @SuppressWarnings("unused") PNone none) {
250+
return self.getCode();
251+
}
252+
253+
@SuppressWarnings("unused")
254+
@Specialization
255+
Object setCode(PFunction self, PCode code) {
256+
int closureLength = self.getClosure().length;
257+
int freeVarsLength = code.getFreeVars().length;
258+
if (closureLength != freeVarsLength) {
259+
throw raise(PythonBuiltinClassType.ValueError, "%s() requires a code object with %d free vars, not %d", self.getName(), closureLength, freeVarsLength);
260+
}
261+
self.setCode(code);
262+
return PNone.NONE;
263+
}
264+
}
243265
}

0 commit comments

Comments
 (0)