Skip to content
Discussion options

You must be logged in to vote

You can verify that it is deduped by:

jax.jit(lambda x: x * x).lower(2.0).compile().compiler_ir()[0].to_string()

Alternatively you can:

def jvp_and_vjp(fn, a, T_a):
  b, jvp_f = jax.linearize(fn, a)
  T_b = jvp_f(T_a)
  vjp_f = jax.linear_transpose(jvp_f, T_b) # only use T_b shape and dtype
  return b, T_b, vjp_f

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@qsh-zh
Comment options

Comment options

You must be logged in to vote
0 replies
Answer selected by qsh-zh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants