Skip to content

Commit ea53c76

Browse files
ayaka14732Google-ML-Automation
authored andcommitted
Fix thread safety of JAX error checking
Fix thread safety of JAX error checking by making the global states thread local PiperOrigin-RevId: 733164878
1 parent 00d9f45 commit ea53c76

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

jax/_src/error_check.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,24 @@ class JaxValueError(ValueError):
4141
"""
4242

4343

44-
_error_code_ref: core.MutableArray | None = None
4544
_error_list_lock = threading.Lock()
4645
_error_list: list[tuple[str, Traceback]] = [] # (error_message, traceback) pair
4746

4847

48+
class _ErrorStorage(threading.local):
49+
50+
def __init__(self):
51+
self.ref: core.MutableArray | None = None
52+
53+
54+
_error_storage = _ErrorStorage()
55+
56+
4957
def _initialize_error_code_ref() -> None:
58+
"""Initialize error_code_ref in the current thread."""
5059
with core.eval_context():
51-
global _error_code_ref
5260
error_code = jnp.uint32(_NO_ERROR)
53-
_error_code_ref = core.mutable_array(error_code)
61+
_error_storage.ref = core.mutable_array(error_code)
5462

5563

5664
def set_error_if(pred: jax.Array, msg: str) -> None:
@@ -59,9 +67,9 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
5967
If the error is already set, the new error will be ignored. It will not
6068
override the existing error.
6169
"""
62-
if _error_code_ref is None:
70+
if _error_storage.ref is None:
6371
_initialize_error_code_ref()
64-
assert _error_code_ref is not None
72+
assert _error_storage.ref is not None
6573

6674
traceback = source_info_util.current().traceback
6775
assert traceback is not None
@@ -70,19 +78,19 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
7078
_error_list.append((msg, traceback))
7179

7280
pred = pred.any()
73-
error_code = _error_code_ref[...]
81+
error_code = _error_storage.ref[...]
7482
should_update = jnp.logical_and(pred, error_code == jnp.uint32(_NO_ERROR))
7583
error_code = jnp.where(should_update, new_error_code, error_code)
7684
# TODO(ayx): support vmap and shard_map.
77-
_error_code_ref[...] = error_code # pytype: disable=unsupported-operands
85+
_error_storage.ref[...] = error_code
7886

7987

8088
def raise_if_error() -> None:
8189
"""Raise error if an error is set."""
82-
if _error_code_ref is None: # if not initialized, do nothing
83-
return
90+
if _error_storage.ref is None:
91+
return # if not initialized, do nothing
8492

85-
error_code = _error_code_ref[...]
93+
error_code = _error_storage.ref[...]
8694
if error_code == jnp.uint32(_NO_ERROR):
8795
return
8896
try:
@@ -92,4 +100,4 @@ def raise_if_error() -> None:
92100
filtered_traceback = traceback_util.filter_traceback(traceback)
93101
raise exc.with_traceback(filtered_traceback)
94102
finally:
95-
_error_code_ref[...] = jnp.uint32(_NO_ERROR)
103+
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)

tests/error_check_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
config.parse_flags_with_absl()
2929

3030

31-
@jtu.thread_unsafe_test_class() # TODO(b/398772684): fix thread safety
3231
@jtu.with_config(jax_check_tracer_leaks=True)
3332
class ErrorCheckTests(jtu.JaxTestCase):
3433

0 commit comments

Comments
 (0)