Skip to content
Discussion options

You must be logged in to vote

JAX does symbolic automatic differentiation at the primitive level. You can see this by printing the jaxpr for the transformed function:

import jax
import jax.numpy as jnp

def a(x):
    return jnp.sin(x)

x = jnp.float32(1.0)

print(jax.make_jaxpr(a)(x))
# { lambda ; a:f32[]. let b:f32[] = sin a in (b,) }

print(jax.make_jaxpr(jax.grad(a))(x))
# { lambda ; a:f32[]. let
#     _:f32[] = sin a
#     b:f32[] = cos a
#     c:f32[] = mul 1.0 b
#   in (c,) }

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jakubMitura14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants