how to compute jvp and vjp for sparse (BCOO) arguments #16012
Unanswered
stefanozampini
asked this question in
Q&A
Replies: 1 comment 2 replies
-
By contrast, Given this, if you want to use import jax
import jax.numpy as jnp
import numpy as np
def f(x):
return x.sum()
x = jnp.arange(6.0).reshape(2, 3)
dx = jnp.ones_like(x)
primals_out, tangents_out = jax.jvp(f, (x,), (dx,))
print(primals_out)
# 15.0
print(tangents_out)
# 6.0 and here's roughly the sparse equivalent: from jax.experimental import sparse
def f_raw(data, indices, shape=x.shape):
x_bcoo = sparse.BCOO((data, indices), shape=shape)
return f(x_bcoo)
x_dot_bcoo = sparse.BCOO((jnp.ones_like(x_bcoo.data), x_bcoo.indices), shape=x_bcoo.shape)
primals = (x_bcoo.data, x_bcoo.indices)
tangents = (x_dot_bcoo.data, np.zeros_like(x_dot_bcoo.indices, dtype=jax.dtypes.float0))
primals_out, tangents_out = jax.jvp(f_raw, primals, tangents)
print(primals_out)
# 15.0
print(tangents_out)
# 5.0 All that boilerplate is effectively what |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
JAX provides experimental support for
grad
andvalue_and_grad
for scalar functions with sparse arguments, andjacfwd
andjacrev
for (possibly) vector functions. Can you provide an example of how to computejvp
andvjp
for vector functions with sparse arguments?Beta Was this translation helpful? Give feedback.
All reactions