Skip to content

Commit 0389d61

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Add a unittest test extension that runs test cases in parallel using threads.
This change does not yet do the work necessary to make any tests pass with threading enabled, which will come in future changes. This approach is broadly inspired by https://github.com/testing-cabal/testtools/blob/a6d205dd4cac51f3cf9267978d39fc877103aacb/testtools/testsuite.py#L113 and by unittest-ft. We add a custom TestResult class that batches up any test result actions and applies them under a lock. We also add a custom TestSuite class that runs individual test cases in parallel using a thread-pool. We need a reader-writer lock to implement a `@jtu.thread_hostile_test` decorator, which we do by adding bindings around absl::Mutex to jaxlib. PiperOrigin-RevId: 713312937
1 parent 3fa5572 commit 0389d61

File tree

4 files changed

+161
-1
lines changed

4 files changed

+161
-1
lines changed

jax/_src/test_util.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import collections
1919
from collections.abc import Callable, Generator, Iterable, Sequence
20+
from concurrent.futures import ThreadPoolExecutor
2021
from contextlib import ExitStack, contextmanager
2122
import datetime
2223
import functools
@@ -31,6 +32,7 @@
3132
import tempfile
3233
import textwrap
3334
import threading
35+
import time
3436
from typing import Any, TextIO
3537
import unittest
3638
import warnings
@@ -115,6 +117,12 @@
115117
'deterministic, interactive'),
116118
)
117119

120+
TEST_NUM_THREADS = config.int_flag(
121+
'jax_test_num_threads', 0,
122+
help='Number of threads to use for running tests. 0 means run everything '
123+
'in the main thread. Using > 1 thread is experimental.'
124+
)
125+
118126
# We sanitize test names to ensure they work with "unitttest -k" and
119127
# "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k
120128
# does not. We replace sequences of problematic characters with a single '_'.
@@ -998,8 +1006,140 @@ def sample_product(*args, **kw):
9981006
"""
9991007
return parameterized.parameters(*sample_product_testcases(*args, **kw))
10001008

1009+
# We use a reader-writer lock to protect test execution. Tests that may run in
1010+
# parallel acquire a read lock; tests that are not thread-safe acquire a write
1011+
# lock.
1012+
if hasattr(util, 'Mutex'):
1013+
_test_rwlock = util.Mutex()
1014+
1015+
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
1016+
_test_rwlock.reader_lock()
1017+
try:
1018+
test(result) # type: ignore
1019+
finally:
1020+
_test_rwlock.reader_unlock()
1021+
1022+
1023+
@contextmanager
1024+
def thread_hostile_test():
1025+
"Decorator for tests that are not thread-safe."
1026+
_test_rwlock.assert_reader_held()
1027+
_test_rwlock.reader_unlock()
1028+
_test_rwlock.writer_lock()
1029+
try:
1030+
yield
1031+
finally:
1032+
_test_rwlock.writer_unlock()
1033+
_test_rwlock.reader_lock()
1034+
else:
1035+
# TODO(phawkins): remove this branch when jaxlib 0.5.0 is the minimum.
1036+
_test_rwlock = threading.Lock()
1037+
1038+
def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
1039+
_test_rwlock.acquire()
1040+
try:
1041+
test(result) # type: ignore
1042+
finally:
1043+
_test_rwlock.release()
1044+
1045+
1046+
@contextmanager
1047+
def thread_hostile_test():
1048+
yield # No reader-writer lock, so we get no parallelism.
1049+
1050+
class ThreadSafeTestResult:
1051+
"""
1052+
Wraps a TestResult to make it thread safe.
1053+
1054+
We do this by accumulating API calls and applying them in a batch under a
1055+
lock at the conclusion of each test case.
1056+
1057+
We duck type instead of inheriting from TestResult because we aren't actually
1058+
a perfect implementation of TestResult, and would rather get a loud error
1059+
for things we haven't implemented.
1060+
"""
1061+
def __init__(self, lock: threading.Lock, result: unittest.TestResult):
1062+
self.lock = lock
1063+
self.test_result = result
1064+
self.actions: list[Callable] = []
1065+
1066+
def startTest(self, test: unittest.TestCase):
1067+
del test
1068+
self.start_time = time.time()
1069+
1070+
def stopTest(self, test: unittest.TestCase):
1071+
stop_time = time.time()
1072+
with self.lock:
1073+
# We assume test_result is an ABSL _TextAndXMLTestResult, so we can
1074+
# override how it gets the time.
1075+
time_getter = self.test_result.time_getter
1076+
try:
1077+
self.test_result.time_getter = lambda: self.start_time
1078+
self.test_result.startTest(test)
1079+
for callback in self.actions:
1080+
callback()
1081+
self.test_result.time_getter = lambda: stop_time
1082+
self.test_result.stopTest(test)
1083+
finally:
1084+
self.test_result.time_getter = time_getter
1085+
1086+
def addSuccess(self, test: unittest.TestCase):
1087+
self.actions.append(lambda: self.test_result.addSuccess(test))
1088+
1089+
def addSkip(self, test: unittest.TestCase, reason: str):
1090+
self.actions.append(lambda: self.test_result.addSkip(test, reason))
1091+
1092+
def addError(self, test: unittest.TestCase, err):
1093+
self.actions.append(lambda: self.test_result.addError(test, err))
1094+
1095+
def addFailure(self, test: unittest.TestCase, err):
1096+
self.actions.append(lambda: self.test_result.addFailure(test, err))
1097+
1098+
def addExpectedFailure(self, test: unittest.TestCase, err):
1099+
self.actions.append(lambda: self.test_result.addExpectedFailure(test, err))
1100+
1101+
def addDuration(self, test: unittest.TestCase, elapsed):
1102+
self.actions.append(lambda: self.test_result.addDuration(test, elapsed))
1103+
1104+
1105+
class JaxTestSuite(unittest.TestSuite):
1106+
"""Runs tests in parallel using threads if TEST_NUM_THREADS is > 1.
1107+
1108+
Caution: this test suite does not run setUpClass or setUpModule methods if
1109+
thread parallelism is enabled.
1110+
"""
1111+
1112+
def __init__(self, suite: unittest.TestSuite):
1113+
super().__init__(list(suite))
1114+
1115+
def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult:
1116+
if TEST_NUM_THREADS.value <= 0:
1117+
return super().run(result)
1118+
1119+
executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
1120+
lock = threading.Lock()
1121+
futures = []
1122+
1123+
def run_test(test):
1124+
"Recursively runs tests in a test suite or test case."
1125+
if isinstance(test, unittest.TestSuite):
1126+
for subtest in test:
1127+
run_test(subtest)
1128+
else:
1129+
test_result = ThreadSafeTestResult(lock, result)
1130+
futures.append(executor.submit(_run_one_test, test, test_result))
1131+
1132+
with executor:
1133+
run_test(self)
1134+
for future in futures:
1135+
future.result()
1136+
1137+
return result
1138+
10011139

10021140
class JaxTestLoader(absltest.TestLoader):
1141+
suiteClass = JaxTestSuite
1142+
10031143
def getTestCaseNames(self, testCaseClass):
10041144
names = super().getTestCaseNames(testCaseClass)
10051145
if _TEST_TARGETS.value:
@@ -1102,11 +1242,11 @@ def setUp(self):
11021242
stack.enter_context(global_config_context(**self._default_config))
11031243

11041244
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
1245+
assert TEST_NUM_THREADS.value <= 1, "Persistent compilation cache is not thread-safe."
11051246
stack.enter_context(config.enable_compilation_cache(True))
11061247
stack.enter_context(config.raise_persistent_cache_errors(True))
11071248
stack.enter_context(config.persistent_cache_min_compile_time_secs(0))
11081249
stack.enter_context(config.persistent_cache_min_entry_size_bytes(0))
1109-
11101250
tmp_dir = stack.enter_context(tempfile.TemporaryDirectory())
11111251
stack.enter_context(config.compilation_cache_dir(tmp_dir))
11121252
stack.callback(compilation_cache.reset_cache)

jax/_src/util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,3 +685,6 @@ def test_event(name: str, *args) -> None:
685685
if not test_event_listener:
686686
return
687687
test_event_listener(name, *args)
688+
689+
if hasattr(jaxlib_utils, "Mutex"):
690+
Mutex = jaxlib_utils.Mutex

jaxlib/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ nanobind_extension(
214214
deps = [
215215
"@com_google_absl//absl/cleanup",
216216
"@com_google_absl//absl/container:inlined_vector",
217+
"@com_google_absl//absl/synchronization",
217218
"@nanobind",
218219
"@xla//third_party/python_runtime:headers",
219220
],

jaxlib/utils.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include "nanobind/nanobind.h"
1919
#include "absl/cleanup/cleanup.h"
2020
#include "absl/container/inlined_vector.h"
21+
#include "absl/synchronization/mutex.h"
2122

2223
namespace nb = nanobind;
2324

@@ -225,4 +226,19 @@ NB_MODULE(utils, m) {
225226
PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr()));
226227
m.attr("safe_zip") = nb::steal<nb::object>(
227228
PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr()));
229+
230+
// Python has no reader-writer lock in its standard library, so we expose
231+
// bindings around absl::Mutex.
232+
nb::class_<absl::Mutex>(m, "Mutex")
233+
.def(nb::init<>())
234+
.def("lock", &absl::Mutex::Lock, nb::call_guard<nb::gil_scoped_release>())
235+
.def("unlock", &absl::Mutex::Unlock)
236+
.def("assert_held", &absl::Mutex::AssertHeld)
237+
.def("reader_lock", &absl::Mutex::ReaderLock,
238+
nb::call_guard<nb::gil_scoped_release>())
239+
.def("reader_unlock", &absl::Mutex::ReaderUnlock)
240+
.def("assert_reader_held", &absl::Mutex::AssertReaderHeld)
241+
.def("writer_lock", &absl::Mutex::WriterLock,
242+
nb::call_guard<nb::gil_scoped_release>())
243+
.def("writer_unlock", &absl::Mutex::WriterUnlock);
228244
}

0 commit comments

Comments
 (0)