-
Hello, I am working on converting some highway simulation code from C++ using the Adept library to python using JAX autograd. However, I believe I made all the necessary changes to get value_and_grad to work yet the gradient is returning 0. I don't think there are any issues in the code because it returns the expected value output, but the correct gradient isn't being returned. Any advice on what might be happening here, or how to go about debugging JAX for a grad equal to 0 when the output relies on the given input? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
In general, a 0 gradient means the output doesn't depend on the input via a chain of floating-point or complex operations. For example, if part of your computation occurs in the integers then that part won't propagate gradients. This isn't specific to JAX. However, we're not going to have the time to debug large pieces of code like this. But we might be able to help you out if you can reduce your question down to a small, self-contained piece of JAX code. Can you make the reproduction smaller? |
Beta Was this translation helpful? Give feedback.
In general, a 0 gradient means the output doesn't depend on the input via a chain of floating-point or complex operations. For example, if part of your computation occurs in the integers then that part won't propagate gradients. This isn't specific to JAX.
However, we're not going to have the time to debug large pieces of code like this. But we might be able to help you out if you can reduce your question down to a small, self-contained piece of JAX code. Can you make the reproduction smaller?