Skip to content

Commit 988a120

Browse files
ayaka14732Google-ML-Automation
authored andcommitted
Better error message when raise_if_error() is called within a traced context
PiperOrigin-RevId: 735557928
1 parent aceae84 commit 988a120

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

jax/_src/error_check.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,17 @@ def raise_if_error() -> None:
8888
"""Raise error if an error is set.
8989
9090
This function should be called after the computation is finished. It should
91-
be used outside jit.
91+
not be called within a traced context, such as within a jitted function."
9292
"""
9393
if _error_storage.ref is None: # if not initialized, do nothing
9494
return
9595

9696
error_code = _error_storage.ref[...]
97+
if isinstance(error_code, core.Tracer):
98+
raise ValueError(
99+
"raise_if_error() should not be called within a traced context, such as"
100+
" within a jitted function."
101+
)
97102
if error_code == jnp.uint32(_NO_ERROR):
98103
return
99104
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)

tests/error_check_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,26 @@ def body(init, xs):
170170
_ = body(init, xs)
171171
error_check.raise_if_error() # should not raise error
172172

173+
@parameterized.product(jit=[True, False])
174+
def test_raise_if_error_fails_in_traced_context(self, jit):
175+
def f(x):
176+
error_check.set_error_if(x <= 0, "x must be greater than 0")
177+
return x + 1
178+
179+
if jit:
180+
f = jax.jit(f)
181+
182+
x = jnp.full((4,), 1, dtype=jnp.int32)
183+
f(x)
184+
with self.assertRaises(
185+
ValueError,
186+
msg=(
187+
"raise_if_error() should not be called within a traced context,"
188+
" such as within a jitted function."
189+
),
190+
):
191+
jax.jit(error_check.raise_if_error)()
192+
173193

174194
if __name__ == "__main__":
175195
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)