Skip to content
Discussion options

You must be logged in to vote

This behavior comes from the static shape requirements of jax.jit and other transformations. For a general sparse array, it is impossible to know at compile time how many elements are nonzero in, say, the first row. So when you index the first row, the code returns a padded representation that will always be able to contain the row's contents.

If you want to remove these explicit zeros, you can do so via the sum_duplicates method:

m = m.sum_duplicates(remove_zeros=True)
print(m.data)
print(m.indices)
[1. 2.]
[[1]
 [3]]

Note however that because the output of this is dynamically-shaped (the shape of the data and index arrays depend on the array contents) you won't be able to do this opera…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@cnguyen10
Comment options

Answer selected by cnguyen10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants