Jacobian-Vector Product of JAX AutoDiff Principle #10227
-
Hello JAX Community; While trying to use and understand As far as I understand,
What I understand from these conditions is that there is a mutual relationship between primals and tangents: Each tangent vector needs to be of tangent space of domain of its corresponding primal. In other words, if the function takes 2 arguments (primals) in, for example, R^5 and R, we need to pass two tangent vectors, each of which has to be member tangent space of its corresponding primal domain, that is T(R^5) and T(R). In this case, what I expected from I am adding an example code about this: import jax
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(0)
W_key, b_key, input_key, v1_key, v2_key = random.split(key, 5)
b = random.normal(b_key, (1,))
W = random.normal(W_key, (5,))
x = random.normal(input_key, (1,5))
tan_vec1 = random.normal(v1_key, W.shape)
tan_vec2 = random.normal(v2_key, b.shape)
single_neuron = lambda W, b: x @ W + b
y, cotan_vecs = jax.jvp(single_neuron, [W, b], [tan_vec1, tan_vec2])
jacobian_fun = jax.jacrev(single_neuron, argnums=(0, 1))
jacobian1, jacobian2 = jacobian_fun(W, b)
jvp1 = jacobian1 @ tan_vec1
jvp2 = jacobian2 @ tan_vec2
print("cotan vec 1 in R: ", jvp1)
print("cotan vec 2 in R: ", jvp2)
print("Sum of cotan vectors: ", jvp1 + jvp2)
print("What jax.jvp() returns: ", cotan_vecs) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You can think that JAX treats |
Beta Was this translation helpful? Give feedback.
You can think that JAX treats
primals
andtangents
as a flattened vector.i.e. in your R^5 and R example, JAX treats it as R^6.