Skip to content
Discussion options

You must be logged in to vote

It was added relatively recently, but if you do the (generalized) dot product of two sparse matrices the result will be sparse. Using the main branch, you could do something like this:

import numpy as np
import jax.numpy as jnp
from jax.experimental import sparse

c = jnp.array(np.random.rand(10))
M = jnp.array(np.random.rand(10, 3, 4))
M = M.at[M < 0.5].set(0)

Msp = sparse.BCOO.fromdense(M)
csp = sparse.BCOO.fromdense(c)

def f(M, c):
  return jnp.tensordot(M, c, axes=(0, 0))

print(f(M, c))
# [[1.0761551  1.1919075  2.4007194  1.0579658 ]
#  [2.5757928  0.77669764 1.0757202  1.4997834 ]
#  [0.14920855 0.8391593  1.8824157  1.1219863 ]]
print(sparse.sparsify(f)(Msp, csp).todense())
# [[…

Replies: 3 comments 3 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
3 replies
@DanPuzzuoli
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

Answer selected by DanPuzzuoli
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants