Get nan when calculating partial derivatives of a smooth function use jax.jacfwd #7313
-
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
This is expected behavior because the gradient calculation here can be numerically problematic. Remember that Increasing the precision to 64 bits is sufficient to sidestep the problem for the given range because 64 bits are sufficient to represent these values (The problem reemerges if we extend the input range to 4). Now you also noted that jacrev does not suffer from the inf problem. Although jacfwd and jacrev yield the same results under infinite precision arithmetic, this no longer holds under finite precision. jacrev and jacfwd order operations differently and thus one algorithm may avoid these inf operations. Even if jacrev does not output nan, jacrev is not numerically accurate either. It predicts a derivative of -70 around 2.2 while the function is close to being constant. These numerical instability problems are known. The typical way to solve them is to rewrite our function in an equivalent way as to remove this instability. For example, for your function we could replace your f function with def f(r):
part = ring_width**2*(jnp.sqrt(2.)+2.*jnp.exp((r - ring_center)**2/(2*ring_width**2))*jnp.sqrt(jnp.pi)*ring_width)
return 3*nu*(2.*jnp.sqrt(2.)*r**2-2*jnp.sqrt(2.)*r*ring_center)/(2.*r*part) - 3*nu/(2.*r) |
Beta Was this translation helpful? Give feedback.
This is expected behavior because the gradient calculation here can be numerically problematic. Remember that
(f/g)' = (f'g - g'f/)g^2
. Theg^2
here grows to a very large number when the input is close to 2.2 and for the standard 32 bit precision floats, we indeed get an inf value. At the same time the numerator of the derivative has problems of its own, since we end having to evaluate inf-inf which gives nan. Unless we are very careful about how we apply the chain rule, we can end up getting a nan.Increasing the precision to 64 bits is sufficient to sidestep the problem for the given range because 64 bits are sufficient to represent these values (The problem reemerges if we extend the i…