crash in nested scan #14816
Unanswered
jecampagne
asked this question in
Q&A
Replies: 1 comment
-
Well I found one after the post, but I guess it can be improved and share solution. The trick is to use array of length def func(argt):
nturns, m = argt
def noop(carry):
return carry
def step_once(carry):
i,m = carry
i += 1
return (i,m)
def body_fn(carry,dummy):
i, m = carry
return jax.lax.cond(i<m, step_once, noop, carry), None
i = 80
carry = (i, m)
#carry, _ = jax.lax.scan(body_fn, carry, None, length=nturns) ##### do not use
carry, _ = jax.lax.scan(body_fn, carry, nturns) # <--------- instead
i,m = carry
return i
def body_fn(carry,dummy):
#decode carry
nturns, m, iold = carry
argt = (nturns, m)
inew = func(argt)
#encode carry
carry = (nturns,m,inew)
return carry, None
nturns, m, i0 = 20, 100, 999
carry = (jax.lax.iota(jnp.int32,nturns), m, i0) #### <--------------- instead of (nturns, m,i0)
carry, _ = jax.lax.scan(body_fn, carry, None, length=1)
print("i end:",carry[2]) So, I get a solution if I know from the beginning the number of loops to be done in the nested scan.... overwise no hope? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello;
Below a nested scan use case which is a simplified version of a code that was originally using lax.while_loop and lax.fori_loop and I would like to use lax.scan to jacrev. This is particularly important due to the while loop which was inside the "func" function.
My point is the crash considering "length=nturns" as nturns is Traced. Do you see a way to transform the "func" using scan?
Thanks
Beta Was this translation helpful? Give feedback.
All reactions