Skip to content
24 changes: 24 additions & 0 deletions Lib/test/support/threading_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,27 @@ def requires_working_threading(*, module=False):
raise unittest.SkipTest(msg)
else:
return unittest.skipUnless(can_start_thread, msg)


def run_concurrently(worker_func, args, nthreads):
"""
Run the worker function concurrently in multiple threads.
"""
barrier = threading.Barrier(nthreads)

def wrapper_func(*args):
# Wait for all threads to reach this point before proceeding.
barrier.wait()
worker_func(*args)

with catch_threading_exception() as cm:
workers = (
threading.Thread(target=wrapper_func, args=args)
for _ in range(nthreads)
)
with start_threads(workers):
pass

# If a worker thread raises an exception, re-raise it.
if cm.exc_value is not None:
raise cm.exc_value
36 changes: 36 additions & 0 deletions Lib/test/test_free_threading/test_grp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

from test.support import import_helper, threading_helper
from test.support.threading_helper import run_concurrently

grp = import_helper.import_module("grp")

from test import test_grp


NTHREADS = 10


@threading_helper.requires_working_threading()
class TestGrp(unittest.TestCase):
def setUp(self):
self.test_grp = test_grp.GroupDatabaseTestCase()

def test_racing_test_values(self):
# test_grp.test_values() calls grp.getgrall() and checks the entries
run_concurrently(
worker_func=self.test_grp.test_values, args=(), nthreads=NTHREADS
)

def test_racing_test_values_extended(self):
# test_grp.test_values_extended() calls grp.getgrall(), grp.getgrgid(),
# grp.getgrnam() and checks the entries
run_concurrently(
worker_func=self.test_grp.test_values_extended,
args=(),
nthreads=NTHREADS,
)


if __name__ == "__main__":
unittest.main()
46 changes: 13 additions & 33 deletions Lib/test/test_free_threading/test_heapq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import heapq

from enum import Enum
from threading import Thread, Barrier, Lock
from threading import Barrier, Lock
from random import shuffle, randint

from test.support import threading_helper
from test.support.threading_helper import run_concurrently
from test import test_heapq


Expand All @@ -28,7 +29,7 @@ def test_racing_heapify(self):
heap = list(range(OBJECT_COUNT))
shuffle(heap)

self.run_concurrently(
run_concurrently(
worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
)
self.test_heapq.check_invariant(heap)
Expand All @@ -40,7 +41,7 @@ def heappush_func(heap):
for item in reversed(range(OBJECT_COUNT)):
heapq.heappush(heap, item)

self.run_concurrently(
run_concurrently(
worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
)
self.test_heapq.check_invariant(heap)
Expand All @@ -61,7 +62,7 @@ def heappop_func(heap, pop_count):
# Each local list should be sorted
self.assertTrue(self.is_sorted_ascending(local_list))

self.run_concurrently(
run_concurrently(
worker_func=heappop_func,
args=(heap, per_thread_pop_count),
nthreads=NTHREADS,
Expand All @@ -77,7 +78,7 @@ def heappushpop_func(heap, pushpop_items):
popped_item = heapq.heappushpop(heap, item)
self.assertTrue(popped_item <= item)

self.run_concurrently(
run_concurrently(
worker_func=heappushpop_func,
args=(heap, pushpop_items),
nthreads=NTHREADS,
Expand All @@ -93,7 +94,7 @@ def heapreplace_func(heap, replace_items):
for item in replace_items:
heapq.heapreplace(heap, item)

self.run_concurrently(
run_concurrently(
worker_func=heapreplace_func,
args=(heap, replace_items),
nthreads=NTHREADS,
Expand All @@ -105,7 +106,7 @@ def test_racing_heapify_max(self):
max_heap = list(range(OBJECT_COUNT))
shuffle(max_heap)

self.run_concurrently(
run_concurrently(
worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
)
self.test_heapq.check_max_invariant(max_heap)
Expand All @@ -117,7 +118,7 @@ def heappush_max_func(max_heap):
for item in range(OBJECT_COUNT):
heapq.heappush_max(max_heap, item)

self.run_concurrently(
run_concurrently(
worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
)
self.test_heapq.check_max_invariant(max_heap)
Expand All @@ -138,7 +139,7 @@ def heappop_max_func(max_heap, pop_count):
# Each local list should be sorted
self.assertTrue(self.is_sorted_descending(local_list))

self.run_concurrently(
run_concurrently(
worker_func=heappop_max_func,
args=(max_heap, per_thread_pop_count),
nthreads=NTHREADS,
Expand All @@ -154,7 +155,7 @@ def heappushpop_max_func(max_heap, pushpop_items):
popped_item = heapq.heappushpop_max(max_heap, item)
self.assertTrue(popped_item >= item)

self.run_concurrently(
run_concurrently(
worker_func=heappushpop_max_func,
args=(max_heap, pushpop_items),
nthreads=NTHREADS,
Expand All @@ -170,7 +171,7 @@ def heapreplace_max_func(max_heap, replace_items):
for item in replace_items:
heapq.heapreplace_max(max_heap, item)

self.run_concurrently(
run_concurrently(
worker_func=heapreplace_max_func,
args=(max_heap, replace_items),
nthreads=NTHREADS,
Expand Down Expand Up @@ -203,7 +204,7 @@ def worker():
except IndexError:
pass

self.run_concurrently(worker, (), n_threads * 2)
run_concurrently(worker, (), n_threads * 2)

@staticmethod
def is_sorted_ascending(lst):
Expand Down Expand Up @@ -241,27 +242,6 @@ def create_random_list(a, b, size):
"""
return [randint(-a, b) for _ in range(size)]

def run_concurrently(self, worker_func, args, nthreads):
"""
Run the worker function concurrently in multiple threads.
"""
barrier = Barrier(nthreads)

def wrapper_func(*args):
# Wait for all threads to reach this point before proceeding.
barrier.wait()
worker_func(*args)

with threading_helper.catch_threading_exception() as cm:
workers = (
Thread(target=wrapper_func, args=args) for _ in range(nthreads)
)
with threading_helper.start_threads(workers):
pass

# Worker threads should not raise any exceptions
self.assertIsNone(cm.exc_value)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make methods in :mod:`grp` thread-safe on the :term:`free threaded <free threading>` build.
19 changes: 19 additions & 0 deletions Modules/grpmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,16 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)

Py_END_ALLOW_THREADS
#else
static PyMutex getgrgid_mutex = {0};
PyMutex_Lock(&getgrgid_mutex);
// The getgrgid() function need not be thread-safe.
// https://pubs.opengroup.org/onlinepubs/9699919799/functions/getgrgid.html
p = getgrgid(gid);
#endif
if (p == NULL) {
#ifndef HAVE_GETGRGID_R
PyMutex_Unlock(&getgrgid_mutex);
#endif
PyMem_RawFree(buf);
if (nomem == 1) {
return PyErr_NoMemory();
Expand All @@ -185,6 +192,8 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)
retval = mkgrent(module, p);
#ifdef HAVE_GETGRGID_R
PyMem_RawFree(buf);
#else
PyMutex_Unlock(&getgrgid_mutex);
#endif
return retval;
}
Expand Down Expand Up @@ -249,9 +258,16 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)

Py_END_ALLOW_THREADS
#else
static PyMutex getgrnam_mutex = {0};
PyMutex_Lock(&getgrnam_mutex);
// The getgrnam() function need not be thread-safe.
// https://pubs.opengroup.org/onlinepubs/9699919799/functions/getgrnam.html
p = getgrnam(name_chars);
#endif
if (p == NULL) {
#ifndef HAVE_GETGRNAM_R
PyMutex_Unlock(&getgrnam_mutex);
#endif
if (nomem == 1) {
PyErr_NoMemory();
}
Expand All @@ -261,6 +277,9 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)
goto out;
}
retval = mkgrent(module, p);
#ifndef HAVE_GETGRNAM_R
PyMutex_Unlock(&getgrnam_mutex);
#endif
out:
PyMem_RawFree(buf);
Py_DECREF(bytes);
Expand Down
Loading