Can I get JVP and VJP in one pass? #11009
-
Assume we have a function def naive_jvp_and_jvp(fn, a, T_a):
b, T_b = jax.jvp(fn, (a, ), (T_a,))
b, f_jvp = jax.vjp(fn, a)
return b, T_b, f_jvp However, such implementation needs to calculate |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
I think they'd be deduped after JIT. |
Beta Was this translation helpful? Give feedback.
-
You can verify that it is deduped by: jax.jit(lambda x: x * x).lower(2.0).compile().compiler_ir()[0].to_string() Alternatively you can: def jvp_and_vjp(fn, a, T_a):
b, jvp_f = jax.linearize(fn, a)
T_b = jvp_f(T_a)
vjp_f = jax.linear_transpose(jvp_f, T_b) # only use T_b shape and dtype
return b, T_b, vjp_f |
Beta Was this translation helpful? Give feedback.
You can verify that it is deduped by:
Alternatively you can: