condition function on jax #7201
Unanswered
lucasliunju
asked this question in
Q&A
Replies: 1 comment 2 replies
-
Hi - thanks for the question. I think we'll need more information to give you a useful answer. Can you give a short, self-contained example of code that demonstrates the issue you're having? |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I would like to use condition function (if... else...) to control data flow.
I have tried
jax.numpy.where
andjax.lax.cond
I find:
1.
jax.numpy.numpy
cannot be used to calculate gradients2.
jax.lax.cond
can be used to calculate gradients, but it will calculate both the two conditions and that will cause extra time.So I would like to ask whether there are some condition functions which just calculate the selected condition.
Thank you!
Beta Was this translation helpful? Give feedback.
All reactions