|
2 | 2 | import warnings
|
3 | 3 | from collections.abc import Iterable
|
4 | 4 | from functools import reduce
|
5 |
| -from typing import Optional, Tuple |
| 5 | +from typing import NamedTuple, Optional, Tuple |
6 | 6 |
|
7 | 7 | import numba
|
8 | 8 |
|
@@ -1059,6 +1059,106 @@ def clip(a, a_min=None, a_max=None, out=None):
|
1059 | 1059 | return a.clip(a_min, a_max)
|
1060 | 1060 |
|
1061 | 1061 |
|
| 1062 | +# Array API set functions |
| 1063 | + |
| 1064 | + |
| 1065 | +class UniqueCountsResult(NamedTuple): |
| 1066 | + values: np.ndarray |
| 1067 | + counts: np.ndarray |
| 1068 | + |
| 1069 | + |
| 1070 | +def unique_counts(x, /): |
| 1071 | + """ |
| 1072 | + Returns the unique elements of an input array `x`, and the corresponding |
| 1073 | + counts for each unique element in `x`. |
| 1074 | +
|
| 1075 | + Parameters |
| 1076 | + ---------- |
| 1077 | + x : COO |
| 1078 | + Input COO array. It will be flattened if it is not already 1-D. |
| 1079 | +
|
| 1080 | + Returns |
| 1081 | + ------- |
| 1082 | + out : namedtuple |
| 1083 | + The result containing: |
| 1084 | + * values - The unique elements of an input array. |
| 1085 | + * counts - The corresponding counts for each unique element. |
| 1086 | +
|
| 1087 | + Raises |
| 1088 | + ------ |
| 1089 | + ValueError |
| 1090 | + If the input array is in a different format than COO. |
| 1091 | +
|
| 1092 | + Examples |
| 1093 | + -------- |
| 1094 | + >>> import sparse |
| 1095 | + >>> x = sparse.COO.from_numpy([1, 0, 2, 1, 2, -3]) |
| 1096 | + >>> sparse.unique_counts(x) |
| 1097 | + UniqueCountsResult(values=array([-3, 0, 1, 2]), counts=array([1, 1, 2, 2])) |
| 1098 | + """ |
| 1099 | + from .core import COO |
| 1100 | + |
| 1101 | + if isinstance(x, scipy.sparse.spmatrix): |
| 1102 | + x = COO.from_scipy_sparse(x) |
| 1103 | + elif not isinstance(x, SparseArray): |
| 1104 | + raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.") |
| 1105 | + elif not isinstance(x, COO): |
| 1106 | + x = x.asformat(COO) |
| 1107 | + |
| 1108 | + x = x.flatten() |
| 1109 | + values, counts = np.unique(x.data, return_counts=True) |
| 1110 | + if x.nnz < x.size: |
| 1111 | + values = np.concatenate([[x.fill_value], values]) |
| 1112 | + counts = np.concatenate([[x.size - x.nnz], counts]) |
| 1113 | + sorted_indices = np.argsort(values) |
| 1114 | + values[sorted_indices] = values.copy() |
| 1115 | + counts[sorted_indices] = counts.copy() |
| 1116 | + |
| 1117 | + return UniqueCountsResult(values, counts) |
| 1118 | + |
| 1119 | + |
| 1120 | +def unique_values(x, /): |
| 1121 | + """ |
| 1122 | + Returns the unique elements of an input array `x`. |
| 1123 | +
|
| 1124 | + Parameters |
| 1125 | + ---------- |
| 1126 | + x : COO |
| 1127 | + Input COO array. It will be flattened if it is not already 1-D. |
| 1128 | +
|
| 1129 | + Returns |
| 1130 | + ------- |
| 1131 | + out : ndarray |
| 1132 | + The unique elements of an input array. |
| 1133 | +
|
| 1134 | + Raises |
| 1135 | + ------ |
| 1136 | + ValueError |
| 1137 | + If the input array is in a different format than COO. |
| 1138 | +
|
| 1139 | + Examples |
| 1140 | + -------- |
| 1141 | + >>> import sparse |
| 1142 | + >>> x = sparse.COO.from_numpy([1, 0, 2, 1, 2, -3]) |
| 1143 | + >>> sparse.unique_values(x) |
| 1144 | + array([-3, 0, 1, 2]) |
| 1145 | + """ |
| 1146 | + from .core import COO |
| 1147 | + |
| 1148 | + if isinstance(x, scipy.sparse.spmatrix): |
| 1149 | + x = COO.from_scipy_sparse(x) |
| 1150 | + elif not isinstance(x, SparseArray): |
| 1151 | + raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.") |
| 1152 | + elif not isinstance(x, COO): |
| 1153 | + x = x.asformat(COO) |
| 1154 | + |
| 1155 | + x = x.flatten() |
| 1156 | + values = np.unique(x.data) |
| 1157 | + if x.nnz < x.size: |
| 1158 | + values = np.sort(np.concatenate([[x.fill_value], values])) |
| 1159 | + return values |
| 1160 | + |
| 1161 | + |
1062 | 1162 | @numba.jit(nopython=True, nogil=True)
|
1063 | 1163 | def _compute_minmax_args(
|
1064 | 1164 | coords: np.ndarray,
|
@@ -1121,8 +1221,12 @@ def _arg_minmax_common(
|
1121 | 1221 |
|
1122 | 1222 | from .core import COO
|
1123 | 1223 |
|
1124 |
| - if not isinstance(x, COO): |
1125 |
| - raise ValueError(f"Only COO arrays are supported but {type(x)} was passed.") |
| 1224 | + if isinstance(x, scipy.sparse.spmatrix): |
| 1225 | + x = COO.from_scipy_sparse(x) |
| 1226 | + elif not isinstance(x, SparseArray): |
| 1227 | + raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.") |
| 1228 | + elif not isinstance(x, COO): |
| 1229 | + x = x.asformat(COO) |
1126 | 1230 |
|
1127 | 1231 | if not isinstance(axis, (int, type(None))):
|
1128 | 1232 | raise ValueError(f"`axis` must be `int` or `None`, but it's: {type(axis)}.")
|
|
0 commit comments