NaN values while using jax.numpy.sqrt #6810
Unanswered
aonurdasdemir
asked this question in
Q&A
Replies: 1 comment 1 reply
-
The best way to deal with this is probably via custom JVP/VJP rules. This situation is very similar to the Numerical Stability section here: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html#numerical-stability If you define an appropriate custom JVP/VJP for your function, you can control the behavior at values that are problematic for the raw floating point implementation. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am implementing a sobel filter inside my optimization package and I need to use vjp on that function. However, since sobel filter has a term
G = jnp.sqrt(jnp.square(G_x) + jnp.square(G_y)
where G_x and G_y are arrays and some elements of them are zero. When I use vjp on this function, it results in a array that has Nan values. From what I see on previous questions, the problem is using sqrt with zero values. Are there any possible solutions or workarounds to solve this problem?Here is a minimized code which results in nan in vjp for functions that has sqrt and input array has zero in some elements:
Beta Was this translation helpful? Give feedback.
All reactions