-
Hello, I have a function that calculates the sigmoid taking in values x and an x_offset. Note that this error only occurs when using @jit
def sigmoid2(x, x_offset):
sigmoid_slope = 32.
return 1. / (1. + jnp.exp(-sigmoid_slope * (x - x_offset))) The input x is
So we have an inf showing up here usually |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The main way to deal with this kind of thing is to define a custom JVP. For example, the Even easier would be to rewrite your function in terms of from jax.scipy.special import expit
def sigmoid2(x, x_offset):
sigmoid_slope = 32.
return expit(sigmoid_slope * (x - x_offset)) |
Beta Was this translation helpful? Give feedback.
The main way to deal with this kind of thing is to define a custom JVP. For example, the
jax.scipy.special.expit
function is essentially equivalent to a sigmoid, and deals with NaN issue this way: https://github.com/google/jax/blob/4e219220558a2279de2abc9dc4140dedeb61703f/jax/_src/scipy/special.py#L98-L104Even easier would be to rewrite your function in terms of
expit
, so that you can take advantage of this custom JVP without having to re-implement it yourself.