@@ -731,7 +731,7 @@ def emit_sequence_constructor(self, name, type):
731731class PyTypesDeclareVisitor (PickleVisitor ):
732732
733733 def visitProduct (self , prod , name ):
734- self .emit ("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name , 0 )
734+ self .emit ("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name , 0 )
735735 if prod .attributes :
736736 self .emit ("static const char * const %s_attributes[] = {" % name , 0 )
737737 for a in prod .attributes :
@@ -752,7 +752,7 @@ def visitSum(self, sum, name):
752752 ptype = "void*"
753753 if is_simple (sum ):
754754 ptype = get_c_type (name )
755- self .emit ("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name , ptype ), 0 )
755+ self .emit ("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name , ptype ), 0 )
756756 for t in sum .types :
757757 self .visitConstructor (t , name )
758758
@@ -984,15 +984,16 @@ def visitModule(self, mod):
984984
985985/* Conversion AST -> Python */
986986
987- static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
987+ static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
988+ PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
988989{
989990 Py_ssize_t i, n = asdl_seq_LEN(seq);
990991 PyObject *result = PyList_New(n);
991992 PyObject *value;
992993 if (!result)
993994 return NULL;
994995 for (i = 0; i < n; i++) {
995- value = func(state, asdl_seq_GET_UNTYPED(seq, i));
996+ value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
996997 if (!value) {
997998 Py_DECREF(result);
998999 return NULL;
@@ -1002,7 +1003,7 @@ def visitModule(self, mod):
10021003 return result;
10031004}
10041005
1005- static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
1006+ static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
10061007{
10071008 PyObject *op = (PyObject*)o;
10081009 if (!op) {
@@ -1014,7 +1015,7 @@ def visitModule(self, mod):
10141015#define ast2obj_identifier ast2obj_object
10151016#define ast2obj_string ast2obj_object
10161017
1017- static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
1018+ static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
10181019{
10191020 return PyLong_FromLong(b);
10201021}
@@ -1116,8 +1117,6 @@ def visitModule(self, mod):
11161117 for dfn in mod .dfns :
11171118 self .visit (dfn )
11181119 self .file .write (textwrap .dedent ('''
1119- state->recursion_depth = 0;
1120- state->recursion_limit = 0;
11211120 return 0;
11221121 }
11231122 ''' ))
@@ -1260,25 +1259,25 @@ class ObjVisitor(PickleVisitor):
12601259 def func_begin (self , name ):
12611260 ctype = get_c_type (name )
12621261 self .emit ("PyObject*" , 0 )
1263- self .emit ("ast2obj_%s(struct ast_state *state, void* _o)" % (name ), 0 )
1262+ self .emit ("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name ), 0 )
12641263 self .emit ("{" , 0 )
12651264 self .emit ("%s o = (%s)_o;" % (ctype , ctype ), 1 )
12661265 self .emit ("PyObject *result = NULL, *value = NULL;" , 1 )
12671266 self .emit ("PyTypeObject *tp;" , 1 )
12681267 self .emit ('if (!o) {' , 1 )
12691268 self .emit ("Py_RETURN_NONE;" , 2 )
12701269 self .emit ("}" , 1 )
1271- self .emit ("if (++state ->recursion_depth > state ->recursion_limit) {" , 1 )
1270+ self .emit ("if (++vstate ->recursion_depth > vstate ->recursion_limit) {" , 1 )
12721271 self .emit ("PyErr_SetString(PyExc_RecursionError," , 2 )
12731272 self .emit ('"maximum recursion depth exceeded during ast construction");' , 3 )
12741273 self .emit ("return NULL;" , 2 )
12751274 self .emit ("}" , 1 )
12761275
12771276 def func_end (self ):
1278- self .emit ("state ->recursion_depth--;" , 1 )
1277+ self .emit ("vstate ->recursion_depth--;" , 1 )
12791278 self .emit ("return result;" , 1 )
12801279 self .emit ("failed:" , 0 )
1281- self .emit ("state ->recursion_depth--;" , 1 )
1280+ self .emit ("vstate ->recursion_depth--;" , 1 )
12821281 self .emit ("Py_XDECREF(value);" , 1 )
12831282 self .emit ("Py_XDECREF(result);" , 1 )
12841283 self .emit ("return NULL;" , 1 )
@@ -1296,15 +1295,15 @@ def visitSum(self, sum, name):
12961295 self .visitConstructor (t , i + 1 , name )
12971296 self .emit ("}" , 1 )
12981297 for a in sum .attributes :
1299- self .emit ("value = ast2obj_%s(state, o->%s);" % (a .type , a .name ), 1 )
1298+ self .emit ("value = ast2obj_%s(state, vstate, o->%s);" % (a .type , a .name ), 1 )
13001299 self .emit ("if (!value) goto failed;" , 1 )
13011300 self .emit ('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a .name , 1 )
13021301 self .emit ('goto failed;' , 2 )
13031302 self .emit ('Py_DECREF(value);' , 1 )
13041303 self .func_end ()
13051304
13061305 def simpleSum (self , sum , name ):
1307- self .emit ("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name , name ), 0 )
1306+ self .emit ("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name , name ), 0 )
13081307 self .emit ("{" , 0 )
13091308 self .emit ("switch(o) {" , 1 )
13101309 for t in sum .types :
@@ -1322,7 +1321,7 @@ def visitProduct(self, prod, name):
13221321 for field in prod .fields :
13231322 self .visitField (field , name , 1 , True )
13241323 for a in prod .attributes :
1325- self .emit ("value = ast2obj_%s(state, o->%s);" % (a .type , a .name ), 1 )
1324+ self .emit ("value = ast2obj_%s(state, vstate, o->%s);" % (a .type , a .name ), 1 )
13261325 self .emit ("if (!value) goto failed;" , 1 )
13271326 self .emit ("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a .name , 1 )
13281327 self .emit ('goto failed;' , 2 )
@@ -1363,7 +1362,7 @@ def set(self, field, value, depth):
13631362 self .emit ("for(i = 0; i < n; i++)" , depth + 1 )
13641363 # This cannot fail, so no need for error handling
13651364 self .emit (
1366- "PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));" .format (
1365+ "PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));" .format (
13671366 field .type ,
13681367 value
13691368 ),
@@ -1372,9 +1371,9 @@ def set(self, field, value, depth):
13721371 )
13731372 self .emit ("}" , depth )
13741373 else :
1375- self .emit ("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value , field .type ), depth )
1374+ self .emit ("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value , field .type ), depth )
13761375 else :
1377- self .emit ("value = ast2obj_%s(state, %s);" % (field .type , value ), depth , reflow = False )
1376+ self .emit ("value = ast2obj_%s(state, vstate, %s);" % (field .type , value ), depth , reflow = False )
13781377
13791378
13801379class PartingShots (StaticVisitor ):
@@ -1394,18 +1393,19 @@ class PartingShots(StaticVisitor):
13941393 if (!tstate) {
13951394 return NULL;
13961395 }
1397- state->recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
1396+ struct validator vstate;
1397+ vstate.recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
13981398 int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
13991399 starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
1400- state-> recursion_depth = starting_recursion_depth;
1400+ vstate. recursion_depth = starting_recursion_depth;
14011401
1402- PyObject *result = ast2obj_mod(state, t);
1402+ PyObject *result = ast2obj_mod(state, &vstate, t);
14031403
14041404 /* Check that the recursion depth counting balanced correctly */
1405- if (result && state-> recursion_depth != starting_recursion_depth) {
1405+ if (result && vstate. recursion_depth != starting_recursion_depth) {
14061406 PyErr_Format(PyExc_SystemError,
14071407 "AST constructor recursion depth mismatch (before=%d, after=%d)",
1408- starting_recursion_depth, state-> recursion_depth);
1408+ starting_recursion_depth, vstate. recursion_depth);
14091409 return NULL;
14101410 }
14111411 return result;
@@ -1475,8 +1475,6 @@ def generate_ast_state(module_state, f):
14751475 f .write ('struct ast_state {\n ' )
14761476 f .write (' _PyOnceFlag once;\n ' )
14771477 f .write (' int finalized;\n ' )
1478- f .write (' int recursion_depth;\n ' )
1479- f .write (' int recursion_limit;\n ' )
14801478 for s in module_state :
14811479 f .write (' PyObject *' + s + ';\n ' )
14821480 f .write ('};' )
@@ -1538,6 +1536,11 @@ def generate_module_def(mod, metadata, f, internal_h):
15381536 #include "pycore_interp.h" // _PyInterpreterState.ast
15391537 #include "pycore_pystate.h" // _PyInterpreterState_GET()
15401538 #include <stddef.h>
1539+
1540+ struct validator {
1541+ int recursion_depth; /* current recursion depth */
1542+ int recursion_limit; /* recursion limit */
1543+ };
15411544
15421545 // Forward declaration
15431546 static int init_types(struct ast_state *state);
0 commit comments