Replies: 2 comments
-
Your solution seems like a reasonable approach. |
Beta Was this translation helpful? Give feedback.
0 replies
-
Thanks for your feedback. As I said, the situation is quite complicated in my work. It's a lot of work to figure out the limit values, though I know where the problematic points are. def linearize_wrapper(fn, a, margin=1e-3):
x0, x1 = a - margin, a + margin
y0, y1 = fn(x0), fn(x1)
def wrapper_fn(x):
cond = jnp.abs(a - x) < margin
x_ = jnp.where(cond, x0, x)
return jnp.where(
cond,
y0 + (x - x0) * (y1 - y0) / (x1 - x0),
fn(x_)
)
return wrapper_fn Example: def f(alpha):
b_minus_a = 5.
beta = 1.
return (jnp.exp(- alpha * b_minus_a) - jnp.exp(- beta * b_minus_a)) / (alpha - beta)
f_ = linearize_wrapper(f, 1.) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
What would be the best way to implement this function in JAX:
For$\alpha \neq \beta$ :
$f(a, b, \alpha, \beta) = (e^{-\alpha (b-a)}-e^{-\beta (b-a)})/(\alpha-\beta)$
For$\alpha = \beta$ :
$f(a, b, \alpha, \beta) = -e^{-\alpha(b-a)}(b-a)$
I want the function to be differentiable to all arguments.
My current solution is using the double where trick:
This seems to work, the gradient to$\alpha$ is not $\alpha=\beta$ .
NaN
whenBut I would love other viewpoints on this. Are there better ways to implement this?
Is there a general solution for defining functions with limits? Because this
f
structure is actually encapsulated in a much more complicated function in my code.I noticed the implementation of
sinc
in Numpy doesn't need an analytical expression for the limit.Wouldn't there be a similar way to automatically handle this kind of limit values in JAX?
Beta Was this translation helpful? Give feedback.
All reactions