|
17 | 17 |
|
18 | 18 | import collections |
19 | 19 | from collections.abc import Callable, Generator, Iterable, Sequence |
| 20 | +from concurrent.futures import ThreadPoolExecutor |
20 | 21 | from contextlib import ExitStack, contextmanager |
21 | 22 | import datetime |
22 | 23 | import functools |
|
31 | 32 | import tempfile |
32 | 33 | import textwrap |
33 | 34 | import threading |
| 35 | +import time |
34 | 36 | from typing import Any, TextIO |
35 | 37 | import unittest |
36 | 38 | import warnings |
|
115 | 117 | 'deterministic, interactive'), |
116 | 118 | ) |
117 | 119 |
|
| 120 | +TEST_NUM_THREADS = config.int_flag( |
| 121 | + 'jax_test_num_threads', 0, |
| 122 | + help='Number of threads to use for running tests. 0 means run everything ' |
| 123 | + 'in the main thread. Using > 1 thread is experimental.' |
| 124 | +) |
| 125 | + |
118 | 126 | # We sanitize test names to ensure they work with "unitttest -k" and |
119 | 127 | # "pytest -k" test filtering. pytest accepts '[' and ']' but unittest -k |
120 | 128 | # does not. We replace sequences of problematic characters with a single '_'. |
@@ -998,8 +1006,140 @@ def sample_product(*args, **kw): |
998 | 1006 | """ |
999 | 1007 | return parameterized.parameters(*sample_product_testcases(*args, **kw)) |
1000 | 1008 |
|
| 1009 | +# We use a reader-writer lock to protect test execution. Tests that may run in |
| 1010 | +# parallel acquire a read lock; tests that are not thread-safe acquire a write |
| 1011 | +# lock. |
| 1012 | +if hasattr(util, 'Mutex'): |
| 1013 | + _test_rwlock = util.Mutex() |
| 1014 | + |
| 1015 | + def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): |
| 1016 | + _test_rwlock.reader_lock() |
| 1017 | + try: |
| 1018 | + test(result) # type: ignore |
| 1019 | + finally: |
| 1020 | + _test_rwlock.reader_unlock() |
| 1021 | + |
| 1022 | + |
| 1023 | + @contextmanager |
| 1024 | + def thread_hostile_test(): |
| 1025 | + "Decorator for tests that are not thread-safe." |
| 1026 | + _test_rwlock.assert_reader_held() |
| 1027 | + _test_rwlock.reader_unlock() |
| 1028 | + _test_rwlock.writer_lock() |
| 1029 | + try: |
| 1030 | + yield |
| 1031 | + finally: |
| 1032 | + _test_rwlock.writer_unlock() |
| 1033 | + _test_rwlock.reader_lock() |
| 1034 | +else: |
| 1035 | + # TODO(phawkins): remove this branch when jaxlib 0.5.0 is the minimum. |
| 1036 | + _test_rwlock = threading.Lock() |
| 1037 | + |
| 1038 | + def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult): |
| 1039 | + _test_rwlock.acquire() |
| 1040 | + try: |
| 1041 | + test(result) # type: ignore |
| 1042 | + finally: |
| 1043 | + _test_rwlock.release() |
| 1044 | + |
| 1045 | + |
| 1046 | + @contextmanager |
| 1047 | + def thread_hostile_test(): |
| 1048 | + yield # No reader-writer lock, so we get no parallelism. |
| 1049 | + |
| 1050 | +class ThreadSafeTestResult: |
| 1051 | + """ |
| 1052 | + Wraps a TestResult to make it thread safe. |
| 1053 | +
|
| 1054 | + We do this by accumulating API calls and applying them in a batch under a |
| 1055 | + lock at the conclusion of each test case. |
| 1056 | +
|
| 1057 | + We duck type instead of inheriting from TestResult because we aren't actually |
| 1058 | + a perfect implementation of TestResult, and would rather get a loud error |
| 1059 | + for things we haven't implemented. |
| 1060 | + """ |
| 1061 | + def __init__(self, lock: threading.Lock, result: unittest.TestResult): |
| 1062 | + self.lock = lock |
| 1063 | + self.test_result = result |
| 1064 | + self.actions: list[Callable] = [] |
| 1065 | + |
| 1066 | + def startTest(self, test: unittest.TestCase): |
| 1067 | + del test |
| 1068 | + self.start_time = time.time() |
| 1069 | + |
| 1070 | + def stopTest(self, test: unittest.TestCase): |
| 1071 | + stop_time = time.time() |
| 1072 | + with self.lock: |
| 1073 | + # We assume test_result is an ABSL _TextAndXMLTestResult, so we can |
| 1074 | + # override how it gets the time. |
| 1075 | + time_getter = self.test_result.time_getter |
| 1076 | + try: |
| 1077 | + self.test_result.time_getter = lambda: self.start_time |
| 1078 | + self.test_result.startTest(test) |
| 1079 | + for callback in self.actions: |
| 1080 | + callback() |
| 1081 | + self.test_result.time_getter = lambda: stop_time |
| 1082 | + self.test_result.stopTest(test) |
| 1083 | + finally: |
| 1084 | + self.test_result.time_getter = time_getter |
| 1085 | + |
| 1086 | + def addSuccess(self, test: unittest.TestCase): |
| 1087 | + self.actions.append(lambda: self.test_result.addSuccess(test)) |
| 1088 | + |
| 1089 | + def addSkip(self, test: unittest.TestCase, reason: str): |
| 1090 | + self.actions.append(lambda: self.test_result.addSkip(test, reason)) |
| 1091 | + |
| 1092 | + def addError(self, test: unittest.TestCase, err): |
| 1093 | + self.actions.append(lambda: self.test_result.addError(test, err)) |
| 1094 | + |
| 1095 | + def addFailure(self, test: unittest.TestCase, err): |
| 1096 | + self.actions.append(lambda: self.test_result.addFailure(test, err)) |
| 1097 | + |
| 1098 | + def addExpectedFailure(self, test: unittest.TestCase, err): |
| 1099 | + self.actions.append(lambda: self.test_result.addExpectedFailure(test, err)) |
| 1100 | + |
| 1101 | + def addDuration(self, test: unittest.TestCase, elapsed): |
| 1102 | + self.actions.append(lambda: self.test_result.addDuration(test, elapsed)) |
| 1103 | + |
| 1104 | + |
| 1105 | +class JaxTestSuite(unittest.TestSuite): |
| 1106 | + """Runs tests in parallel using threads if TEST_NUM_THREADS is > 1. |
| 1107 | +
|
| 1108 | + Caution: this test suite does not run setUpClass or setUpModule methods if |
| 1109 | + thread parallelism is enabled. |
| 1110 | + """ |
| 1111 | + |
| 1112 | + def __init__(self, suite: unittest.TestSuite): |
| 1113 | + super().__init__(list(suite)) |
| 1114 | + |
| 1115 | + def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.TestResult: |
| 1116 | + if TEST_NUM_THREADS.value <= 0: |
| 1117 | + return super().run(result) |
| 1118 | + |
| 1119 | + executor = ThreadPoolExecutor(TEST_NUM_THREADS.value) |
| 1120 | + lock = threading.Lock() |
| 1121 | + futures = [] |
| 1122 | + |
| 1123 | + def run_test(test): |
| 1124 | + "Recursively runs tests in a test suite or test case." |
| 1125 | + if isinstance(test, unittest.TestSuite): |
| 1126 | + for subtest in test: |
| 1127 | + run_test(subtest) |
| 1128 | + else: |
| 1129 | + test_result = ThreadSafeTestResult(lock, result) |
| 1130 | + futures.append(executor.submit(_run_one_test, test, test_result)) |
| 1131 | + |
| 1132 | + with executor: |
| 1133 | + run_test(self) |
| 1134 | + for future in futures: |
| 1135 | + future.result() |
| 1136 | + |
| 1137 | + return result |
| 1138 | + |
1001 | 1139 |
|
1002 | 1140 | class JaxTestLoader(absltest.TestLoader): |
| 1141 | + suiteClass = JaxTestSuite |
| 1142 | + |
1003 | 1143 | def getTestCaseNames(self, testCaseClass): |
1004 | 1144 | names = super().getTestCaseNames(testCaseClass) |
1005 | 1145 | if _TEST_TARGETS.value: |
@@ -1102,11 +1242,11 @@ def setUp(self): |
1102 | 1242 | stack.enter_context(global_config_context(**self._default_config)) |
1103 | 1243 |
|
1104 | 1244 | if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: |
| 1245 | + assert TEST_NUM_THREADS.value <= 1, "Persistent compilation cache is not thread-safe." |
1105 | 1246 | stack.enter_context(config.enable_compilation_cache(True)) |
1106 | 1247 | stack.enter_context(config.raise_persistent_cache_errors(True)) |
1107 | 1248 | stack.enter_context(config.persistent_cache_min_compile_time_secs(0)) |
1108 | 1249 | stack.enter_context(config.persistent_cache_min_entry_size_bytes(0)) |
1109 | | - |
1110 | 1250 | tmp_dir = stack.enter_context(tempfile.TemporaryDirectory()) |
1111 | 1251 | stack.enter_context(config.compilation_cache_dir(tmp_dir)) |
1112 | 1252 | stack.callback(compilation_cache.reset_cache) |
|
0 commit comments