You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using the rounding function for quantization I use straight-through-estimators (see code below), however, they cause an error when computing the diagonal of the hessian ("can't apply forward-mode autodiff (jvp) to a custom_vjp function."). Is there any way around that in JAX so that I can keep custom_vjps?
I not only want to have hessian diagonals for the weights/params but also for intermediate activations (e.g. x1 and x2 in the code below). In pytorch this can be done somehow like this https://github.com/cvlab-yonsei/EWGS/blob/56c654cb893d53563eb352dd591d1450c34bdd15/ImageNet/utils.py#L161 (retaining a graph and then you can arbitrarily call the grad functions on matrices in the graph?). Any idea how that would be realizable in JAX?
At the bottom of this post, you can find some code illustrating the challenges I am facing. Any help or hint is truly appreciated, JAX so far has made my life a lot easier and I feel that those problems should be easily solvable but I somehow haven't found a way to do it yet.
Many thanks and go JAX!
Clemens
importjaximportjax.numpyasjnpfromtypingimportAny, Callablefromjax.flatten_utilimportravel_pytreeArray=jnp.ndarray# Rounding with straight-through-estimator@jax.custom_vjpdefroundpass(x):
returnjnp.round(x)
defroundpass_fwd(x):
returnroundpass(x), (None,)
defroundpass_bwd(res, g):
return (g,)
roundpass.defvjp(roundpass_fwd, roundpass_bwd)
# simple NN with two layersrng=jax.random.PRNGKey(0)
rng_p1, rng_p2, rng_p3, rng_p4=jax.random.split(rng, 4)
inputs=jax.random.normal(rng_p1, (10, 8))
params= [
jax.random.normal(rng_p2, (8, 9)),
jax.random.normal(rng_p3, (9, 11))
]
targets=jax.random.normal(rng_p4, (10, 11))
defloss_fn(params, x):
x1=jnp.dot(x, params[0])
xi=roundpass(x1) # replace this line with xi = x1 for functioning codex2=jnp.dot(xi, params[1])
returnjnp.sum((x2-targets)**2)
defloss_wrt_params(x): returnloss_fn(x, inputs)
# compute diagonal of hessian based on https://github.com/deepmind/optax/blob/master/optax/_src/second_order.pydefravel(p: Any) ->Array:
returnravel_pytree(p)[0]
_, unravel_fn=ravel_pytree(params)
vs=jnp.eye(ravel(params).size)
defcomp(v): returnjnp.vdot(
v, ravel(jax.jvp(jax.grad(loss_wrt_params), [params], [unravel_fn(v)])[1]))
hess_diag_wrt_weights=jax.vmap(comp)(vs)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hey everyone,
This question is actually twofold. Generally, I want to implement a quantization scheme that utilizes hessian information (e.g. https://arxiv.org/pdf/1911.03852.pdf or https://arxiv.org/pdf/2104.00903.pdf). And I am facing two problems right now:
At the bottom of this post, you can find some code illustrating the challenges I am facing. Any help or hint is truly appreciated, JAX so far has made my life a lot easier and I feel that those problems should be easily solvable but I somehow haven't found a way to do it yet.
Many thanks and go JAX!
Clemens
Beta Was this translation helpful? Give feedback.
All reactions