Second derivative of a simple piecewise function is NaN #7190
Answered
by
jakevdp
mariogeiger
asked this question in
Q&A
-
Hi, I have a piecewise C^infinity function that I want jax to be able to derive at will. First attempt: def sus(x):
return jnp.where(x > 0.0, jnp.exp(-1.0 / x), 0)
print(jax.grad(sus)(jnp.zeros(()))) # nan Second attempt: @jax.custom_vjp
def sus(x):
return jnp.where(x > 0.0, jnp.exp(-1.0 / x), 0)
def sus_fwd(x):
y = sus(x)
return y, (x, y)
def sus_bwd(res, g):
x, y = res
return (jnp.where(x > 0.0, g * y / x**2, 0),)
sus.defvjp(sus_fwd, sus_bwd)
print(jax.grad(sus)(jnp.zeros(()))) # 0
print(jax.grad(jax.grad(sus))(jnp.zeros(()))) # nan I would like all the derivatives to work fine. |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Jul 4, 2021
Replies: 1 comment
-
Hi - this is a common issue with gradients of functions that use jnp.where. You can find some information on this issue and how to work around it here: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
mariogeiger
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi - this is a common issue with gradients of functions that use jnp.where. You can find some information on this issue and how to work around it here: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where