Skip to content

Commit e20523c

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Make api_test.py work when test cases are run using multiple threads.
* keep track of all known config.State objects so we can find them by name. * change `@jtu.with_config` to default to setting thread-local configurations. * add a `@jtu.with_global_config` for those things that truly need to be set globally. * add a `@jtu.thread_local_config_context` that overrides thread-local configuration options, just as `jtu.global_config_context` overrides global configuration options. * change the pretty printer color option to be a State so it can be set locally. * tag a number of tests as thread-hostile, in particular tests that check counters for numbers of compilations, rely on garbage collection having particular semantics, or look at log output. PiperOrigin-RevId: 713411171
1 parent c4ac0dd commit e20523c

File tree

9 files changed

+75
-28
lines changed

9 files changed

+75
-28
lines changed

jax/_src/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ def trace_context():
233233
class NoDefault: pass
234234
no_default = NoDefault()
235235

236+
config_states = {}
237+
236238
class State(config_ext.Config[_T]):
237239

238240
__slots__ = (
@@ -265,6 +267,7 @@ def __init__(
265267
self._validator(default)
266268
if self._update_global_hook:
267269
self._update_global_hook(default)
270+
config_states[name] = self
268271

269272
def __bool__(self) -> NoReturn:
270273
raise TypeError(

jax/_src/pretty_printer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
colorama = None
4343

4444

45-
_PPRINT_USE_COLOR = config.bool_flag(
45+
_PPRINT_USE_COLOR = config.bool_state(
4646
'jax_pprint_use_color',
47-
config.bool_env('JAX_PPRINT_USE_COLOR', True),
47+
True,
4848
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
4949
)
5050

jax/_src/test_util.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,10 @@ def _run_one_test(test: unittest.TestCase, result: ThreadSafeTestResult):
10231023
@contextmanager
10241024
def thread_hostile_test():
10251025
"Decorator for tests that are not thread-safe."
1026+
if TEST_NUM_THREADS.value <= 0:
1027+
yield
1028+
return
1029+
10261030
_test_rwlock.assert_reader_held()
10271031
_test_rwlock.reader_unlock()
10281032
_test_rwlock.writer_lock()
@@ -1157,10 +1161,21 @@ def with_config(**kwds):
11571161
"""Test case decorator for subclasses of JaxTestCase"""
11581162
def decorator(cls):
11591163
assert inspect.isclass(cls) and issubclass(cls, JaxTestCase), "@with_config can only wrap JaxTestCase class definitions."
1160-
cls._default_config = {}
1164+
cls._default_thread_local_config = {}
11611165
for b in cls.__bases__:
1162-
cls._default_config.update(b._default_config)
1163-
cls._default_config.update(kwds)
1166+
cls._default_thread_local_config.update(b._default_thread_local_config)
1167+
cls._default_thread_local_config.update(kwds)
1168+
return cls
1169+
return decorator
1170+
1171+
def with_global_config(**kwds):
1172+
"""Test case decorator for subclasses of JaxTestCase"""
1173+
def decorator(cls):
1174+
assert inspect.isclass(cls) and issubclass(cls, JaxTestCase), "@with_config can only wrap JaxTestCase class definitions."
1175+
cls._default_global_config = {}
1176+
for b in cls.__bases__:
1177+
cls._default_global_config.update(b._default_global_config)
1178+
cls._default_global_config.update(kwds)
11641179
return cls
11651180
return decorator
11661181

@@ -1191,6 +1206,15 @@ def global_config_context(**kwds):
11911206
for key, value in original_config.items():
11921207
config.update(key, value)
11931208

1209+
@contextmanager
1210+
def thread_local_config_context(**kwds):
1211+
stack = ExitStack()
1212+
for config_name, value in kwds.items():
1213+
stack.enter_context(config.config_states[config_name](value))
1214+
try:
1215+
yield
1216+
finally:
1217+
stack.close()
11941218

11951219
class NotPresent:
11961220
def __repr__(self):
@@ -1214,7 +1238,8 @@ def assert_global_configs_unchanged():
12141238

12151239
class JaxTestCase(parameterized.TestCase):
12161240
"""Base class for JAX tests including numerical checks and boilerplate."""
1217-
_default_config = {
1241+
_default_global_config: dict[str, Any] = {}
1242+
_default_thread_local_config = {
12181243
'jax_enable_checks': True,
12191244
'jax_numpy_dtype_promotion': 'strict',
12201245
'jax_numpy_rank_promotion': 'raise',
@@ -1239,7 +1264,9 @@ def setUp(self):
12391264
self._context_stack = ExitStack()
12401265
self.addCleanup(self._context_stack.close)
12411266
stack = self._context_stack
1242-
stack.enter_context(global_config_context(**self._default_config))
1267+
stack.enter_context(global_config_context(**self._default_global_config))
1268+
for config_name, value in self._default_thread_local_config.items():
1269+
stack.enter_context(jax._src.config.config_states[config_name](value))
12431270

12441271
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
12451272
assert TEST_NUM_THREADS.value <= 1, "Persistent compilation cache is not thread-safe."
@@ -1365,7 +1392,7 @@ def wrapped_fun(*args):
13651392

13661393
cache_misses = dispatch.xla_primitive_callable.cache_info().misses
13671394
python_ans = fun(*args)
1368-
if check_cache_misses:
1395+
if check_cache_misses and TEST_NUM_THREADS.value <= 1:
13691396
self.assertEqual(
13701397
cache_misses, dispatch.xla_primitive_callable.cache_info().misses,
13711398
"Compilation detected during second call of {} in op-by-op "

jax/experimental/jax2tf/tests/shape_poly_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2670,7 +2670,7 @@ def test_harness(self, harness: PolyHarness):
26702670
if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]):
26712671
raise unittest.SkipTest("JAX implements eig only on CPU.")
26722672

2673-
with jtu.global_config_context(**harness.override_jax_config_flags):
2673+
with jtu.thread_local_config_context(**harness.override_jax_config_flags):
26742674
harness.run_test(self)
26752675

26762676

tests/api_test.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ def f(x):
632632
python_should_be_executing = False
633633
jit(f)(3)
634634

635+
@jtu.thread_hostile_test()
635636
def test_jit_cache_clear(self):
636637
@jit
637638
def f(x, y):
@@ -2591,6 +2592,7 @@ def test_block_until_ready_mixed(self):
25912592
self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False)
25922593
self.assertEqual(pytree[3], 4)
25932594

2595+
@jtu.thread_hostile_test()
25942596
def test_devicearray_weakref_friendly(self):
25952597
x = device_put(1.)
25962598
y = weakref.ref(x)
@@ -2739,6 +2741,7 @@ def f(x):
27392741

27402742
self.assertEqual(count(), 1)
27412743

2744+
@jtu.thread_hostile_test()
27422745
def test_jit_infer_params_cache(self):
27432746
def f(x):
27442747
return x
@@ -3329,6 +3332,7 @@ def test_grad_object_array_error(self):
33293332
with self.assertRaisesRegex(TypeError, ".*is not a valid JAX type"):
33303333
jax.grad(lambda x: x)(x)
33313334

3335+
@jtu.thread_hostile_test()
33323336
def test_jit_compilation_time_logging(self):
33333337
@api.jit
33343338
def f(x):
@@ -3417,6 +3421,7 @@ def test_trivial_computations(self):
34173421
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
34183422
self.assertEqual(z2, 1)
34193423

3424+
@jtu.thread_hostile_test()
34203425
def test_nested_jit_hoisting(self):
34213426
@api.jit
34223427
def f(x, y):
@@ -3454,6 +3459,7 @@ def mlir_jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
34543459
self.assertEqual(inner_jaxpr.eqns[-2].primitive.name, 'mul')
34553460
self.assertEqual(inner_jaxpr.eqns[-1].primitive.name, 'add')
34563461

3462+
@jtu.thread_hostile_test()
34573463
def test_primitive_compilation_cache(self):
34583464
with jtu.count_primitive_compiles() as count:
34593465
lax.add(1, 2)
@@ -4013,13 +4019,17 @@ def __jax_array__(self):
40134019
a2 = jnp.array(((x, x), [x, x]))
40144020
self.assertAllClose(np.array(((1, 1), (1, 1))), a2)
40154021

4022+
@jtu.thread_hostile_test()
40164023
def test_eval_shape_weak_type(self):
40174024
# https://github.com/jax-ml/jax/issues/23302
40184025
arr = jax.numpy.array(1)
40194026

4027+
def f(x):
4028+
return jax.numpy.array(x)
4029+
40204030
with jtu.count_jit_tracing_cache_miss() as count:
4021-
jax.eval_shape(jax.numpy.array, 1)
4022-
out = jax.eval_shape(jax.numpy.array, 1)
4031+
jax.eval_shape(f, 1)
4032+
out = jax.eval_shape(f, 1)
40234033

40244034
self.assertEqual(count(), 1)
40254035
self.assertTrue(out.weak_type)
@@ -4138,6 +4148,7 @@ def test_dot_precision_flag(self):
41384148
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
41394149
self.assertIn('Precision.HIGH', str(jaxpr))
41404150

4151+
@jtu.thread_hostile_test()
41414152
def test_dot_precision_forces_retrace(self):
41424153
num_traces = 0
41434154

@@ -4310,6 +4321,7 @@ def test_jnp_array_doesnt_device_put(self):
43104321
api.make_jaxpr(lambda: jnp.array(3))()
43114322
self.assertEqual(count(), 0)
43124323

4324+
@jtu.thread_hostile_test()
43134325
def test_rank_promotion_forces_retrace(self):
43144326
num_traces = 0
43154327

@@ -4328,7 +4340,7 @@ def f_jit(x):
43284340

43294341
for f in [f_jit, f_cond]:
43304342
# Use _read() to read the flag value rather than threadlocal value.
4331-
allow_promotion = config._read("jax_numpy_rank_promotion")
4343+
allow_promotion = jax.numpy_rank_promotion.get_global()
43324344
try:
43334345
config.update("jax_numpy_rank_promotion", "allow")
43344346
num_traces = 0
@@ -4350,9 +4362,9 @@ def f(x):
43504362
self.assertGreaterEqual(num_traces, 2)
43514363
nt = num_traces
43524364
f(x)
4353-
self.assertEqual(num_traces, nt + 1)
4365+
self.assertEqual(num_traces, nt)
43544366
f(x)
4355-
self.assertEqual(num_traces, nt + 1)
4367+
self.assertEqual(num_traces, nt)
43564368
finally:
43574369
config.update("jax_numpy_rank_promotion", allow_promotion)
43584370

@@ -4450,6 +4462,7 @@ def foo(x, y, z):
44504462
self.assertEqual(jfoo.__qualname__, f"make_jaxpr({foo.__qualname__})")
44514463
self.assertEqual(jfoo.__module__, "jax")
44524464

4465+
@jtu.thread_hostile_test()
44534466
def test_inner_jit_function_retracing(self):
44544467
# https://github.com/jax-ml/jax/issues/7155
44554468
inner_count = outer_count = 0
@@ -4691,6 +4704,7 @@ def test_mesh_creation_error_message(self):
46914704
with self.assertRaisesRegex(ValueError, "ndim of its first argument"):
46924705
jax.sharding.Mesh(jax.devices(), ("x", "y"))
46934706

4707+
@jtu.thread_hostile_test()
46944708
def test_jit_boundmethod_reference_cycle(self):
46954709
class A:
46964710
def __init__(self):
@@ -4829,6 +4843,7 @@ class RematTest(jtu.JaxTestCase):
48294843
('_policy', partial(jax.remat, policy=lambda *_, **__: False)),
48304844
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
48314845
])
4846+
@jtu.thread_hostile_test()
48324847
def test_remat_basic(self, remat):
48334848
@remat
48344849
def g(x):
@@ -5166,6 +5181,7 @@ def f_yesremat(x):
51665181
('_policy', partial(jax.remat, policy=lambda *_, **__: False)),
51675182
('_new', partial(new_checkpoint, policy=lambda *_, **__: False)),
51685183
])
5184+
@jtu.thread_hostile_test()
51695185
def test_remat_no_redundant_flops(self, remat):
51705186
# see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584
51715187

@@ -6409,6 +6425,7 @@ def f(x):
64096425
self.assertIn(' sin ', str(jaxpr))
64106426
self.assertIn(' cos ', str(jaxpr))
64116427

6428+
@jtu.thread_hostile_test()
64126429
def test_remat_residual_logging(self):
64136430
def f(x):
64146431
x = jnp.sin(x)
@@ -9626,11 +9643,8 @@ def foo_bwd(_, g):
96269643

96279644
foo.defvjp(foo_fwd, foo_bwd)
96289645

9629-
try:
9630-
jax.config.update('jax_custom_vjp_disable_shape_check', True)
9646+
with config.custom_vjp_disable_shape_check(True):
96319647
jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4))
9632-
finally:
9633-
jax.config.update('jax_custom_vjp_disable_shape_check', False)
96349648

96359649
def test_bwd_rule_can_produce_list_or_tuple(self):
96369650
@jax.custom_vjp
@@ -11114,6 +11128,8 @@ def test_autodidax_smoketest(self):
1111411128
spec.loader.exec_module(autodidax_module)
1111511129

1111611130
class GarbageCollectionTest(jtu.JaxTestCase):
11131+
11132+
@jtu.thread_hostile_test()
1111711133
def test_xla_gc_callback(self):
1111811134
# https://github.com/jax-ml/jax/issues/14882
1111911135
x_np = np.arange(10, dtype='int32')

tests/mock_gpu_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
NUM_SHARDS = 4
2929

3030

31-
@jtu.with_config(mock_num_gpu_processes=NUM_SHARDS)
31+
@jtu.with_global_config(mock_num_gpu_processes=NUM_SHARDS)
3232
class MockGPUTest(jtu.JaxTestCase):
3333

3434
def setUp(self):

tests/mock_gpu_topology_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
NUM_HOSTS_PER_SLICE = 4
2626

2727

28-
@jtu.with_config(
28+
@jtu.with_global_config(
2929
jax_mock_gpu_topology=f"{NUM_SLICES}x{NUM_HOSTS_PER_SLICE}x1",
3030
jax_cuda_visible_devices="0")
3131
class MockGPUTopologyTest(jtu.JaxTestCase):

tests/pmap_test.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2185,21 +2185,22 @@ def test_cache_hits_across_threads(self):
21852185
f = lambda x: x+1
21862186
inputs = np.zeros([jax.device_count()], dtype=np.float32)
21872187
pmaped_f = self.pmap(f)
2188-
pmaped_f(inputs)
2189-
self.assertEqual(pmaped_f._cache_size, 1)
2188+
self.assertEqual(pmaped_f._cache_size, 0)
21902189

2191-
# Note: We do not call jax.pmap in the other thread but we reuse the same
2192-
# object.
2190+
# We only call pmaped_f in the thread pool to make sure that any
2191+
# thread-local config settings are identical.
21932192
futures = []
2194-
with ThreadPoolExecutor(max_workers=1) as executor:
2195-
futures.append(executor.submit(lambda: pmaped_f(inputs)))
2193+
with ThreadPoolExecutor(max_workers=2) as executor:
2194+
for _ in range(8):
2195+
futures.append(executor.submit(lambda: pmaped_f(inputs)))
21962196
outputs = [f.result() for f in futures]
21972197

2198-
np.testing.assert_array_equal(pmaped_f(inputs), outputs[0])
21992198
if pmaped_f._cache_size != 1:
22002199
print(pmaped_f._debug_cache_keys())
22012200
self.assertEqual(pmaped_f._cache_size, 1)
22022201

2202+
np.testing.assert_array_equal(pmaped_f(inputs), outputs[0])
2203+
22032204
def test_cache_uses_jax_key(self):
22042205
f = lambda x: x+1
22052206
inputs = np.zeros([jax.device_count()], dtype=np.float32)

tests/shape_poly_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3683,7 +3683,7 @@ def test_harness(self, harness: PolyHarness):
36833683
if "cholesky" in harness.group_name and jtu.test_device_matches(["tpu"]):
36843684
harness.tol = 5e-5
36853685

3686-
with jtu.global_config_context(**config_flags):
3686+
with jtu.thread_local_config_context(**config_flags):
36873687
harness.run_test(self)
36883688

36893689

0 commit comments

Comments
 (0)