Skip to content
Discussion options

You must be logged in to vote

The reason this operation fails is because in general 1 - spmat is a fully-dense matrix, and JAX avoids implicit conversion to dense. If converting your matrix to dense is actually what you want to do, then you can do it this way:

(1 - probs_sp.todense()).prod(axis=1)

But I understand that for your specific computation, the sparse elements in the matrix become ones that do not contribute to the final product. JAX sparse does not have any support for this specific type of computation; you might be able to proceed by working in log-space where the ones become zeros and the multiplications become additions.

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@edavisau
Comment options

Answer selected by edavisau
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants