Skip to content
Discussion options

You must be logged in to vote

In general, a 0 gradient means the output doesn't depend on the input via a chain of floating-point or complex operations. For example, if part of your computation occurs in the integers then that part won't propagate gradients. This isn't specific to JAX.

However, we're not going to have the time to debug large pieces of code like this. But we might be able to help you out if you can reduce your question down to a small, self-contained piece of JAX code. Can you make the reproduction smaller?

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by mjhoover1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants