Skip to content
Discussion options

You must be logged in to vote

Thanks for the question!

The issue is a numerical stability one in the expression 1/(1 + jnp.exp(-gumbel_samp)): setting lamb=0.5, we get values of gumbel_samp up to -104.296745, which in turn means we're evaluating something like jnp.exp(104.296745), which gives us floating point's infinity. That inf appears in a Jacobian coefficient, and so leads to the nan when we multiply it by a zero cotangent.

(Tangentially: I found that by using the with jax.debug_nans() context manager (and also with jax.debug_infs()), which raised an exception as soon as an operation produced a nan (or inf), along with a post-mortem debugger. That showed me exactly where things were going wrong, but I just realiz…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Peter-Vincent
Comment options

Answer selected by Peter-Vincent
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants