Skip to content
Discussion options

You must be logged in to vote

Thanks for the question. FIrst of all, in trying to answer your question I realized there is a bug in how A @ B is implemented for sparse matrices. You should be able to just write A @ X @ B and have it dispatch to the proper operation; I'm planning to fix the bug in #8455

That said, it works correctly right now within sparsify, so you can do this:

import jax.numpy as jnp
from jax import grad
from jax.experimental import sparse

A = jnp.arange(24.).reshape(2, 3, 4)
X = jnp.arange(20.).reshape(4, 5)
B = jnp.arange(60.).reshape(2, 5, 6)

@sparse.sparsify
def f(A, X, B):
  return A @ X @ B

Asp = sparse.BCOO.fromdense(A, n_batch=1)
Bsp = sparse.BCOO.fromdense(B, n_batch=1)

print(f(A, X, B).s…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@DanPuzzuoli
Comment options

@jakevdp
Comment options

Answer selected by DanPuzzuoli
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants