Skip to content

Commit b019f98

Browse files
authored
Fix conversions from non-canonical scipy.sparse arrays. (#861)
1 parent 295226d commit b019f98

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

sparse/numba_backend/_compressed/compressed.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,14 @@ def from_coo(cls, x, compressed_axes=None, idx_dtype=None):
209209

210210
@classmethod
211211
def from_scipy_sparse(cls, x, /, *, fill_value=None):
212-
if x.format == "csc":
213-
return cls((x.data, x.indices, x.indptr), shape=x.shape, compressed_axes=(1,), fill_value=fill_value)
214-
215-
x = x.asformat("csr")
216-
return cls((x.data, x.indices, x.indptr), shape=x.shape, compressed_axes=(0,), fill_value=fill_value)
212+
is_csc = x.format == "csc"
213+
ca = (1,) if is_csc else (0,)
214+
if not is_csc:
215+
x = x.asformat("csr")
216+
if not x.has_canonical_format:
217+
x.eliminate_zeros()
218+
x.sum_duplicates()
219+
return cls((x.data, x.indices, x.indptr), shape=x.shape, compressed_axes=ca, fill_value=fill_value)
217220

218221
@classmethod
219222
def from_iter(cls, x, shape=None, compressed_axes=None, fill_value=None, idx_dtype=None):
@@ -903,6 +906,10 @@ def __init__(self, arg, shape=None, compressed_axes=class_compressed_axes, prune
903906
@classmethod
904907
def from_scipy_sparse(cls, x, /, *, fill_value=None):
905908
x = x.asformat("csr", copy=False)
909+
if not x.has_canonical_format:
910+
x.eliminate_zeros()
911+
x.sum_duplicates()
912+
906913
return cls((x.data, x.indices, x.indptr), shape=x.shape, fill_value=fill_value)
907914

908915
def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"]:
@@ -935,6 +942,10 @@ def __init__(self, arg, shape=None, compressed_axes=class_compressed_axes, prune
935942
@classmethod
936943
def from_scipy_sparse(cls, x, /, *, fill_value=None):
937944
x = x.asformat("csc", copy=False)
945+
if not x.has_canonical_format:
946+
x.eliminate_zeros()
947+
x.sum_duplicates()
948+
938949
return cls((x.data, x.indices, x.indptr), shape=x.shape, fill_value=fill_value)
939950

940951
def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"]:

sparse/numba_backend/_coo/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,10 @@ def from_scipy_sparse(cls, x, /, *, fill_value=None):
447447
True
448448
"""
449449
x = x.asformat("coo")
450+
if not x.has_canonical_format:
451+
x.eliminate_zeros()
452+
x.sum_duplicates()
453+
450454
coords = np.empty((2, x.nnz), dtype=x.row.dtype)
451455
coords[0, :] = x.row
452456
coords[1, :] = x.col

sparse/numba_backend/tests/test_conversion.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import pytest
55

6+
import numpy as np
7+
import scipy.sparse as sps
8+
69
FORMATS_ND = [
710
sparse.COO,
811
sparse.DOK,
@@ -38,3 +41,20 @@ def test_conversion_scalar(format1, format2):
3841
x = sparse.random((), format=format1, fill_value=0.5)
3942
y = x.asformat(format2)
4043
assert_eq(x, y)
44+
45+
46+
def test_non_canonical_conversion():
47+
"""
48+
Regression test for gh-602.
49+
50+
Adapted from https://github.com/LiberTEM/sparseconverter/blob/4cfc0ee2ad4c37b07742db8f3643bcbd858a4e85/src/sparseconverter/__init__.py#L154-L183
51+
"""
52+
data = np.array((2.0, 1.0, 3.0, 3.0, 1.0))
53+
indices = np.array((1, 0, 0, 1, 1), dtype=int)
54+
indptr = np.array((0, 2, 5), dtype=int)
55+
56+
x = sps.csr_matrix((data, indices, indptr), shape=(2, 2))
57+
ref = np.array(((1.0, 2.0), (3.0, 4.0)))
58+
59+
gcxs_check = sparse.GCXS(x)
60+
assert np.all(gcxs_check[:1].todense() == ref[:1]) and np.all(gcxs_check[1:].todense() == ref[1:])

0 commit comments

Comments
 (0)