nested control flow #12970
Unanswered
yiminghwang
asked this question in
Q&A
nested control flow
#12970
Replies: 1 comment 2 replies
-
The answer depends whether the conditions are static or dynamic. For example, assuming import jax
from functools import partial
@partial(jax.jit, static_argnames=['op'])
def f(x, y, op=3):
if op == 2:
return x * jnp.where(y == 0, 1j, -1j)
elif op == 3:
return x * jnp.where(y == 1, -1, 1)
else:
raise NotImplementedError(f"{op=} not supported") |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, there,
In order to use jit, I have to use jax.lax.cond to implement the condition flow, is there any elegant way to write the nested if?
for example,
original one,
the version of using lax.cond
Beta Was this translation helpful? Give feedback.
All reactions