Skip to content

Commit 2d50dd2

Browse files
authored
gh-137422: Fix race condition in PyImport_AddModuleRef (gh-141822)
1 parent 019c315 commit 2d50dd2

File tree

3 files changed

+68
-6
lines changed

3 files changed

+68
-6
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import ctypes
2+
import sys
3+
import unittest
4+
5+
from test.support import threading_helper
6+
from test.support.threading_helper import run_concurrently
7+
8+
9+
_PyImport_AddModuleRef = ctypes.pythonapi.PyImport_AddModuleRef
10+
_PyImport_AddModuleRef.argtypes = (ctypes.c_char_p,)
11+
_PyImport_AddModuleRef.restype = ctypes.py_object
12+
13+
14+
@threading_helper.requires_working_threading()
15+
class TestImportCAPI(unittest.TestCase):
16+
def test_pyimport_addmoduleref_thread_safe(self):
17+
# gh-137422: Concurrent calls to PyImport_AddModuleRef with the same
18+
# module name must return the same module object.
19+
20+
NUM_ITERS = 10
21+
NTHREADS = 4
22+
23+
module_name = f"test_free_threading_addmoduleref_{id(self)}"
24+
module_name_bytes = module_name.encode()
25+
sys.modules.pop(module_name, None)
26+
results = []
27+
28+
def worker():
29+
module = _PyImport_AddModuleRef(module_name_bytes)
30+
results.append(module)
31+
32+
for _ in range(NUM_ITERS):
33+
try:
34+
run_concurrently(worker_func=worker, nthreads=NTHREADS)
35+
self.assertEqual(len(results), NTHREADS)
36+
reference = results[0]
37+
for module in results[1:]:
38+
self.assertIs(module, reference)
39+
self.assertIn(module_name, sys.modules)
40+
self.assertIs(sys.modules[module_name], reference)
41+
finally:
42+
results.clear()
43+
sys.modules.pop(module_name, None)
44+
45+
46+
if __name__ == "__main__":
47+
unittest.main()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Fix :term:`free threading` race condition in
2+
:c:func:`PyImport_AddModuleRef`. It was previously possible for two calls to
3+
the function return two different objects, only one of which was stored in
4+
:data:`sys.modules`.

Python/import.c

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "Python.h"
44
#include "pycore_audit.h" // _PySys_Audit()
55
#include "pycore_ceval.h"
6+
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
67
#include "pycore_hashtable.h" // _Py_hashtable_new_full()
78
#include "pycore_import.h" // _PyImport_BootstrapImp()
89
#include "pycore_initconfig.h" // _PyStatus_OK()
@@ -309,13 +310,8 @@ PyImport_GetModule(PyObject *name)
309310
if not, create a new one and insert it in the modules dictionary. */
310311

311312
static PyObject *
312-
import_add_module(PyThreadState *tstate, PyObject *name)
313+
import_add_module_lock_held(PyObject *modules, PyObject *name)
313314
{
314-
PyObject *modules = get_modules_dict(tstate, false);
315-
if (modules == NULL) {
316-
return NULL;
317-
}
318-
319315
PyObject *m;
320316
if (PyMapping_GetOptionalItem(modules, name, &m) < 0) {
321317
return NULL;
@@ -335,6 +331,21 @@ import_add_module(PyThreadState *tstate, PyObject *name)
335331
return m;
336332
}
337333

334+
static PyObject *
335+
import_add_module(PyThreadState *tstate, PyObject *name)
336+
{
337+
PyObject *modules = get_modules_dict(tstate, false);
338+
if (modules == NULL) {
339+
return NULL;
340+
}
341+
342+
PyObject *m;
343+
Py_BEGIN_CRITICAL_SECTION(modules);
344+
m = import_add_module_lock_held(modules, name);
345+
Py_END_CRITICAL_SECTION();
346+
return m;
347+
}
348+
338349
PyObject *
339350
PyImport_AddModuleRef(const char *name)
340351
{

0 commit comments

Comments
 (0)