Skip to content

Commit ac5785d

Browse files
committed
BUG: asarray: fix default format for SparseArray input
1 parent 67be471 commit ac5785d

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

sparse/numba_backend/_common.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2073,7 +2073,7 @@ def format_to_string(format):
20732073

20742074

20752075
@_check_device
2076-
def asarray(obj, /, *, dtype=None, format="coo", copy=False, device=None):
2076+
def asarray(obj, /, *, dtype=None, format=None, copy=False, device=None):
20772077
"""
20782078
Convert the input to a sparse array.
20792079
@@ -2085,6 +2085,8 @@ def asarray(obj, /, *, dtype=None, format="coo", copy=False, device=None):
20852085
Output array data type.
20862086
format : str, optional
20872087
Output array sparse format.
2088+
Default: existing format if the input is a `SparseArray`,
2089+
else COO.
20882090
device : str, optional
20892091
Device on which to place the created array.
20902092
copy : bool, optional
@@ -2102,7 +2104,7 @@ def asarray(obj, /, *, dtype=None, format="coo", copy=False, device=None):
21022104
<COO: shape=(8, 8), dtype=int64, nnz=8, fill_value=0>
21032105
"""
21042106

2105-
if format not in {"coo", "dok", "gcxs", "csc", "csr"}:
2107+
if format not in {None, "coo", "dok", "gcxs", "csc", "csr"}:
21062108
raise ValueError(f"{format} format not supported.")
21072109

21082110
from ._compressed import CSC, CSR, GCXS
@@ -2111,8 +2113,10 @@ def asarray(obj, /, *, dtype=None, format="coo", copy=False, device=None):
21112113

21122114
format_dict = {"coo": COO, "dok": DOK, "gcxs": GCXS, "csc": CSC, "csr": CSR}
21132115

2114-
if isinstance(obj, COO | DOK | GCXS | CSC | CSR):
2115-
return obj.asformat(format)
2116+
if isinstance(obj, SparseArray):
2117+
return obj.asformat(format) if format is not None else obj
2118+
2119+
format = "coo" if format is None else format
21162120

21172121
if _is_scipy_sparse_obj(obj):
21182122
sparse_obj = format_dict[format].from_scipy_sparse(obj)

sparse/numba_backend/tests/test_array_function.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sparse
2+
from sparse.numba_backend import SparseArray
23
from sparse.numba_backend._settings import NEP18_ENABLED
34
from sparse.numba_backend._utils import assert_eq
45

@@ -132,3 +133,6 @@ def test_asarray(self, input, dtype, format):
132133
expected = input.todense() if hasattr(input, "todense") else np.asarray(input)
133134

134135
np.testing.assert_equal(actual, expected)
136+
137+
if isinstance(input, SparseArray):
138+
assert sparse.asarray(input).__class__ is input.__class__

0 commit comments

Comments
 (0)