Skip to content
Discussion options

You must be logged in to vote

If you're ever curious about the exact sequence of operations that are used to compute an automatic gradient (or any other operation), you can see them using make_jaxpr:

import jax
import jax.numpy as jnp

x = 2.34567
f = jnp.tanh

def df1(x):
  return jax.grad(f)(x)

def df2(x):
  return 1 - f(x)**2

print(jax.make_jaxpr(df1)(x))
# { lambda ; a:f32[]. let
#     b:f32[] = tanh a
#     c:f32[] = sub 1.0 b
#     d:f32[] = mul 1.0 c
#     e:f32[] = mul d b
#     f:f32[] = add_any d e
#   in (f,) }

print(jax.make_jaxpr(df2)(x))
# { lambda ; a:f32[]. let
#     b:f32[] = tanh a
#     c:f32[] = integer_pow[y=2] b
#     d:f32[] = sub 1.0 c
#   in (d,) }

These are two different ways of computing …

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by dinesh110598
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants