Error using jit for speed up #14393
Unanswered
PoulomiPradhan
asked this question in
Q&A
Replies: 1 comment 2 replies
-
453 n_after = self.rndx[i][ref_wl]
--> 454 if z_dir_after < 0:
455 n_after = -n_after
456 ifc.delta_n = n_after - n_before There is some general discussion of this here: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow. If you share a short reproduction of the issue here, we could give more pointed advice for your particular situation, but assuming lines 454-455 are the only problem, you may be able to replace them with this: n_after *= jnp.sign(z_dir_after) |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have a function where I am applying grad and getting values. But when I am using jit for speedup, I am getting Jax error.
This works fine but the code below throws error
Error :
Can someone kindly help here ?
Beta Was this translation helpful? Give feedback.
All reactions