Skip to content

Commit 925112f

Browse files
authored
API: Add set functions [Array API] (#619)
1 parent b8f2717 commit 925112f

File tree

4 files changed

+147
-3
lines changed

4 files changed

+147
-3
lines changed

sparse/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
roll,
5050
tril,
5151
triu,
52+
unique_counts,
53+
unique_values,
5254
where,
5355
)
5456
from ._dok import DOK
@@ -114,6 +116,8 @@
114116
"min",
115117
"max",
116118
"nanreduce",
119+
"unique_counts",
120+
"unique_values",
117121
]
118122

119123
__array_api_version__ = "2022.12"

sparse/_coo/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
stack,
2121
tril,
2222
triu,
23+
unique_counts,
24+
unique_values,
2325
where,
2426
)
2527
from .core import COO, as_coo
@@ -49,4 +51,6 @@
4951
"result_type",
5052
"diagonal",
5153
"diagonalize",
54+
"unique_counts",
55+
"unique_values",
5256
]

sparse/_coo/common.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from collections.abc import Iterable
44
from functools import reduce
5-
from typing import Optional, Tuple
5+
from typing import NamedTuple, Optional, Tuple
66

77
import numba
88

@@ -1059,6 +1059,106 @@ def clip(a, a_min=None, a_max=None, out=None):
10591059
return a.clip(a_min, a_max)
10601060

10611061

1062+
# Array API set functions
1063+
1064+
1065+
class UniqueCountsResult(NamedTuple):
1066+
values: np.ndarray
1067+
counts: np.ndarray
1068+
1069+
1070+
def unique_counts(x, /):
1071+
"""
1072+
Returns the unique elements of an input array `x`, and the corresponding
1073+
counts for each unique element in `x`.
1074+
1075+
Parameters
1076+
----------
1077+
x : COO
1078+
Input COO array. It will be flattened if it is not already 1-D.
1079+
1080+
Returns
1081+
-------
1082+
out : namedtuple
1083+
The result containing:
1084+
* values - The unique elements of an input array.
1085+
* counts - The corresponding counts for each unique element.
1086+
1087+
Raises
1088+
------
1089+
ValueError
1090+
If the input array is in a different format than COO.
1091+
1092+
Examples
1093+
--------
1094+
>>> import sparse
1095+
>>> x = sparse.COO.from_numpy([1, 0, 2, 1, 2, -3])
1096+
>>> sparse.unique_counts(x)
1097+
UniqueCountsResult(values=array([-3, 0, 1, 2]), counts=array([1, 1, 2, 2]))
1098+
"""
1099+
from .core import COO
1100+
1101+
if isinstance(x, scipy.sparse.spmatrix):
1102+
x = COO.from_scipy_sparse(x)
1103+
elif not isinstance(x, SparseArray):
1104+
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1105+
elif not isinstance(x, COO):
1106+
x = x.asformat(COO)
1107+
1108+
x = x.flatten()
1109+
values, counts = np.unique(x.data, return_counts=True)
1110+
if x.nnz < x.size:
1111+
values = np.concatenate([[x.fill_value], values])
1112+
counts = np.concatenate([[x.size - x.nnz], counts])
1113+
sorted_indices = np.argsort(values)
1114+
values[sorted_indices] = values.copy()
1115+
counts[sorted_indices] = counts.copy()
1116+
1117+
return UniqueCountsResult(values, counts)
1118+
1119+
1120+
def unique_values(x, /):
1121+
"""
1122+
Returns the unique elements of an input array `x`.
1123+
1124+
Parameters
1125+
----------
1126+
x : COO
1127+
Input COO array. It will be flattened if it is not already 1-D.
1128+
1129+
Returns
1130+
-------
1131+
out : ndarray
1132+
The unique elements of an input array.
1133+
1134+
Raises
1135+
------
1136+
ValueError
1137+
If the input array is in a different format than COO.
1138+
1139+
Examples
1140+
--------
1141+
>>> import sparse
1142+
>>> x = sparse.COO.from_numpy([1, 0, 2, 1, 2, -3])
1143+
>>> sparse.unique_values(x)
1144+
array([-3, 0, 1, 2])
1145+
"""
1146+
from .core import COO
1147+
1148+
if isinstance(x, scipy.sparse.spmatrix):
1149+
x = COO.from_scipy_sparse(x)
1150+
elif not isinstance(x, SparseArray):
1151+
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1152+
elif not isinstance(x, COO):
1153+
x = x.asformat(COO)
1154+
1155+
x = x.flatten()
1156+
values = np.unique(x.data)
1157+
if x.nnz < x.size:
1158+
values = np.sort(np.concatenate([[x.fill_value], values]))
1159+
return values
1160+
1161+
10621162
@numba.jit(nopython=True, nogil=True)
10631163
def _compute_minmax_args(
10641164
coords: np.ndarray,
@@ -1121,8 +1221,12 @@ def _arg_minmax_common(
11211221

11221222
from .core import COO
11231223

1124-
if not isinstance(x, COO):
1125-
raise ValueError(f"Only COO arrays are supported but {type(x)} was passed.")
1224+
if isinstance(x, scipy.sparse.spmatrix):
1225+
x = COO.from_scipy_sparse(x)
1226+
elif not isinstance(x, SparseArray):
1227+
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1228+
elif not isinstance(x, COO):
1229+
x = x.asformat(COO)
11261230

11271231
if not isinstance(axis, (int, type(None))):
11281232
raise ValueError(f"`axis` must be `int` or `None`, but it's: {type(axis)}.")

sparse/tests/test_coo.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,3 +1745,35 @@ def test_squeeze_validation(self):
17451745

17461746
with pytest.raises(ValueError, match="Specified axis `0` has a size greater than one: 3"):
17471747
s_arr.squeeze(0)
1748+
1749+
1750+
class TestUnique:
1751+
arr = np.array([[0, 0, 1, 5, 3, 0], [1, 0, 4, 0, 3, 0], [0, 1, 0, 1, 1, 0]], dtype=np.int64)
1752+
arr_empty = np.zeros((5, 5))
1753+
arr_full = np.arange(1, 10)
1754+
1755+
@pytest.mark.parametrize("arr", [arr, arr_empty, arr_full])
1756+
@pytest.mark.parametrize("fill_value", [-1, 0, 1])
1757+
def test_unique_counts(self, arr, fill_value):
1758+
s_arr = sparse.COO.from_numpy(arr, fill_value)
1759+
1760+
result_values, result_counts = sparse.unique_counts(s_arr)
1761+
expected_values, expected_counts = np.unique(arr, return_counts=True)
1762+
1763+
np.testing.assert_equal(result_values, expected_values)
1764+
np.testing.assert_equal(result_counts, expected_counts)
1765+
1766+
@pytest.mark.parametrize("arr", [arr, arr_empty, arr_full])
1767+
@pytest.mark.parametrize("fill_value", [-1, 0, 1])
1768+
def test_unique_values(self, arr, fill_value):
1769+
s_arr = sparse.COO.from_numpy(arr, fill_value)
1770+
1771+
result = sparse.unique_values(s_arr)
1772+
expected = np.unique(arr)
1773+
1774+
np.testing.assert_equal(result, expected)
1775+
1776+
@pytest.mark.parametrize("func", [sparse.unique_counts, sparse.unique_values])
1777+
def test_input_validation(self, func):
1778+
with pytest.raises(ValueError, match=r"Input must be an instance of SparseArray"):
1779+
func(self.arr)

0 commit comments

Comments
 (0)