Skip to content

jitted, chexified chex.assert_equal fails for value assertions #387

@emergenz

Description

@emergenz
import jax
import chex

@chex.chexify
@jax.jit
def f(x):
    chex.assert_equal(x, 0)
    return

f(0)

This minimal example leads to the following error:

RuntimeError: Chex assertion detected `ConcretizationTypeError`: it is very likely that it tried to access tensors' values during tracing. Make sure that you d
efined a jittable version of this chex assertion; if that does not help, please file a bug.

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=
off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/franzsrambical/Documents/pdoom/Stoix/temp_chex_test.py", line 10, in <module>
    f(0)
  File "/home/franzsrambical/Documents/pdoom/Stoix/.venv/lib/python3.10/site-packages/chex/_src/asserts_chexify.py", line 210, in _chexified_fn
    err, out = checkified_fn(*args, **kwargs)
  File "/home/franzsrambical/Documents/pdoom/Stoix/temp_chex_test.py", line 7, in f
    chex.assert_equal(x, 0)
  File "/home/franzsrambical/Documents/pdoom/Stoix/.venv/lib/python3.10/site-packages/chex/_src/asserts_internal.py", line 290, in _chex_assert_fn
    raise exc from RuntimeError(msg)
  File "/home/franzsrambical/Documents/pdoom/Stoix/.venv/lib/python3.10/site-packages/chex/_src/asserts_internal.py", line 278, in _chex_assert_fn
    host_assertion_fn(
  File "/home/franzsrambical/Documents/pdoom/Stoix/.venv/lib/python3.10/site-packages/chex/_src/asserts_internal.py", line 168, in _assert_on_host
    assert_fn(*args, **kwargs)
  File "/home/franzsrambical/Documents/pdoom/Stoix/.venv/lib/python3.10/site-packages/chex/_src/asserts.py", line 232, in assert_equal
    unittest.TestCase().assertEqual(first, second)
  File "/home/franzsrambical/.local/share/uv/python/cpython-3.10.16-linux-x86_64-gnu/lib/python3.10/unittest/case.py", line 845, in assertEqual
    assertion_func(first, second, msg=msg)
  File "/home/franzsrambical/.local/share/uv/python/cpython-3.10.16-linux-x86_64-gnu/lib/python3.10/unittest/case.py", line 835, in _baseAssertEqual
    if not first == second:
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /home/franzsrambical/Documents/pdoom/Stoix/temp_chex_test.py:4 for jit. This concrete value was not availabl
e in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
WARNING:absl:[Chex] Some of chexify assetion statuses were not inspected due to async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html). Consider
 calling `chex.block_until_chexify_assertions_complete()` at the end of computations that rely on jitted chex assetions.

chex version: 0.1.88

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions