The forward derivative of tanh #19052
-
Using jaxlib- 0.4.23+cuda12.cudnn89 I tried computing the derivative of jnp.tanh like this: x = 2.34567
f = jnp.tanh
f_x, df_dx = jax.value_and_grad(f)(x)
print(df_dx, 1 - f_x**2)
print("Difference: ", df_dx - (1 - f_x**2)) with the following output:
Why are the exact derivatives not matching? Is it the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
If you're ever curious about the exact sequence of operations that are used to compute an automatic gradient (or any other operation), you can see them using import jax
import jax.numpy as jnp
x = 2.34567
f = jnp.tanh
def df1(x):
return jax.grad(f)(x)
def df2(x):
return 1 - f(x)**2
print(jax.make_jaxpr(df1)(x))
# { lambda ; a:f32[]. let
# b:f32[] = tanh a
# c:f32[] = sub 1.0 b
# d:f32[] = mul 1.0 c
# e:f32[] = mul d b
# f:f32[] = add_any d e
# in (f,) }
print(jax.make_jaxpr(df2)(x))
# { lambda ; a:f32[]. let
# b:f32[] = tanh a
# c:f32[] = integer_pow[y=2] b
# d:f32[] = sub 1.0 c
# in (d,) } These are two different ways of computing what would be the same value in real-valued arithmetic, but in floating point the errors accumulate differently. The results you're seeind differ by about 1 part in If you want to see how this autodiff rule is defined in the code, you can find it here: https://github.com/google/jax/blob/c172be137911ee77b8f1327b98d2b0c0f8b459ea/jax/_src/lax/lax.py#L1800-L1801 |
Beta Was this translation helpful? Give feedback.
If you're ever curious about the exact sequence of operations that are used to compute an automatic gradient (or any other operation), you can see them using
make_jaxpr
:These are two different ways of computing …