Unclear why gradients are NaN when including a sigmoid #15617
-
I have the following function
And I want to differentiate it with respect to the
But when I apply
and for these values the magic value of lamb is If i remove the sigmoid from the function then differentiating is completely fine, at any value of |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! The issue is a numerical stability one in the expression (Tangentially: I found that by using the Back to how to fix this particular issue: basically, use See this section of the Using What do you think? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
The issue is a numerical stability one in the expression
1/(1 + jnp.exp(-gumbel_samp))
: settinglamb=0.5
, we get values ofgumbel_samp
up to-104.296745
, which in turn means we're evaluating something likejnp.exp(104.296745)
, which gives us floating point's infinity. That inf appears in a Jacobian coefficient, and so leads to the nan when we multiply it by a zero cotangent.(Tangentially: I found that by using the
with jax.debug_nans()
context manager (and alsowith jax.debug_infs()
), which raised an exception as soon as an operation produced a nan (or inf), along with a post-mortem debugger. That showed me exactly where things were going wrong, but I just realiz…