-
Notifications
You must be signed in to change notification settings - Fork 64
Open
Description
import jax
import chex
@chex.chexify
@jax.jit
def f(x):
chex.assert_equal(x, 0)
return
f(0)This minimal example leads to the following error:
RuntimeError: Chex assertion detected `ConcretizationTypeError`: it is very likely that it tried to access tensors' values during tracing. Make sure that you d
efined a jittable version of this chex assertion; if that does not help, please file a bug.
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=
off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/franzsrambical/Documents/pdoom/Stoix/temp_chex_test.py", line 10, in <module>
f(0)
File "/home/franzsrambical/Documents/pdoom/Stoix/.venv/lib/python3.10/site-packages/chex/_src/asserts_chexify.py", line 210, in _chexified_fn
err, out = checkified_fn(*args, **kwargs)
File "/home/franzsrambical/Documents/pdoom/Stoix/temp_chex_test.py", line 7, in f
chex.assert_equal(x, 0)
File "/home/franzsrambical/Documents/pdoom/Stoix/.venv/lib/python3.10/site-packages/chex/_src/asserts_internal.py", line 290, in _chex_assert_fn
raise exc from RuntimeError(msg)
File "/home/franzsrambical/Documents/pdoom/Stoix/.venv/lib/python3.10/site-packages/chex/_src/asserts_internal.py", line 278, in _chex_assert_fn
host_assertion_fn(
File "/home/franzsrambical/Documents/pdoom/Stoix/.venv/lib/python3.10/site-packages/chex/_src/asserts_internal.py", line 168, in _assert_on_host
assert_fn(*args, **kwargs)
File "/home/franzsrambical/Documents/pdoom/Stoix/.venv/lib/python3.10/site-packages/chex/_src/asserts.py", line 232, in assert_equal
unittest.TestCase().assertEqual(first, second)
File "/home/franzsrambical/.local/share/uv/python/cpython-3.10.16-linux-x86_64-gnu/lib/python3.10/unittest/case.py", line 845, in assertEqual
assertion_func(first, second, msg=msg)
File "/home/franzsrambical/.local/share/uv/python/cpython-3.10.16-linux-x86_64-gnu/lib/python3.10/unittest/case.py", line 835, in _baseAssertEqual
if not first == second:
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /home/franzsrambical/Documents/pdoom/Stoix/temp_chex_test.py:4 for jit. This concrete value was not availabl
e in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
WARNING:absl:[Chex] Some of chexify assetion statuses were not inspected due to async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html). Consider
calling `chex.block_until_chexify_assertions_complete()` at the end of computations that rely on jitted chex assetions.
chex version: 0.1.88
Metadata
Metadata
Assignees
Labels
No labels