Checkify in vmaped lax.cond/lax.select. #19476
-
I've been trying to debug a code with Checkify and it is great but what would be the "idiomatic" way of dealing with a code like this import os
from typing import Any
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import jax
import jax.numpy as jnp
from jax import vmap, jit, lax
from jax.experimental import checkify
def f(x, y):
def f_true(x):
checkify.check(y == 0, 'y is {} not 0', y)
return x * 2
def f_false(x):
checkify.check(y != 0, 'y is not zero but {}', y)
return x * 3
return lax.cond(x == 0, f_true, f_false, x, y)
f = checkify.checkify(jit(vmap(f)), errors=checkify.all_checks)
err, x = f(jnp.array([0]), jnp.array([0]))
err.throw() Since the Is there a way of fixing this inside jax? Is the solution just making the checks like checkify.check(jnp.logical_or(x != 0, y == 0), 'y is {} not 0', y) ? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I think this will take a bit more work to tackle in general in JAX. The checkify was presumably there for good reason: simply disabling it might mean that e.g. that branch results in an infinite loop. So in general I think you probably want to do something like def f_true(x, _):
...
def f_false(_, x):
...
pred = x == 0
safe_true_x = jnp.where(pred, x, some_safe_value)
safe_false_x = jnp.where(pred, some_other_safe_value, x)
lax.cond(pred, f_true, f_false, safe_true_x, safe_false_x) where Side note, you may like error_if as an alternative to checkify. |
Beta Was this translation helpful? Give feedback.
I think this will take a bit more work to tackle in general in JAX. The checkify was presumably there for good reason: simply disabling it might mean that e.g. that branch results in an infinite loop.
So in general I think you probably want to do something like
where
some_safe_value
andsome_other_safe_value
are something you know will pass any checkifys inf_true
andf_false
, not get caught in any infinite loops, etc.Side note, you may like error_if as an …