-
I am working on relabeling a large dataset with a large number of classes. Naively using a dense matrix would lead to memory inefficiency or out of memory, although the matrix storing the class labels is sparse. Hence, I am using
However, I am facing a problem when indexing or slicing the sparse matrix. When returning a sparse submatrix, that submatrix counts the zero elements as well: import jax.numpy as jnp
from jax.experimental import sparse
M = jnp.array([[0., 1., 0., 2.], [3., 0., 0., 0.], [0., 0., 4., 0.]])
M_sp = sparse.BCOO.fromdense(M)
print(M_sp) # print result: BCOO(float32[3, 4], nse=4)
m = M_sp[0]
print(m) # print result: BCOO(float32[4], nse=4)
print(m.data) # print result: Array([1., 2., 0., 0.], dtype=float32)
print(m.indices) # print result: Array([[1], [3], [4], [4]], dtype=int32) I expect that when indexing, it should return a sparse sub-matrix, whose data is non-zero elements and their corresponding indices in that new sub-matrix. The current behavior is strange. A work around solution is that whenever making a sparse sub-matrix, I have to prune all of the zero values in |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
This behavior comes from the static shape requirements of If you want to remove these explicit zeros, you can do so via the m = m.sum_duplicates(remove_zeros=True)
print(m.data)
print(m.indices)
Note however that because the output of this is dynamically-shaped (the shape of the Alternatively, you can use a structured sparse layout (in this case a BCOO with one batch dimension) and then this sort of indexing will be more constrained, because we know a priori that there are a maximum of two nonzero elements per row: M_sp = sparse.BCOO.fromdense(M, n_batch=1)
print(M_sp)
m = M_sp[0]
print(m.data)
print(m.indices)
I hope that helps! |
Beta Was this translation helpful? Give feedback.
This behavior comes from the static shape requirements of
jax.jit
and other transformations. For a general sparse array, it is impossible to know at compile time how many elements are nonzero in, say, the first row. So when you index the first row, the code returns a padded representation that will always be able to contain the row's contents.If you want to remove these explicit zeros, you can do so via the
sum_duplicates
method:Note however that because the output of this is dynamically-shaped (the shape of the
data
andindex
arrays depend on the array contents) you won't be able to do this opera…