-
I'm trying to use lax.select to try to avoid unnecessary executions of lax.while_loop. case1 = 1 == 0
u0 = 1
u1 = 0 * u0
ii = 0
def raise_error(n):
raise RuntimeError(f'Failed to find a solution in {n} loops')
def cond(state):
u0, u1, ii = state
return (jnp.abs(u1 - u0) > 10 ** (-6)).all()
def body(state):
u0, u1, ii = state
print('iterating')
u0 = u1
ii += 1
jax.lax.cond(ii < 10000, lambda _: None, lambda ii: jax.debug.callback(raise_error, ii), ii)
return (u0, u1, ii)
u0, u1, ii = jax.lax.select(case1, (u0, u1, ii), jax.lax.while_loop(cond_fun=cond, body_fun=body, init_val=tuple([u0, u1, ii]))) However, lax.select is not accepting a tuple as the return type for the bypass case:
How should I do that? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
result = jax.lax.while_loop(cond_fun=cond, body_fun=body, init_val=tuple([u0, u1, ii]))
u0 = jax.lax.select(case1, u0, result[0])
u1 = jax.lax.select(case1, u1, result[1])
ii = jax.lax.select(case1, ii, result[2]) But please note that init_val = (u0, u1, ii)
u0, u1, ii = jax.lax.cond(case1,
lambda init_val: init_val,
lambda init_val: jax.lax.while_loop(cond_fun=cond, body_fun=body, init_val=init_val),
init_val) |
Beta Was this translation helpful? Give feedback.
lax.select
can only work with single array arguments. To do what you have in mind, you would need threeselect
statements:But please note that
lax.select
will in general execute both branches unconditionally (see thelax.select
docs). If you want to condition in a way that will only execute one branch, you can trylax.cond
. Something like this should work for that case: