Sparse tensordot with sparse outputs? #8407
-
Hi, I've just started playing around with the BCOO sparse arrays and I'm having trouble performing the following operation. Basically, I want to compute a linear combination of a list of a sparse matrices, with the list of sparse matrices being specified as a 3d BCOO array, the coefficients a dense 1d array, and I want the output to be another BCOO. To be exact, I want to do:
If everything is dense I would just do: An alternative approach I've tried to take is using Any ideas? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
I just realized I can actually do:
though this seems a little excessive as the |
Beta Was this translation helpful? Give feedback.
-
It was added relatively recently, but if you do the (generalized) dot product of two sparse matrices the result will be sparse. Using the main branch, you could do something like this: import numpy as np
import jax.numpy as jnp
from jax.experimental import sparse
c = jnp.array(np.random.rand(10))
M = jnp.array(np.random.rand(10, 3, 4))
M = M.at[M < 0.5].set(0)
Msp = sparse.BCOO.fromdense(M)
csp = sparse.BCOO.fromdense(c)
def f(M, c):
return jnp.tensordot(M, c, axes=(0, 0))
print(f(M, c))
# [[1.0761551 1.1919075 2.4007194 1.0579658 ]
# [2.5757928 0.77669764 1.0757202 1.4997834 ]
# [0.14920855 0.8391593 1.8824157 1.1219863 ]]
print(sparse.sparsify(f)(Msp, csp).todense())
# [[1.0761552 1.1919074 2.4007192 1.0579658 ]
# [2.5757926 0.77669764 1.0757202 1.4997835 ]
# [0.14920855 0.8391593 1.8824158 1.1219863 ]] Currently it's limited to cases where one of the matrices has a single sparse dimension, and only sparse dimensions, not batch dimensions, can be contracted. |
Beta Was this translation helpful? Give feedback.
-
Thanks for the detailed replies. As I'm working on a feature in a package that uses JAX as a dependency (and so want to stick to things already in a release), I think for now I'll proceed with:
as I'm already seeing a substantial speed improvement over dense arrays (on CPU) in a toy example. I'll make sure to update it once the element-wise broadcasted multiplication makes it into a release. Thanks for the help Jake! JAX is awesome! |
Beta Was this translation helpful? Give feedback.
It was added relatively recently, but if you do the (generalized) dot product of two sparse matrices the result will be sparse. Using the main branch, you could do something like this: