Is there a way to get JIT to treat constant inner function outputs as compile time constants? #9778
-
One of the common patterns I use in JAX is the following: f = lambda x,xi: np.dot(H(x),xi)
df = grad(f)
L = jit(lambda xi: df(x0,xi)) Essentially, Is it possible to accomplish this with the tools JAX offers? |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 17 replies
-
maybe you should check whether lowered = L.lower(example_xi)
print(lowered.compile().compiler_ir()[0].to_string()) If not, you can define custom vjp for |
Beta Was this translation helpful? Give feedback.
-
https://github.com/google/jax/blob/main/jax/interpreters/partial_eval.py |
Beta Was this translation helpful? Give feedback.
-
|
Beta Was this translation helpful? Give feedback.
-
If you don't need the gradient w.r.t. |
Beta Was this translation helpful? Give feedback.
https://github.com/google/jax/blob/main/jax/interpreters/partial_eval.py
I think this tool can help you gracefully achieve your goal.
jax
use it to implementlinearize
as follow:You can declare
x
as known value then usepe.trace_to_jaxpr
.WDYT?