Way around BCOO batch dimension contraction in example? #8454
-
Hi, I'm working on writing some code using the BCOO type and am running into an issue I'm not sure how to get around. Mathematically, what I'm trying to do is write a function that computes the following:
For
When trying to use
This however doesn't work unless I set
(summing all elements as opposed to only over the 0 axis to get a scalar output) I can't take the
Is there a different way I can write this so that I can also grad through this computation? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Thanks for the question. FIrst of all, in trying to answer your question I realized there is a bug in how That said, it works correctly right now within import jax.numpy as jnp
from jax import grad
from jax.experimental import sparse
A = jnp.arange(24.).reshape(2, 3, 4)
X = jnp.arange(20.).reshape(4, 5)
B = jnp.arange(60.).reshape(2, 5, 6)
@sparse.sparsify
def f(A, X, B):
return A @ X @ B
Asp = sparse.BCOO.fromdense(A, n_batch=1)
Bsp = sparse.BCOO.fromdense(B, n_batch=1)
print(f(A, X, B).shape)
# (2, 3, 6)
print(f(Asp, X, Bsp).shape)
# (2, 3, 6) As for the Unfortunately, when you run something like from jax import grad
gsp = jax.grad(lambda X: f(Asp, X, B).sum())
print(gsp(X).shape) the reverse-mode autodiff requires doing a sparse matmul where the contraction dimension is over the array's batch dimension, which is not yet implemented – there's no deep reason why this shouldn't be possible, it's just not something we've written the code for yet (it's on the list). Fortunately, forward-mode autodiff does not require such a sparse-sparse matrix product, so if you can reexpress your gradient computation in terms of forward-mode autodiff, then you will be able to compute the result. For example: from jax import grad, jacfwd
g = grad(lambda X: f(A, X, B).sum())
print(g(X))
# [[ 9540. 11700. 13860. 16020. 18180.]
# [10170. 12546. 14922. 17298. 19674.]
# [10800. 13392. 15984. 18576. 21168.]
# [11430. 14238. 17046. 19854. 22662.]]
gsp = jacfwd(lambda X: f(Asp, X, B).sum())
print(gsp(X))
# [[ 9540. 11700. 13860. 16020. 18180.]
# [10170. 12546. 14922. 17298. 19674.]
# [10800. 13392. 15984. 18576. 21168.]
# [11430. 14238. 17046. 19854. 22662.]] (Note Hopefully we will be able to get those unimplemented matmul modes implemented very soon, so you won't need this workaround any more. |
Beta Was this translation helpful? Give feedback.
Thanks for the question. FIrst of all, in trying to answer your question I realized there is a bug in how
A @ B
is implemented for sparse matrices. You should be able to just writeA @ X @ B
and have it dispatch to the proper operation; I'm planning to fix the bug in #8455That said, it works correctly right now within
sparsify
, so you can do this: