jax.grad yields nan of jnp.sqrt(array)? #17263
Replies: 1 comment 7 replies
-
I don't quite understand the question – here's the code I constructed based on what you provided in your question: import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)
multiplier = jnp.float64(-3.55448731e-13)
arr = jnp.array([0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
1.86344989, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.02569974, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.36218982,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ])
value_inside_exp = multiplier * jnp.sqrt(arr)
result = jnp.exp(value_inside_exp)
print(result)
# [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
# 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
# 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
# 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
# 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] This seems like reasonable output – what should I be looking for? |
Beta Was this translation helpful? Give feedback.
7 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
working with 64bit and isolated a problem down to one line which I abbreviate here:
example input:
So if there was problem with the value then I could use jnp.where or otherwise to retain a value at extremes. Assuming this is about extreme values, the grad returns nan. I've tried clipping the 'value_inside_exp' but that just affects the value while the grad still returns nan values in the result. Is there a way to control the grad output (ie set to zero when values or grad are really small)?
I do think nan are showing up in the gradient array due to scale of numbers as the input 'arr' above is operated on elsewhere without issue, while the variables in grad wrt variables array that deliver nan are the variables used in building this padded array.... unless I'm missing another reason for the nan's.
Every function is @jax.jit.
Beta Was this translation helpful? Give feedback.
All reactions