-
Hi all, I'm working on code which involves taking derivatives w.r.t some parameter For some reasons which would be out of topic to describe here I cannot rely on autodifferentiation alone. Writing However I can use autodifferentiation for
Since the gradients of Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Without more details it's hard to say how the compiler will optimize any particular sequence of operations, but in general it should. The most definitive way to check is to print the compiled HLO to see. Here's a simple example: import jax
import jax.numpy as jnp
def f(x):
y = jnp.sin(x)
z = jnp.sin(x)
return y + z
print(jax.jit(f).lower(1.0).compile().as_text())
You can see that in this case, the sine is only computed once ( |
Beta Was this translation helpful? Give feedback.
Without more details it's hard to say how the compiler will optimize any particular sequence of operations, but in general it should. The most definitive way to check is to print the compiled HLO to see. Here's a simple example: