Replies: 1 comment 4 replies
-
In many cases, you an use the jit-compatible three-term y = jnp.where(x < 0, 0, 1) There are some more examples & discussion in the documentation associated with the error that results from your original code: https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I want to use the following code in a function using jit :
However jnp.where provides an error with abstract values. So I used the jax.lax.cond function but with that my code is slower than before. I would like to know if there is another way to programm these 3 lines without using jax.lax.cond.
Thanks in advance.
Beta Was this translation helpful? Give feedback.
All reactions