Skip to content
Discussion options

You must be logged in to vote

Thanks for the question: the expected nse in this case is A.nse * B.nse, because that's the upper bound of the required nse for arbitrary sparse matrices. To see why, consider a situation like this:

import jax.numpy as jnp

# A is 4x4 with nse=3
A = jnp.array([[1, 0, 0, 0],
               [1, 0, 0, 0],
               [1, 0, 0, 0],
               [0, 0, 0, 0]])

# B is 4 x 5 with nse=4
B = jnp.array([[1, 1, 1, 1, 0],
               [0, 0, 0, 0, 0],
               [0, 0, 0, 0, 0],
               [0, 0, 0, 0, 0]])

# C is 4 x 5 with nse = A.nse * B.nse = 12
C = A @ B
print(C)
# [[1 1 1 1 0]
#  [1 1 1 1 0]
#  [1 1 1 1 0]
#  [0 0 0 0 0]]

In the worst case, C.nse will be equal to A.nse * B.nse.…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@CamRuiz
Comment options

Answer selected by CamRuiz
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants