@@ -41,6 +41,7 @@ PyUnstable_AtExit(PyInterpreterState *interp,
4141    callback -> next  =  NULL ;
4242
4343    struct  atexit_state  * state  =  & interp -> atexit ;
44+     _PyAtExit_LockCallbacks (state );
4445    atexit_callback  * top  =  state -> ll_callbacks ;
4546    if  (top  ==  NULL ) {
4647        state -> ll_callbacks  =  callback ;
@@ -49,36 +50,16 @@ PyUnstable_AtExit(PyInterpreterState *interp,
4950        callback -> next  =  top ;
5051        state -> ll_callbacks  =  callback ;
5152    }
53+     _PyAtExit_UnlockCallbacks (state );
5254    return  0 ;
5355}
5456
5557
56- static  void 
57- atexit_delete_cb (struct  atexit_state  * state , int  i )
58- {
59-     atexit_py_callback  * cb  =  state -> callbacks [i ];
60-     state -> callbacks [i ] =  NULL ;
61- 
62-     Py_DECREF (cb -> func );
63-     Py_DECREF (cb -> args );
64-     Py_XDECREF (cb -> kwargs );
65-     PyMem_Free (cb );
66- }
67- 
68- 
6958/* Clear all callbacks without calling them */ 
7059static  void 
7160atexit_cleanup (struct  atexit_state  * state )
7261{
73-     atexit_py_callback  * cb ;
74-     for  (int  i  =  0 ; i  <  state -> ncallbacks ; i ++ ) {
75-         cb  =  state -> callbacks [i ];
76-         if  (cb  ==  NULL )
77-             continue ;
78- 
79-         atexit_delete_cb (state , i );
80-     }
81-     state -> ncallbacks  =  0 ;
62+     PyList_Clear (state -> callbacks );
8263}
8364
8465
@@ -89,23 +70,21 @@ _PyAtExit_Init(PyInterpreterState *interp)
8970    // _PyAtExit_Init() must only be called once 
9071    assert (state -> callbacks  ==  NULL );
9172
92-     state -> callback_len  =  32 ;
93-     state -> ncallbacks  =  0 ;
94-     state -> callbacks  =  PyMem_New (atexit_py_callback * , state -> callback_len );
73+     state -> callbacks  =  PyList_New (0 );
9574    if  (state -> callbacks  ==  NULL ) {
9675        return  _PyStatus_NO_MEMORY ();
9776    }
9877    return  _PyStatus_OK ();
9978}
10079
101- 
10280void 
10381_PyAtExit_Fini (PyInterpreterState  * interp )
10482{
83+     // In theory, there shouldn't be any threads left by now, so we 
84+     // won't lock this. 
10585    struct  atexit_state  * state  =  & interp -> atexit ;
10686    atexit_cleanup (state );
107-     PyMem_Free (state -> callbacks );
108-     state -> callbacks  =  NULL ;
87+     Py_CLEAR (state -> callbacks );
10988
11089    atexit_callback  * next  =  state -> ll_callbacks ;
11190    state -> ll_callbacks  =  NULL ;
@@ -120,35 +99,44 @@ _PyAtExit_Fini(PyInterpreterState *interp)
12099    }
121100}
122101
123- 
124102static  void 
125103atexit_callfuncs (struct  atexit_state  * state )
126104{
127105    assert (!PyErr_Occurred ());
106+     assert (state -> callbacks  !=  NULL );
107+     assert (PyList_CheckExact (state -> callbacks ));
128108
129-     if  (state -> ncallbacks  ==  0 ) {
109+     // Create a copy of the list for thread safety 
110+     PyObject  * copy  =  PyList_GetSlice (state -> callbacks , 0 , PyList_GET_SIZE (state -> callbacks ));
111+     if  (copy  ==  NULL )
112+     {
113+         PyErr_WriteUnraisable (NULL );
130114        return ;
131115    }
132116
133-     for  (int  i  =  state -> ncallbacks  -  1 ; i  >= 0 ; i -- ) {
134-         atexit_py_callback  * cb  =  state -> callbacks [i ];
135-         if  (cb  ==  NULL ) {
136-             continue ;
137-         }
117+     for  (Py_ssize_t  i  =  0 ; i  <  PyList_GET_SIZE (copy ); ++ i ) {
118+         // We don't have to worry about evil borrowed references, because 
119+         // no other threads can access this list. 
120+         PyObject  * tuple  =  PyList_GET_ITEM (copy , i );
121+         assert (PyTuple_CheckExact (tuple ));
122+ 
123+         PyObject  * func  =  PyTuple_GET_ITEM (tuple , 0 );
124+         PyObject  * args  =  PyTuple_GET_ITEM (tuple , 1 );
125+         PyObject  * kwargs  =  PyTuple_GET_ITEM (tuple , 2 );
138126
139-         // bpo-46025: Increment the refcount of cb-> func as the call itself may unregister it 
140-         PyObject *   the_func   =   Py_NewRef ( cb -> func ); 
141-         PyObject   * res   =   PyObject_Call ( cb -> func ,  cb -> args ,  cb -> kwargs );
127+         PyObject   * res   =   PyObject_Call ( func , 
128+                                        args , 
129+                                        kwargs   ==   Py_None  ?  NULL  :  kwargs );
142130        if  (res  ==  NULL ) {
143131            PyErr_FormatUnraisable (
144-                 "Exception ignored in atexit callback %R" , the_func );
132+                 "Exception ignored in atexit callback %R" , func );
145133        }
146134        else  {
147135            Py_DECREF (res );
148136        }
149-         Py_DECREF (the_func );
150137    }
151138
139+     Py_DECREF (copy );
152140    atexit_cleanup (state );
153141
154142    assert (!PyErr_Occurred ());
@@ -194,33 +182,27 @@ atexit_register(PyObject *module, PyObject *args, PyObject *kwargs)
194182                "the first argument must be callable" );
195183        return  NULL ;
196184    }
185+     PyObject  * func_args  =  PyTuple_GetSlice (args , 1 , PyTuple_GET_SIZE (args ));
186+     PyObject  * func_kwargs  =  kwargs ;
197187
198-     struct  atexit_state  * state  =  get_atexit_state ();
199-     if  (state -> ncallbacks  >= state -> callback_len ) {
200-         atexit_py_callback  * * r ;
201-         state -> callback_len  +=  16 ;
202-         size_t  size  =  sizeof (atexit_py_callback * ) *  (size_t )state -> callback_len ;
203-         r  =  (atexit_py_callback * * )PyMem_Realloc (state -> callbacks , size );
204-         if  (r  ==  NULL ) {
205-             return  PyErr_NoMemory ();
206-         }
207-         state -> callbacks  =  r ;
188+     if  (func_kwargs  ==  NULL )
189+     {
190+         func_kwargs  =  Py_None ;
208191    }
209- 
210-     atexit_py_callback   * callback  =   PyMem_Malloc ( sizeof ( atexit_py_callback )); 
211-     if  ( callback   ==   NULL )  {
212-         return  PyErr_NoMemory () ;
192+      PyObject   * callback   =   PyTuple_Pack ( 3 ,  func ,  func_args ,  func_kwargs ); 
193+     if  ( callback  ==    NULL ) 
194+     {
195+         return  NULL ;
213196    }
214197
215-     callback -> args  =  PyTuple_GetSlice (args , 1 , PyTuple_GET_SIZE (args ));
216-     if  (callback -> args  ==  NULL ) {
217-         PyMem_Free (callback );
198+     struct  atexit_state  * state  =  get_atexit_state ();
199+     // atexit callbacks go in a LIFO order 
200+     if  (PyList_Insert (state -> callbacks , 0 , callback ) <  0 )
201+     {
202+         Py_DECREF (callback );
218203        return  NULL ;
219204    }
220-     callback -> func  =  Py_NewRef (func );
221-     callback -> kwargs  =  Py_XNewRef (kwargs );
222- 
223-     state -> callbacks [state -> ncallbacks ++ ] =  callback ;
205+     Py_DECREF (callback );
224206
225207    return  Py_NewRef (func );
226208}
@@ -264,7 +246,33 @@ static PyObject *
264246atexit_ncallbacks (PyObject  * module , PyObject  * unused )
265247{
266248    struct  atexit_state  * state  =  get_atexit_state ();
267-     return  PyLong_FromSsize_t (state -> ncallbacks );
249+     assert (state -> callbacks  !=  NULL );
250+     assert (PyList_CheckExact (state -> callbacks ));
251+     return  PyLong_FromSsize_t (PyList_GET_SIZE (state -> callbacks ));
252+ }
253+ 
254+ static  int 
255+ atexit_unregister_locked (PyObject  * callbacks , PyObject  * func )
256+ {
257+     for  (Py_ssize_t  i  =  0 ; i  <  PyList_GET_SIZE (callbacks ); ++ i ) {
258+         PyObject  * tuple  =  PyList_GET_ITEM (callbacks , i );
259+         assert (PyTuple_CheckExact (tuple ));
260+         PyObject  * to_compare  =  PyTuple_GET_ITEM (tuple , 0 );
261+         int  cmp  =  PyObject_RichCompareBool (func , to_compare , Py_EQ );
262+         if  (cmp  <  0 )
263+         {
264+             return  -1 ;
265+         }
266+         if  (cmp  ==  1 ) {
267+             // We found a callback! 
268+             if  (PyList_SetSlice (callbacks , i , i  +  1 , NULL ) <  0 ) {
269+                 return  -1 ;
270+             }
271+             -- i ;
272+         }
273+     }
274+ 
275+     return  0 ;
268276}
269277
270278PyDoc_STRVAR (atexit_unregister__doc__ ,
@@ -280,22 +288,11 @@ static PyObject *
280288atexit_unregister (PyObject  * module , PyObject  * func )
281289{
282290    struct  atexit_state  * state  =  get_atexit_state ();
283-     for  (int  i  =  0 ; i  <  state -> ncallbacks ; i ++ )
284-     {
285-         atexit_py_callback  * cb  =  state -> callbacks [i ];
286-         if  (cb  ==  NULL ) {
287-             continue ;
288-         }
289- 
290-         int  eq  =  PyObject_RichCompareBool (cb -> func , func , Py_EQ );
291-         if  (eq  <  0 ) {
292-             return  NULL ;
293-         }
294-         if  (eq ) {
295-             atexit_delete_cb (state , i );
296-         }
297-     }
298-     Py_RETURN_NONE ;
291+     int  result ;
292+     Py_BEGIN_CRITICAL_SECTION (state -> callbacks );
293+     result  =  atexit_unregister_locked (state -> callbacks , func );
294+     Py_END_CRITICAL_SECTION ();
295+     return  result  <  0  ? NULL  : Py_None ;
299296}
300297
301298
0 commit comments