@@ -41,16 +41,24 @@ class JaxValueError(ValueError):
4141"""
4242
4343
44- _error_code_ref : core .MutableArray | None = None
4544_error_list_lock = threading .Lock ()
4645_error_list : list [tuple [str , Traceback ]] = [] # (error_message, traceback) pair
4746
4847
48+ class _ErrorStorage (threading .local ):
49+
50+ def __init__ (self ):
51+ self .ref : core .MutableArray | None = None
52+
53+
54+ _error_storage = _ErrorStorage ()
55+
56+
4957def _initialize_error_code_ref () -> None :
58+ """Initialize error_code_ref in the current thread."""
5059 with core .eval_context ():
51- global _error_code_ref
5260 error_code = jnp .uint32 (_NO_ERROR )
53- _error_code_ref = core .mutable_array (error_code )
61+ _error_storage . ref = core .mutable_array (error_code )
5462
5563
5664def set_error_if (pred : jax .Array , msg : str ) -> None :
@@ -59,9 +67,9 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
5967 If the error is already set, the new error will be ignored. It will not
6068 override the existing error.
6169 """
62- if _error_code_ref is None :
70+ if _error_storage . ref is None :
6371 _initialize_error_code_ref ()
64- assert _error_code_ref is not None
72+ assert _error_storage . ref is not None
6573
6674 traceback = source_info_util .current ().traceback
6775 assert traceback is not None
@@ -70,19 +78,19 @@ def set_error_if(pred: jax.Array, msg: str) -> None:
7078 _error_list .append ((msg , traceback ))
7179
7280 pred = pred .any ()
73- error_code = _error_code_ref [...]
81+ error_code = _error_storage . ref [...]
7482 should_update = jnp .logical_and (pred , error_code == jnp .uint32 (_NO_ERROR ))
7583 error_code = jnp .where (should_update , new_error_code , error_code )
7684 # TODO(ayx): support vmap and shard_map.
77- _error_code_ref [...] = error_code # pytype: disable=unsupported-operands
85+ _error_storage . ref [...] = error_code
7886
7987
8088def raise_if_error () -> None :
8189 """Raise error if an error is set."""
82- if _error_code_ref is None : # if not initialized, do nothing
83- return
90+ if _error_storage . ref is None :
91+ return # if not initialized, do nothing
8492
85- error_code = _error_code_ref [...]
93+ error_code = _error_storage . ref [...]
8694 if error_code == jnp .uint32 (_NO_ERROR ):
8795 return
8896 try :
@@ -92,4 +100,4 @@ def raise_if_error() -> None:
92100 filtered_traceback = traceback_util .filter_traceback (traceback )
93101 raise exc .with_traceback (filtered_traceback )
94102 finally :
95- _error_code_ref [...] = jnp .uint32 (_NO_ERROR )
103+ _error_storage . ref [...] = jnp .uint32 (_NO_ERROR )
0 commit comments