Replies: 1 comment 2 replies
-
The JIT does common subexpression elimination. Here's an example: import jax
import jax.numpy as jnp
def foo(x):
def f(x):
return jnp.sin(x)
def g(x):
return f(x) + x
def h(x):
return f(x) + 1/x
return jax.grad(g)(x), jax.grad(h)(x)
jitted = jax.jit(foo)
compiled = jitted.lower(1.).compile()
print(compiled.as_text()) This outputs:
As you can see there's only one call to |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
I was about the behaviour of JAX relative to the chain rule when I want to perform certain operations on the gradients. Say I have the following code:
Can I be certain that JAX won't compute the gradient of
f
twice ?Beta Was this translation helpful? Give feedback.
All reactions