Skip to content

Commit 12a2562

Browse files
authored
Allow for custom fill-values on conversion to/from scipy.sparse. (#685)
1 parent 93dde75 commit 12a2562

File tree

6 files changed

+79
-39
lines changed

6 files changed

+79
-39
lines changed

sparse/numba_backend/_compressed/compressed.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from .._coo.core import COO
1212
from .._sparse_array import SparseArray
1313
from .._utils import (
14+
_zero_of_dtype,
1415
can_store,
1516
check_compressed_axes,
16-
check_zero_fill_value,
17+
check_fill_value,
1718
equivalent,
1819
normalize_axis,
1920
)
@@ -137,7 +138,7 @@ def __init__(
137138
shape=None,
138139
compressed_axes=None,
139140
prune=False,
140-
fill_value=0,
141+
fill_value=None,
141142
idx_dtype=None,
142143
):
143144
from .._common import _is_scipy_sparse_obj
@@ -176,8 +177,11 @@ def __init__(
176177

177178
self.shape = shape
178179

180+
if fill_value is None:
181+
fill_value = _zero_of_dtype(self.data.dtype)
182+
179183
self._compressed_axes = tuple(compressed_axes) if isinstance(compressed_axes, Iterable) else None
180-
self.fill_value = fill_value
184+
self.fill_value = self.data.dtype.type(fill_value)
181185

182186
if prune:
183187
self._prune()
@@ -194,7 +198,7 @@ def copy(self, deep=True):
194198
return _copy.deepcopy(self) if deep else _copy.copy(self)
195199

196200
@classmethod
197-
def from_numpy(cls, x, compressed_axes=None, fill_value=0, idx_dtype=None):
201+
def from_numpy(cls, x, compressed_axes=None, fill_value=None, idx_dtype=None):
198202
coo = COO.from_numpy(x, fill_value=fill_value, idx_dtype=idx_dtype)
199203
return cls.from_coo(coo, compressed_axes, idx_dtype)
200204

@@ -204,12 +208,12 @@ def from_coo(cls, x, compressed_axes=None, idx_dtype=None):
204208
return cls(arg, shape=shape, compressed_axes=compressed_axes, fill_value=fill_value)
205209

206210
@classmethod
207-
def from_scipy_sparse(cls, x):
211+
def from_scipy_sparse(cls, x, /, *, fill_value=None):
208212
if x.format == "csc":
209-
return cls((x.data, x.indices, x.indptr), shape=x.shape, compressed_axes=(1,))
213+
return cls((x.data, x.indices, x.indptr), shape=x.shape, compressed_axes=(1,), fill_value=fill_value)
210214

211215
x = x.asformat("csr")
212-
return cls((x.data, x.indices, x.indptr), shape=x.shape, compressed_axes=(0,))
216+
return cls((x.data, x.indices, x.indptr), shape=x.shape, compressed_axes=(0,), fill_value=fill_value)
213217

214218
@classmethod
215219
def from_iter(cls, x, shape=None, compressed_axes=None, fill_value=None, idx_dtype=None):
@@ -471,13 +475,20 @@ def todok(self):
471475

472476
return DOK.from_coo(self.tocoo()) # probably a temporary solution
473477

474-
def to_scipy_sparse(self):
478+
def to_scipy_sparse(self, accept_fv=None):
475479
"""
476480
Converts this :obj:`GCXS` object into a :obj:`scipy.sparse.csr_matrix` or `scipy.sparse.csc_matrix`.
481+
482+
Parameters
483+
----------
484+
accept_fv : scalar or list of scalar, optional
485+
The list of accepted fill-values. The default accepts only zero.
486+
477487
Returns
478488
-------
479489
:obj:`scipy.sparse.csr_matrix` or `scipy.sparse.csc_matrix`
480490
The converted Scipy sparse matrix.
491+
481492
Raises
482493
------
483494
ValueError
@@ -487,8 +498,7 @@ def to_scipy_sparse(self):
487498
"""
488499
import scipy.sparse
489500

490-
check_zero_fill_value(self)
491-
501+
check_fill_value(self, accept_fv=accept_fv)
492502
if self.ndim != 2:
493503
raise ValueError("Can only convert a 2-dimensional array to a Scipy sparse matrix.")
494504

@@ -873,9 +883,9 @@ def __init__(self, arg, shape=None, compressed_axes=class_compressed_axes, prune
873883
super().__init__(arg, shape=shape, compressed_axes=compressed_axes, fill_value=fill_value)
874884

875885
@classmethod
876-
def from_scipy_sparse(cls, x):
886+
def from_scipy_sparse(cls, x, /, *, fill_value=None):
877887
x = x.asformat("csr", copy=False)
878-
return cls((x.data, x.indices, x.indptr), shape=x.shape)
888+
return cls((x.data, x.indices, x.indptr), shape=x.shape, fill_value=fill_value)
879889

880890
def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"]:
881891
axes = normalize_axis(axes, self.ndim)
@@ -905,9 +915,9 @@ def __init__(self, arg, shape=None, compressed_axes=class_compressed_axes, prune
905915
super().__init__(arg, shape=shape, compressed_axes=compressed_axes, fill_value=fill_value)
906916

907917
@classmethod
908-
def from_scipy_sparse(cls, x):
918+
def from_scipy_sparse(cls, x, /, *, fill_value=None):
909919
x = x.asformat("csc", copy=False)
910-
return cls((x.data, x.indices, x.indptr), shape=x.shape)
920+
return cls((x.data, x.indices, x.indptr), shape=x.shape, fill_value=fill_value)
911921

912922
def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"]:
913923
axes = normalize_axis(axes, self.ndim)

sparse/numba_backend/_coo/core.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .._utils import (
1616
_zero_of_dtype,
1717
can_store,
18+
check_fill_value,
1819
check_zero_fill_value,
1920
equivalent,
2021
normalize_axis,
@@ -425,14 +426,16 @@ def todense(self):
425426
return x
426427

427428
@classmethod
428-
def from_scipy_sparse(cls, x):
429+
def from_scipy_sparse(cls, x, /, *, fill_value=None):
429430
"""
430431
Construct a :obj:`COO` array from a :obj:`scipy.sparse.spmatrix`
431432
432433
Parameters
433434
----------
434435
x : scipy.sparse.spmatrix
435436
The sparse matrix to construct the array from.
437+
fill_value : scalar
438+
The fill-value to use when converting.
436439
437440
Returns
438441
-------
@@ -456,6 +459,7 @@ def from_scipy_sparse(cls, x):
456459
shape=x.shape,
457460
has_duplicates=not x.has_canonical_format,
458461
sorted=x.has_canonical_format,
462+
fill_value=fill_value,
459463
)
460464

461465
@classmethod
@@ -1155,10 +1159,15 @@ def resize(self, *args, refcheck=True, coords_dtype=np.intp):
11551159
if len(self.data) != len(linear_loc):
11561160
self.data = self.data[:end_idx].copy()
11571161

1158-
def to_scipy_sparse(self):
1162+
def to_scipy_sparse(self, /, *, accept_fv=None):
11591163
"""
11601164
Converts this :obj:`COO` object into a :obj:`scipy.sparse.coo_matrix`.
11611165
1166+
Parameters
1167+
----------
1168+
accept_fv : scalar or list of scalar, optional
1169+
The list of accepted fill-values. The default accepts only zero.
1170+
11621171
Returns
11631172
-------
11641173
:obj:`scipy.sparse.coo_matrix`
@@ -1178,7 +1187,7 @@ def to_scipy_sparse(self):
11781187
"""
11791188
import scipy.sparse
11801189

1181-
check_zero_fill_value(self)
1190+
check_fill_value(self, accept_fv=accept_fv)
11821191

11831192
if self.ndim != 2:
11841193
raise ValueError("Can only convert a 2-dimensional array to a Scipy sparse matrix.")

sparse/numba_backend/_dok.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,16 @@ def __init__(self, shape, data=None, dtype=None, fill_value=None):
131131
raise ValueError("data must be a dict.")
132132

133133
@classmethod
134-
def from_scipy_sparse(cls, x):
134+
def from_scipy_sparse(cls, x, /, *, fill_value=None):
135135
"""
136136
Create a :obj:`DOK` array from a :obj:`scipy.sparse.spmatrix`.
137137
138138
Parameters
139139
----------
140140
x : scipy.sparse.spmatrix
141141
The matrix to convert.
142+
fill_value : scalar
143+
The fill-value to use when converting.
142144
143145
Returns
144146
-------
@@ -154,7 +156,7 @@ def from_scipy_sparse(cls, x):
154156
"""
155157
from sparse import COO
156158

157-
return COO.from_scipy_sparse(x).asformat(cls)
159+
return COO.from_scipy_sparse(x, fill_value=fill_value).asformat(cls)
158160

159161
@classmethod
160162
def from_coo(cls, x):

sparse/numba_backend/_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def assert_eq(x, y, check_nnz=True, compare_dtype=True, **kwargs):
3333
if isinstance(x, COO) and isinstance(y, COO) and check_nnz:
3434
assert np.array_equal(x.coords, y.coords)
3535
assert check_equal(x.data, y.data, **kwargs)
36-
assert x.fill_value == y.fill_value
36+
assert x.fill_value == y.fill_value or (np.isnan(x.fill_value) and np.isnan(y.fill_value))
3737
return
3838

3939
if hasattr(x, "todense"):
@@ -528,6 +528,31 @@ def check_compressed_axes(ndim, compressed_axes):
528528
raise ValueError("axis out of range")
529529

530530

531+
def check_fill_value(x, /, *, accept_fv=None) -> None:
532+
"""Raises on incorrect fill-values.
533+
534+
Parameters
535+
----------
536+
x : SparseArray
537+
The array to check
538+
accept_fv : scalar or list of scalar, optional
539+
The list of accepted fill-values. The default accepts only zero.
540+
541+
Raises
542+
------
543+
ValueError
544+
If the fill-value doesn't match.
545+
"""
546+
if accept_fv is None:
547+
accept_fv = [0]
548+
549+
if not isinstance(accept_fv, Iterable):
550+
accept_fv = [accept_fv]
551+
552+
if not any(equivalent(fv, x.fill_value) for fv in accept_fv):
553+
raise ValueError(f"{x.fill_value=} but should be in {accept_fv}.")
554+
555+
531556
def check_zero_fill_value(*args):
532557
"""
533558
Checks if all the arguments have zero fill-values.

sparse/numba_backend/tests/test_compressed.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import sparse
22
from sparse.numba_backend._compressed import GCXS
3-
from sparse.numba_backend._utils import assert_eq
3+
from sparse.numba_backend._utils import assert_eq, equivalent
44

55
import pytest
66

77
import numpy as np
8-
import scipy
98

109

1110
@pytest.fixture(scope="module", params=["f8", "f4", "i8", "i4"])
@@ -177,18 +176,21 @@ def test_tranpose(a, b):
177176
assert_eq(x.transpose(b), s.transpose(b))
178177

179178

180-
def test_to_scipy_sparse():
181-
s = sparse.random((3, 5), density=0.5, format="gcxs", compressed_axes=(0,))
182-
a = s.to_scipy_sparse()
183-
b = scipy.sparse.csr_matrix(s.todense())
179+
@pytest.mark.parametrize("fill_value_in", [0, np.inf, np.nan, 5, None])
180+
@pytest.mark.parametrize("fill_value_out", [0, np.inf, np.nan, 5, None])
181+
@pytest.mark.parametrize("format", [sparse.COO, sparse._compressed.CSR])
182+
def test_to_scipy_sparse(fill_value_in, fill_value_out, format):
183+
s = sparse.random((3, 5), density=0.5, format=format, fill_value=fill_value_in)
184184

185-
assert_eq(a, b)
185+
if not ((fill_value_in in {0, None} and fill_value_out in {0, None}) or equivalent(fill_value_in, fill_value_out)):
186+
with pytest.raises(ValueError, match=r"fill_value=.* but should be in .*\."):
187+
s.to_scipy_sparse(accept_fv=fill_value_out)
188+
return
186189

187-
s = sparse.random((3, 5), density=0.5, format="gcxs", compressed_axes=(1,))
188-
a = s.to_scipy_sparse()
189-
b = scipy.sparse.csc_matrix(s.todense())
190+
sps_matrix = s.to_scipy_sparse(accept_fv=fill_value_in)
191+
s2 = format.from_scipy_sparse(sps_matrix, fill_value=fill_value_out)
190192

191-
assert_eq(a, b)
193+
assert_eq(s, s2)
192194

193195

194196
def test_tocoo():

sparse/numba_backend/tests/test_coo.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -380,14 +380,6 @@ def test_reshape_errors(format):
380380
s.reshape((3, 5, 1), order="F")
381381

382382

383-
def test_to_scipy_sparse():
384-
s = sparse.random((3, 5), density=0.5)
385-
a = s.to_scipy_sparse()
386-
b = scipy.sparse.coo_matrix(s.todense())
387-
388-
assert_eq(a, b)
389-
390-
391383
@pytest.mark.parametrize("a_ndim", [1, 2, 3])
392384
@pytest.mark.parametrize("b_ndim", [1, 2, 3])
393385
def test_kron(a_ndim, b_ndim):

0 commit comments

Comments
 (0)