-
I want to do very large dense-dense matrix multiplication but only on very few elements of the result matrix. So the overall operation is |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
JAX doesn't have any fully supported sparse computation, but there is That said, you shouldn't expect that function to be particularly performant: it's really just a reference implementation for exploring the interaction between sparse computation and JAX transformations. With that caveat, here's what it would look like to call that API: import jax
import jax.numpy as jnp
from jax.experimental import sparse
x = jnp.arange(6).reshape(2, 3)
y = jnp.arange(12).reshape(3, 4)
mask = jnp.array([[1, 0, 0, 1],
[0, 1, 0, 1]])
result_direct = (x @ y) * mask
print(result_direct)
# [[20 0 0 29]
# [ 0 68 0 92]]
indices = sparse.BCOO.fromdense(mask).indices
data = sparse.bcoo_dot_general_sampled(
x, y, indices, dimension_numbers=(([1], [0]), ([], [])))
result_sparse = sparse.BCOO((data, indices), shape=mask.shape)
print(result_sparse.todense())
# [[20 0 0 29]
# [ 0 68 0 92]] |
Beta Was this translation helpful? Give feedback.
JAX doesn't have any fully supported sparse computation, but there is
jax.experimental.sparse
, andbcoo_dot_general_sampled
is the equivalent operation you're looking for.That said, you shouldn't expect that function to be particularly performant: it's really just a reference implementation for exploring the interaction between sparse computation and JAX transformations.
With that caveat, here's what it would look like to call that API: