Skip to content
Discussion options

You must be logged in to vote

I will go ahead and answer this myself. It is because it is called a jacobian vector product. 🤦‍♂️

So the correct function is given by:

from functools import partial
import jax
import jax.numpy as jnp

@partial(jax.custom_jvp, nondiff_argnums=(1,))
def f(x, A):
    return jnp.dot(x,jnp.dot(A,x))

@f.defjvp
def f_jvp(A,primals,tangents):
    x, = primals
    x_dot, = tangents
    return jnp.dot(x,jnp.dot(A,x)), jnp.dot(jnp.dot(A+A.T,x),x_dot)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by mathDR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant