-
I am new to JAX and was trying out the
From what I have read on various other issues it seems that it is necessary to set Furthermore, I've seen that it is not possible to set |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question: the expected import jax.numpy as jnp
# A is 4x4 with nse=3
A = jnp.array([[1, 0, 0, 0],
[1, 0, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 0]])
# B is 4 x 5 with nse=4
B = jnp.array([[1, 1, 1, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
# C is 4 x 5 with nse = A.nse * B.nse = 12
C = A @ B
print(C)
# [[1 1 1 1 0]
# [1 1 1 1 0]
# [1 1 1 1 0]
# [0 0 0 0 0]] In the worst case, |
Beta Was this translation helpful? Give feedback.
Thanks for the question: the expected
nse
in this case isA.nse * B.nse
, because that's the upper bound of the required nse for arbitrary sparse matrices. To see why, consider a situation like this:In the worst case,
C.nse
will be equal toA.nse * B.nse
.…