Skip to content
Discussion options

You must be logged in to vote

lax.select can only work with single array arguments. To do what you have in mind, you would need three select statements:

result = jax.lax.while_loop(cond_fun=cond, body_fun=body, init_val=tuple([u0, u1, ii]))
u0 = jax.lax.select(case1, u0, result[0])
u1 = jax.lax.select(case1, u1, result[1])
ii = jax.lax.select(case1, ii, result[2])

But please note that lax.select will in general execute both branches unconditionally (see the lax.select docs). If you want to condition in a way that will only execute one branch, you can try lax.cond. Something like this should work for that case:

init_val = (u0, u1, ii)
u0, u1, ii = jax.lax.cond(case1,
    lambda init_val: init_val,
    lambda init_val: jax

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by andreok
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants