Skip to content
Discussion options

You must be logged in to vote

JAX doesn't have any fully supported sparse computation, but there is jax.experimental.sparse, and bcoo_dot_general_sampled is the equivalent operation you're looking for.

That said, you shouldn't expect that function to be particularly performant: it's really just a reference implementation for exploring the interaction between sparse computation and JAX transformations.

With that caveat, here's what it would look like to call that API:

import jax
import jax.numpy as jnp
from jax.experimental import sparse

x = jnp.arange(6).reshape(2, 3)
y = jnp.arange(12).reshape(3, 4)

mask = jnp.array([[1, 0, 0, 1],
                  [0, 1, 0, 1]])
result_direct = (x @ y) * mask
print(result_direct)
#…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by kkew3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants