Best practice of experimental.jet
to evaluate laplacian of an scalar valued MLP
#9598
Replies: 2 comments 9 replies
-
Hi @YouJiacheng! I'll let the experts answer as well, but let me add some quick pointers, maybe this is already helpful :)
|
Beta Was this translation helpful? Give feedback.
-
This paper said it is 10x faster when the order of differentiation is 2. I don't know if it is laplacian-like differentiation operator(result is a scalar), or it is hessian-like operator(result is a high order tensor). Since the (k+1)-th term of I use following method to compute laplacian of a mlp(exp activation, as the paper suggested), it is about 1.4x slower than from functools import partial
import jax
import jax.numpy as jnp
from jax.experimental import jet
# jet.fact = lambda n: jax.lax.prod(range(1, n + 1))
def f(ws, wo, x):
for w in ws:
x = jax.lax.exp(x @ w)
return jnp.reshape(x @ wo, ())
@jax.jit
@partial(jax.vmap, in_axes=(None, None, 0))
def laplacian_1(ws, wo, x):
fun = partial(f, ws, wo)
@jax.vmap
def hvv(v):
return jet.jet(fun, (x,), ((v, jnp.zeros_like(x)),))[1][1]
return jnp.sum(hvv(jnp.eye(x.shape[0], dtype=x.dtype)))
@jax.jit
@partial(jax.vmap, in_axes=(None, None, 0))
def laplacian_2(ws, wo, x):
fun = partial(f, ws, wo)
in_tangents = jnp.eye(x.shape[0], dtype=x.dtype)
pushfwd = partial(jax.jvp, jax.grad(fun), (x,))
_, hessian = jax.vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
return jnp.trace(hessian)
@jax.jit
@partial(jax.vmap, in_axes=(None, None, 0))
def laplacian_3(ws, wo, x):
fun = partial(f, ws, wo)
return jnp.trace(jax.hessian(fun)(x))
def timer(f):
from time import time
f() # compile
t = time()
for _ in range(3):
f()
print((time() - t) / 3)
d = 256
ws = [jnp.zeros((d, d)) for _ in range(64)]
wo = jnp.zeros((d, 1))
x = jnp.zeros((512, d))
timer(lambda : jax.block_until_ready(laplacian_1(ws, wo, x)))
timer(lambda : jax.block_until_ready(laplacian_2(ws, wo, x)))
timer(lambda : jax.block_until_ready(laplacian_3(ws, wo, x))) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
This paper show that
jet
can accelerate high order differentiation of Two-layer MLP withexp
non-linearities.But I cannot find the example code of this paper. How can I use
jet
to accelerate laplacian computation of MLP?Even if the non-linearities are not
exp
?@mattjj
Beta Was this translation helpful? Give feedback.
All reactions