Is it possible to early return same shape values in jitted functions with branches? #12822
-
Let's say there is some function we want to jit, and it has 3 locations where it can return, but it returns earlier depending on abstract bools in order to save compute. Each return is of the same shape, so is it somehow possible to write early returns in a easy manner without a ton of jax.lax.cond functions nested in each other. E.g. here's a random function
Is there a easy way to write the above to make it jit compatible? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
As written, no this function is not JIT compatible, because if none of the conditions evaluate to true, the return value is None and None is not a valid JAX value. if you had some fallback return value of the same shape, then one way to make this JIT compatible would be to use nested calls to lax.cond |
Beta Was this translation helpful? Give feedback.
As written, no this function is not JIT compatible, because if none of the conditions evaluate to true, the return value is None and None is not a valid JAX value.
if you had some fallback return value of the same shape, then one way to make this JIT compatible would be to use nested calls to lax.cond