Forward over forward mode HVPs #7669
Replies: 1 comment 1 reply
-
Thanks for the question! Actually, there's a bug in the code. Check the dimension of these results: print(jnp.ndim(correct_answer)) # 2
print(jnp.ndim(hvp_fwdfwd(f, (X,), (V,)))) # 0 A hint is that in the implementation above you're using def hvp_fwdfwd(f, primals, tangents):
g = lambda primals: jvp(f, (primals,), tangents)[1]
return jvp(g, primals, tangents)[1] Another clue is dimensionality: for a function In general I don't think we can get an HVP using just two applications of from jax import jvp, grad, hessian, jacfwd
from jax import random
import jax.numpy as jnp
key = random.PRNGKey(0)
def f(X):
return jnp.sum(jnp.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
correct_answer = jnp.tensordot(hessian(f)(X), V, 2)
def hvp_fwdrev(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]
return grad(g)(primals)
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
def hvp_fwdfwd(f, primals, tangents):
g = lambda *primals: jvp(f, primals, tangents)[1]
return jacfwd(g)(*primals)
print("Forward over reverse, correct", jnp.allclose(
correct_answer, hvp_fwdrev(f, (X,), (V,)), 1e-4, 1e-4))
print("Reverse over forward, correct", jnp.allclose(
correct_answer, hvp_revfwd(f, (X,), (V,)), 1e-4, 1e-4))
print("Reverse over reverse, correct", jnp.allclose(
correct_answer, hvp_revrev(f, (X,), (V,)), 1e-4, 1e-4))
print("Forward over forward, correct", jnp.allclose(
correct_answer, hvp_fwdfwd(f, (X,), (V,)), 1e-4, 1e-4)) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Following the Autodiff cookbook, there are three variants for computing Hessian vector products (HVPs): a) forward-over-reverse, b) reverse-over-forward and c) reverse-over-reverse. Clearly, a fourth option is missing: d) forward-over-forward. I tried to implement this in the same style, however, it does not work. Is this a conceputal problem or is there a mistake in the code?
Here is a minimal example (jax.version = 0.2.12):
from jax import jvp, grad, hessian
from jax import random
key = random.PRNGKey(0)
def f(X):
return jnp.sum(jnp.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
correct_answer = jnp.tensordot(hessian(f)(X), V, 2)
def hvp_fwdrev(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]
return grad(g)(primals)
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
def hvp_fwdfwd(f, primals, tangents):
g = lambda primals: jvp(f, (primals,), tangents)[1]
return jvp(g, primals, tangents)[1]
print("Forward over reverse, correct", jnp.allclose(
correct_answer, hvp_fwdrev(f, (X,), (V,)), 1e-4, 1e-4))
print("Reverse over forward, correct", jnp.allclose(
correct_answer, hvp_revfwd(f, (X,), (V,)), 1e-4, 1e-4))
print("Reverse over reverse, correct", jnp.allclose(
correct_answer, hvp_revrev(f, (X,), (V,)), 1e-4, 1e-4))
print("Forward over forward, correct", jnp.allclose(
correct_answer, hvp_fwdfwd(f, (X,), (V,)), 1e-4, 1e-4))
Output:
Forward over reverse, correct True
Reverse over forward, correct True
Reverse over reverse, correct True
Forward over forward, correct False
Beta Was this translation helpful? Give feedback.
All reactions