Skip to content

Allow for nested chex.chexifyΒ #306

@nicow-elia

Description

@nicow-elia

Hello, I have a dilemma with chexify - consider the following code:

# If this is not commented out, the second test will fail
# If this is commented out, the first test will fail
@chex.chexify
@jax.jit
def log_safe(x: jnp.array) -> jnp.array:
    chex.assert_trees_all_equal(x > 0, jnp.ones_like(x, dtype=bool))
    return jnp.log(x)

@chex.chexify
@jax.jit
def combo_safe(x: jnp.array) -> jnp.array:
    chex.assert_trees_all_equal(x != 1, jnp.ones_like(x, dtype=bool))
    return log_safe(x) / (x - 1)


def test_log_safe() -> None:
    x = jnp.array([1.0, 2.0, 3.0, -1.0])
    with pytest.raises(Exception):
        log_safe(x)
        log_safe.wait_checks()

    x = jnp.array([1.0, 2.0, 3.0, 4.0])
    assert jnp.array_equal(log_safe(x), jnp.log(x))
    log_safe.wait_checks()

def test_combo_safe() -> None:
    x = jnp.array([1.0, 2.0, 3.0, 4.0])
    with pytest.raises(Exception):
        combo_safe(x)
        combo_safe.wait_checks()

    x = jnp.array([2.0, 3.0, 4.0, 5.0])
    assert jnp.array_equal(combo_safe(x), jnp.log(x) / (x - 1))
    combo_safe.wait_checks()

If I comment out the first chexify the test_log_safe test will fail with RuntimeError: Value assertions can only be called from functions wrapped with @chex.chexify. See the docs. which makes sense to me. However, once I add the decorator back in, the second test fails with RuntimeError: Nested @chexify wrapping is disallowed. Make sure that you only wrap the function at the outermost level.

A hack in this simple scenario would be to make two versions of the function, a log_safe without the chexify decorator and a log_safe_test = chex.chexify(log_safe) and only call the log_safe_test version during my test. However, that solution is pretty clumsy, especially if I have a lot of these scenarios. In a codebase that is fully end-to-end jax, that would mean all but the outermost function require this hack. Would it be possible to allow for nested chex.chexify where subsequent applications of the macro simply do nothing, or just raise a warning?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions