Skip to content
Discussion options

You must be logged in to vote

I think this will take a bit more work to tackle in general in JAX. The checkify was presumably there for good reason: simply disabling it might mean that e.g. that branch results in an infinite loop.

So in general I think you probably want to do something like

def f_true(x, _):
    ...

def f_false(_, x):
    ...

pred = x == 0
safe_true_x = jnp.where(pred, x, some_safe_value)
safe_false_x = jnp.where(pred, some_other_safe_value, x)
lax.cond(pred, f_true, f_false, safe_true_x, safe_false_x)

where some_safe_value and some_other_safe_value are something you know will pass any checkifys in f_true and f_false, not get caught in any infinite loops, etc.

Side note, you may like error_if as an …

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@SobhanMP
Comment options

@patrick-kidger
Comment options

@SobhanMP
Comment options

Answer selected by SobhanMP
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants