-
Is there currently a way to do the following operation with sparse arrays? import jax.numpy as jnp
from jax.experimental import sparse
probs = jnp.array(
[
[0.5, 0.0, 0.0, 0.6],
[0.8, 0.0, 0.9, 0.1],
]
)
probs_sp = sparse.BCOO.fromdense(probs)
(1 - probs).prod(axis=1)
# returns [0.2, 0.018 ]
(1 - probs_sp).prod(axis=1)
# NotImplementedError: Subtraction between sparse and dense array. For obvious reasons, Example: Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
I believe I could do since but there must be a better way... |
Beta Was this translation helpful? Give feedback.
-
The reason this operation fails is because in general (1 - probs_sp.todense()).prod(axis=1) But I understand that for your specific computation, the sparse elements in the matrix become ones that do not contribute to the final product. JAX sparse does not have any support for this specific type of computation; you might be able to proceed by working in log-space where the ones become zeros and the multiplications become additions. |
Beta Was this translation helpful? Give feedback.
The reason this operation fails is because in general
1 - spmat
is a fully-dense matrix, and JAX avoids implicit conversion to dense. If converting your matrix to dense is actually what you want to do, then you can do it this way:But I understand that for your specific computation, the sparse elements in the matrix become ones that do not contribute to the final product. JAX sparse does not have any support for this specific type of computation; you might be able to proceed by working in log-space where the ones become zeros and the multiplications become additions.