Skip to content

Commit 63a8593

Browse files
committed
API: Update asarray function
1 parent f95eeb3 commit 63a8593

File tree

4 files changed

+36
-8
lines changed

4 files changed

+36
-8
lines changed

sparse/finch_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from finch import (
77
add,
8+
asarray,
89
astype,
910
bool,
1011
compiled,
@@ -38,6 +39,7 @@
3839

3940
__all__ = [
4041
"add",
42+
"asarray",
4143
"astype",
4244
"bool",
4345
"compiled",

sparse/pydata_backend/_common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,16 +2033,16 @@ def asarray(obj, /, *, dtype=None, format="coo", device=None, copy=False):
20332033
<COO: shape=(8, 8), dtype=int64, nnz=8, fill_value=0>
20342034
"""
20352035

2036-
if format not in {"coo", "dok", "gcxs"}:
2036+
if format not in {"coo", "dok", "gcxs", "csc", "csr"}:
20372037
raise ValueError(f"{format} format not supported.")
20382038

2039-
from ._compressed import GCXS
2039+
from ._compressed import CSC, CSR, GCXS
20402040
from ._coo import COO
20412041
from ._dok import DOK
20422042

2043-
format_dict = {"coo": COO, "dok": DOK, "gcxs": GCXS}
2043+
format_dict = {"coo": COO, "dok": DOK, "gcxs": GCXS, "csc": CSC, "csr": CSR}
20442044

2045-
if isinstance(obj, COO | DOK | GCXS):
2045+
if isinstance(obj, COO | DOK | GCXS | CSC | CSR):
20462046
return obj.asformat(format)
20472047

20482048
if _is_scipy_sparse_obj(obj):

sparse/pydata_backend/_compressed/compressed.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,8 @@ def isnan(self):
817817

818818

819819
class _Compressed2d(GCXS):
820+
class_compressed_axes: tuple[int]
821+
820822
def __init__(self, arg, shape=None, compressed_axes=None, prune=False, fill_value=0):
821823
if not hasattr(arg, "shape") and shape is None:
822824
raise ValueError("missing `shape` argument")
@@ -847,6 +849,11 @@ def __str__(self):
847849
def ndim(self) -> int:
848850
return 2
849851

852+
@classmethod
853+
def from_numpy(cls, x, fill_value=0, idx_dtype=None):
854+
coo = COO.from_numpy(x, fill_value=fill_value, idx_dtype=idx_dtype)
855+
return cls.from_coo(coo, cls.class_compressed_axes, idx_dtype)
856+
850857

851858
class CSR(_Compressed2d):
852859
"""
@@ -857,8 +864,12 @@ class CSR(_Compressed2d):
857864
Sparse supports 2-D CSR.
858865
"""
859866

860-
def __init__(self, arg, shape=None, prune=False, fill_value=0):
861-
super().__init__(arg, shape=shape, compressed_axes=(0,), fill_value=fill_value)
867+
class_compressed_axes: tuple[int] = (0,)
868+
869+
def __init__(self, arg, shape=None, compressed_axes=class_compressed_axes, prune=False, fill_value=0):
870+
if compressed_axes != self.class_compressed_axes:
871+
raise ValueError(f"CSR only accepts rows as compressed axis but got: {compressed_axes}")
872+
super().__init__(arg, shape=shape, compressed_axes=compressed_axes, fill_value=fill_value)
862873

863874
@classmethod
864875
def from_scipy_sparse(cls, x):
@@ -882,8 +893,12 @@ class CSC(_Compressed2d):
882893
Sparse supports 2-D CSC.
883894
"""
884895

885-
def __init__(self, arg, shape=None, prune=False, fill_value=0):
886-
super().__init__(arg, shape=shape, compressed_axes=(1,), fill_value=fill_value)
896+
class_compressed_axes: tuple[int] = (1,)
897+
898+
def __init__(self, arg, shape=None, compressed_axes=class_compressed_axes, prune=False, fill_value=0):
899+
if compressed_axes != self.class_compressed_axes:
900+
raise ValueError(f"CSC only accepts columns as compressed axis but got: {compressed_axes}")
901+
super().__init__(arg, shape=shape, compressed_axes=compressed_axes, fill_value=fill_value)
887902

888903
@classmethod
889904
def from_scipy_sparse(cls, x):

sparse/tests/test_backends.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import sparse
22

3+
import pytest
4+
35
import numpy as np
46
import scipy.sparse as sp
57
from numpy.testing import assert_equal
@@ -54,3 +56,12 @@ def my_fun(tns1, tns2):
5456
result = my_fun(finch_dense, finch_arr)
5557

5658
assert_equal(result.todense(), np.sum(2 * np_eye, axis=0))
59+
60+
61+
@pytest.mark.parametrize("format", ["csc", "csr", "coo"])
62+
def test_asarray(backend, format):
63+
arr = np.eye(5)
64+
65+
result = sparse.asarray(arr, format=format)
66+
67+
assert_equal(result.todense(), arr)

0 commit comments

Comments
 (0)