@@ -15,6 +15,19 @@ typedef struct {
1515 int recursion_limit ; /* recursion limit */
1616} _PyASTOptimizeState ;
1717
18+ #define ENTER_RECURSIVE (ST ) \
19+ do { \
20+ if (++(ST)->recursion_depth > (ST)->recursion_limit) { \
21+ PyErr_SetString(PyExc_RecursionError, \
22+ "maximum recursion depth exceeded during compilation"); \
23+ return 0; \
24+ } \
25+ } while(0)
26+
27+ #define LEAVE_RECURSIVE (ST ) \
28+ do { \
29+ --(ST)->recursion_depth; \
30+ } while(0)
1831
1932static int
2033make_const (expr_ty node , PyObject * val , PyArena * arena )
@@ -708,11 +721,7 @@ astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
708721static int
709722astfold_expr (expr_ty node_ , PyArena * ctx_ , _PyASTOptimizeState * state )
710723{
711- if (++ state -> recursion_depth > state -> recursion_limit ) {
712- PyErr_SetString (PyExc_RecursionError ,
713- "maximum recursion depth exceeded during compilation" );
714- return 0 ;
715- }
724+ ENTER_RECURSIVE (state );
716725 switch (node_ -> kind ) {
717726 case BoolOp_kind :
718727 CALL_SEQ (astfold_expr , expr , node_ -> v .BoolOp .values );
@@ -811,7 +820,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
811820 case Name_kind :
812821 if (node_ -> v .Name .ctx == Load &&
813822 _PyUnicode_EqualToASCIIString (node_ -> v .Name .id , "__debug__" )) {
814- state -> recursion_depth -- ;
823+ LEAVE_RECURSIVE ( state ) ;
815824 return make_const (node_ , PyBool_FromLong (!state -> optimize ), ctx_ );
816825 }
817826 break ;
@@ -824,7 +833,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
824833 // No default case, so the compiler will emit a warning if new expression
825834 // kinds are added without being handled here
826835 }
827- state -> recursion_depth -- ;
836+ LEAVE_RECURSIVE ( state ); ;
828837 return 1 ;
829838}
830839
@@ -871,11 +880,7 @@ astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
871880static int
872881astfold_stmt (stmt_ty node_ , PyArena * ctx_ , _PyASTOptimizeState * state )
873882{
874- if (++ state -> recursion_depth > state -> recursion_limit ) {
875- PyErr_SetString (PyExc_RecursionError ,
876- "maximum recursion depth exceeded during compilation" );
877- return 0 ;
878- }
883+ ENTER_RECURSIVE (state );
879884 switch (node_ -> kind ) {
880885 case FunctionDef_kind :
881886 CALL_SEQ (astfold_type_param , type_param , node_ -> v .FunctionDef .type_params );
@@ -999,7 +1004,7 @@ astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
9991004 // No default case, so the compiler will emit a warning if new statement
10001005 // kinds are added without being handled here
10011006 }
1002- state -> recursion_depth -- ;
1007+ LEAVE_RECURSIVE ( state ) ;
10031008 return 1 ;
10041009}
10051010
@@ -1031,11 +1036,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
10311036 // Currently, this is really only used to form complex/negative numeric
10321037 // constants in MatchValue and MatchMapping nodes
10331038 // We still recurse into all subexpressions and subpatterns anyway
1034- if (++ state -> recursion_depth > state -> recursion_limit ) {
1035- PyErr_SetString (PyExc_RecursionError ,
1036- "maximum recursion depth exceeded during compilation" );
1037- return 0 ;
1038- }
1039+ ENTER_RECURSIVE (state );
10391040 switch (node_ -> kind ) {
10401041 case MatchValue_kind :
10411042 CALL (astfold_expr , expr_ty , node_ -> v .MatchValue .value );
@@ -1067,7 +1068,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
10671068 // No default case, so the compiler will emit a warning if new pattern
10681069 // kinds are added without being handled here
10691070 }
1070- state -> recursion_depth -- ;
1071+ LEAVE_RECURSIVE ( state ) ;
10711072 return 1 ;
10721073}
10731074
0 commit comments