Gradient of arctan2 fails but unclear why #15865
-
I have a large network that, at some point, uses the
at the line containing the call to The function performs fine in the forward pass and does not produce any Looking at the gradient of
the gradients are always evaluated well. It's a bit tricky to know how to start resolving this problem, since clearly my diagnosis is wrong. What are the next steps to trying to understand how this issue is arising? Going through the full code in debug mode and breaking when we reach that line, it doesn't seem like any of the values are 0, but I'm still struggling to peer into the Jax internals. Adding some stuff Printing the jaxpr gives me the following
so it all seems good until that div step where I expect the error is coming from, but I don't quite understand why my minimal example should not fail? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! Operationally, the issue is that the gradient of More conceptually, |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
Operationally, the issue is that the gradient of
atan2(x, y)
contains the termx / (x ** 2 + y ** 2)
, so ifx
andy
have zeros in the same location, the gradient will beNaN
.More conceptually,$(0, 0)$ , so its derivative at that point is ill-defined.
atan2
is not continuous at the point