Skip to content

Commit fc910a3

Browse files
committed
Hide that C recursion protection is implemented with a counter. There is an imbalance in the AST somewhere.
1 parent bff4bfe commit fc910a3

File tree

18 files changed

+618
-805
lines changed

18 files changed

+618
-805
lines changed

Include/ceval.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ PyAPI_FUNC(int) Py_GetRecursionLimit(void);
6060
PyAPI_FUNC(int) Py_EnterRecursiveCall(const char *where);
6161
PyAPI_FUNC(void) Py_LeaveRecursiveCall(void);
6262

63+
PyAPI_FUNC(int) Py_ReachedRecursionLimit(PyThreadState *tstate, int margin_count);
64+
PyAPI_FUNC(void) _Py_EnterRecursiveCallUnchecked(PyThreadState *tstate);
65+
PyAPI_FUNC(void) Py_LeaveRecursiveCallTstate(PyThreadState *tstate);
66+
6367
PyAPI_FUNC(const char *) PyEval_GetFuncName(PyObject *);
6468
PyAPI_FUNC(const char *) PyEval_GetFuncDesc(PyObject *);
6569

Include/cpython/object.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,15 +490,15 @@ PyAPI_FUNC(void) _PyTrash_thread_destroy_chain(PyThreadState *tstate);
490490
#define Py_TRASHCAN_BEGIN(op, dealloc) \
491491
do { \
492492
PyThreadState *tstate = PyThreadState_Get(); \
493-
if (tstate->c_recursion_remaining <= Py_TRASHCAN_HEADROOM && Py_TYPE(op)->tp_dealloc == (destructor)dealloc) { \
493+
if (Py_ReachedRecursionLimit(tstate, 1) && Py_TYPE(op)->tp_dealloc == (destructor)dealloc) { \
494494
_PyTrash_thread_deposit_object(tstate, (PyObject *)op); \
495495
break; \
496496
} \
497-
tstate->c_recursion_remaining--;
497+
_Py_EnterRecursiveCallUnchecked(tstate);
498498
/* The body of the deallocator is here. */
499499
#define Py_TRASHCAN_END \
500-
tstate->c_recursion_remaining++; \
501-
if (tstate->delete_later && tstate->c_recursion_remaining > (Py_TRASHCAN_HEADROOM*2)) { \
500+
Py_LeaveRecursiveCallTstate(tstate); \
501+
if (tstate->delete_later && !Py_ReachedRecursionLimit(tstate, 2)) { \
502502
_PyTrash_thread_destroy_chain(tstate); \
503503
} \
504504
} while (0);

Include/internal/pycore_ceval.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ static inline int _Py_EnterRecursiveCallTstate(PyThreadState *tstate,
221221
}
222222

223223
static inline void _Py_EnterRecursiveCallTstateUnchecked(PyThreadState *tstate) {
224-
assert(tstate->c_recursion_remaining > 0);
224+
assert(tstate->c_recursion_remaining >= -2); // Allow a bit of wiggle room
225225
tstate->c_recursion_remaining--;
226226
}
227227

@@ -234,6 +234,12 @@ static inline void _Py_LeaveRecursiveCallTstate(PyThreadState *tstate) {
234234
tstate->c_recursion_remaining++;
235235
}
236236

237+
#define Py_RECURSION_LIMIT_MARGIN_MULTIPLIER 50
238+
239+
static inline int _Py_ReachedRecursionLimit(PyThreadState *tstate, int margin_count) {
240+
return tstate->c_recursion_remaining <= margin_count * Py_RECURSION_LIMIT_MARGIN_MULTIPLIER;
241+
}
242+
237243
static inline void _Py_LeaveRecursiveCall(void) {
238244
PyThreadState *tstate = _PyThreadState_GET();
239245
_Py_LeaveRecursiveCallTstate(tstate);

Include/internal/pycore_symtable.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ struct symtable {
8282
PyObject *st_private; /* name of current class or NULL */
8383
_PyFutureFeatures *st_future; /* module's future features that affect
8484
the symbol table */
85-
int recursion_depth; /* current recursion depth */
86-
int recursion_limit; /* recursion limit */
8785
};
8886

8987
typedef struct _symtable_entry {

Lib/test/test_ast/test_ast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def next(self):
749749
def test_ast_recursion_limit(self):
750750
fail_depth = support.exceeds_recursion_limit()
751751
crash_depth = 100_000
752-
success_depth = int(support.get_c_recursion_limit() * 0.8)
752+
success_depth = int(support.get_c_recursion_limit() * 0.6)
753753
if _testinternalcapi is not None:
754754
remaining = _testinternalcapi.get_c_recursion_remaining()
755755
success_depth = min(success_depth, remaining)

Lib/test/test_capi/test_misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def test_trashcan_subclass(self):
408408
# activated when its tp_dealloc is being called by a subclass
409409
from _testcapi import MyList
410410
L = None
411-
for i in range(1000):
411+
for i in range(support.get_c_recursion_limit()):
412412
L = MyList((L,))
413413

414414
@support.requires_resource('cpu')

Objects/object.c

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2910,8 +2910,7 @@ _PyTrash_thread_destroy_chain(PyThreadState *tstate)
29102910
tups = [(tup,) for tup in tups]
29112911
del tups
29122912
*/
2913-
assert(tstate->c_recursion_remaining > Py_TRASHCAN_HEADROOM);
2914-
tstate->c_recursion_remaining--;
2913+
_Py_EnterRecursiveCallTstateUnchecked(tstate);
29152914
while (tstate->delete_later) {
29162915
PyObject *op = tstate->delete_later;
29172916
destructor dealloc = Py_TYPE(op)->tp_dealloc;
@@ -2933,7 +2932,7 @@ _PyTrash_thread_destroy_chain(PyThreadState *tstate)
29332932
_PyObject_ASSERT(op, Py_REFCNT(op) == 0);
29342933
(*dealloc)(op);
29352934
}
2936-
tstate->c_recursion_remaining++;
2935+
_Py_LeaveRecursiveCallTstate(tstate);
29372936
}
29382937

29392938
void _Py_NO_RETURN

Parser/asdl_c.py

Lines changed: 18 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def emit_sequence_constructor(self, name, type):
738738
class PyTypesDeclareVisitor(PickleVisitor):
739739

740740
def visitProduct(self, prod, name):
741-
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0)
741+
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
742742
if prod.attributes:
743743
self.emit("static const char * const %s_attributes[] = {" % name, 0)
744744
for a in prod.attributes:
@@ -759,7 +759,7 @@ def visitSum(self, sum, name):
759759
ptype = "void*"
760760
if is_simple(sum):
761761
ptype = get_c_type(name)
762-
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0)
762+
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
763763
for t in sum.types:
764764
self.visitConstructor(t, name)
765765

@@ -1734,16 +1734,16 @@ def visitModule(self, mod):
17341734
17351735
/* Conversion AST -> Python */
17361736
1737-
static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
1738-
PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
1737+
static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq,
1738+
PyObject* (*func)(struct ast_state *state, void*))
17391739
{
17401740
Py_ssize_t i, n = asdl_seq_LEN(seq);
17411741
PyObject *result = PyList_New(n);
17421742
PyObject *value;
17431743
if (!result)
17441744
return NULL;
17451745
for (i = 0; i < n; i++) {
1746-
value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
1746+
value = func(state, asdl_seq_GET_UNTYPED(seq, i));
17471747
if (!value) {
17481748
Py_DECREF(result);
17491749
return NULL;
@@ -1753,7 +1753,7 @@ def visitModule(self, mod):
17531753
return result;
17541754
}
17551755
1756-
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
1756+
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
17571757
{
17581758
PyObject *op = (PyObject*)o;
17591759
if (!op) {
@@ -1765,7 +1765,7 @@ def visitModule(self, mod):
17651765
#define ast2obj_identifier ast2obj_object
17661766
#define ast2obj_string ast2obj_object
17671767
1768-
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
1768+
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
17691769
{
17701770
return PyLong_FromLong(b);
17711771
}
@@ -2014,25 +2014,23 @@ class ObjVisitor(PickleVisitor):
20142014
def func_begin(self, name):
20152015
ctype = get_c_type(name)
20162016
self.emit("PyObject*", 0)
2017-
self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0)
2017+
self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
20182018
self.emit("{", 0)
20192019
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
20202020
self.emit("PyObject *result = NULL, *value = NULL;", 1)
20212021
self.emit("PyTypeObject *tp;", 1)
20222022
self.emit('if (!o) {', 1)
20232023
self.emit("Py_RETURN_NONE;", 2)
20242024
self.emit("}", 1)
2025-
self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1)
2026-
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
2027-
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
2025+
self.emit('if (Py_EnterRecursiveCall("during ast construction")) {', 1)
20282026
self.emit("return NULL;", 2)
20292027
self.emit("}", 1)
20302028

20312029
def func_end(self):
2032-
self.emit("vstate->recursion_depth--;", 1)
2030+
self.emit("Py_LeaveRecursiveCall();", 1)
20332031
self.emit("return result;", 1)
20342032
self.emit("failed:", 0)
2035-
self.emit("vstate->recursion_depth--;", 1)
2033+
self.emit("Py_LeaveRecursiveCall();", 1)
20362034
self.emit("Py_XDECREF(value);", 1)
20372035
self.emit("Py_XDECREF(result);", 1)
20382036
self.emit("return NULL;", 1)
@@ -2050,15 +2048,15 @@ def visitSum(self, sum, name):
20502048
self.visitConstructor(t, i + 1, name)
20512049
self.emit("}", 1)
20522050
for a in sum.attributes:
2053-
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
2051+
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
20542052
self.emit("if (!value) goto failed;", 1)
20552053
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
20562054
self.emit('goto failed;', 2)
20572055
self.emit('Py_DECREF(value);', 1)
20582056
self.func_end()
20592057

20602058
def simpleSum(self, sum, name):
2061-
self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0)
2059+
self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
20622060
self.emit("{", 0)
20632061
self.emit("switch(o) {", 1)
20642062
for t in sum.types:
@@ -2076,7 +2074,7 @@ def visitProduct(self, prod, name):
20762074
for field in prod.fields:
20772075
self.visitField(field, name, 1, True)
20782076
for a in prod.attributes:
2079-
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
2077+
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
20802078
self.emit("if (!value) goto failed;", 1)
20812079
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
20822080
self.emit('goto failed;', 2)
@@ -2117,7 +2115,7 @@ def set(self, field, value, depth):
21172115
self.emit("for(i = 0; i < n; i++)", depth+1)
21182116
# This cannot fail, so no need for error handling
21192117
self.emit(
2120-
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format(
2118+
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
21212119
field.type,
21222120
value
21232121
),
@@ -2126,9 +2124,9 @@ def set(self, field, value, depth):
21262124
)
21272125
self.emit("}", depth)
21282126
else:
2129-
self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
2127+
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
21302128
else:
2131-
self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False)
2129+
self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)
21322130

21332131

21342132
class PartingShots(StaticVisitor):
@@ -2140,28 +2138,8 @@ class PartingShots(StaticVisitor):
21402138
if (state == NULL) {
21412139
return NULL;
21422140
}
2141+
PyObject *result = ast2obj_mod(state, t);
21432142
2144-
int starting_recursion_depth;
2145-
/* Be careful here to prevent overflow. */
2146-
PyThreadState *tstate = _PyThreadState_GET();
2147-
if (!tstate) {
2148-
return NULL;
2149-
}
2150-
struct validator vstate;
2151-
vstate.recursion_limit = Py_C_RECURSION_LIMIT;
2152-
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
2153-
starting_recursion_depth = recursion_depth;
2154-
vstate.recursion_depth = starting_recursion_depth;
2155-
2156-
PyObject *result = ast2obj_mod(state, &vstate, t);
2157-
2158-
/* Check that the recursion depth counting balanced correctly */
2159-
if (result && vstate.recursion_depth != starting_recursion_depth) {
2160-
PyErr_Format(PyExc_SystemError,
2161-
"AST constructor recursion depth mismatch (before=%d, after=%d)",
2162-
starting_recursion_depth, vstate.recursion_depth);
2163-
return NULL;
2164-
}
21652143
return result;
21662144
}
21672145
@@ -2293,11 +2271,6 @@ def generate_module_def(mod, metadata, f, internal_h):
22932271
#include "structmember.h"
22942272
#include <stddef.h>
22952273
2296-
struct validator {
2297-
int recursion_depth; /* current recursion depth */
2298-
int recursion_limit; /* recursion limit */
2299-
};
2300-
23012274
// Forward declaration
23022275
static int init_types(void *arg);
23032276

0 commit comments

Comments
 (0)