Skip to content
Discussion options

You must be logged in to vote

The main way to deal with this kind of thing is to define a custom JVP. For example, the jax.scipy.special.expit function is essentially equivalent to a sigmoid, and deals with NaN issue this way: https://github.com/google/jax/blob/4e219220558a2279de2abc9dc4140dedeb61703f/jax/_src/scipy/special.py#L98-L104

Even easier would be to rewrite your function in terms of expit, so that you can take advantage of this custom JVP without having to re-implement it yourself.

from jax.scipy.special import expit
def sigmoid2(x, x_offset):
    sigmoid_slope = 32.
    return expit(sigmoid_slope * (x - x_offset))

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@mjhoover1
Comment options

Answer selected by mjhoover1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants