Replies: 2 comments 3 replies
-
I get a minimal repro: import jax.numpy as jnp
import jax
id_f = lambda x: x
def b(x):
return jax.lax.cond(x[0], id_f, id_f, x) # jnp.where(x[0], id_f(x), id_f(x)) is ok
f = jax.jacrev(lambda x: b(b(x))) # jacfwd is ok
def repro():
jax.vmap(f)(jnp.ones((1,1)))
def ok():
f(jnp.ones((1,)))
print('ok')
if __name__ == "__main__":
with jax.disable_jit():
ok()
repro() |
Beta Was this translation helpful? Give feedback.
2 replies
-
Thanks, this is definitely a bug! Actually it might fit better as a bug issue than a 'discussion'. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Inside a nested scan when using
jax.lax.cond
I get error message that is hard to understand.Original code was very complex, sadly minimal example I could make is still quite complex.
Bizzarly, when changing small details code runs fine (see commented lines).
Working fine:
x * d_www[0,0,0] # Ok
x * f_www(x)[0,0,0,0] # Ok
Throwing hard to understand error
jnp.array([1.0]) * f_www(x)[0, 0, 0, 0] * d_www[0, 0, 0] # AttributeError
Which makes me think that both
d_www
creation andf_www
creation work fine.I've found no similar issues when searching.
~~R
Beta Was this translation helpful? Give feedback.
All reactions