Skip to content

Commit 82fb0d5

Browse files
authored
API: Add expand_dims and flip functions for COO format (#629)
1 parent f367b76 commit 82fb0d5

File tree

6 files changed

+165
-17
lines changed

6 files changed

+165
-17
lines changed

docs/generated/sparse.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,12 @@ API
7070

7171
equal
7272

73+
expand_dims
74+
7375
eye
7476

77+
flip
78+
7579
full
7680

7781
full_like
@@ -158,6 +162,8 @@ API
158162

159163
unique_values
160164

165+
var
166+
161167
where
162168

163169
zeros

sparse/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
std,
117117
sum,
118118
tensordot,
119+
var,
119120
zeros,
120121
zeros_like,
121122
)
@@ -129,6 +130,8 @@
129130
clip,
130131
diagonal,
131132
diagonalize,
133+
expand_dims,
134+
flip,
132135
isneginf,
133136
isposinf,
134137
kron,
@@ -206,9 +209,11 @@
206209
"empty_like",
207210
"equal",
208211
"exp",
212+
"expand_dims",
209213
"expm1",
210214
"eye",
211215
"finfo",
216+
"flip",
212217
"float16",
213218
"float32",
214219
"float64",
@@ -297,6 +302,7 @@
297302
"uint8",
298303
"unique_counts",
299304
"unique_values",
305+
"var",
300306
"where",
301307
"zeros",
302308
"zeros_like",

sparse/_coo/__init__.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
concatenate,
77
diagonal,
88
diagonalize,
9+
expand_dims,
10+
flip,
911
isneginf,
1012
isposinf,
1113
kron,
@@ -29,28 +31,30 @@
2931
__all__ = [
3032
"COO",
3133
"as_coo",
32-
"concatenate",
34+
"argmax",
35+
"argmin",
36+
"argwhere",
3337
"clip",
34-
"stack",
35-
"triu",
36-
"tril",
37-
"where",
38-
"nansum",
38+
"concatenate",
39+
"diagonal",
40+
"diagonalize",
41+
"expand_dims",
42+
"flip",
43+
"isneginf",
44+
"isposinf",
45+
"kron",
46+
"nanmax",
3947
"nanmean",
40-
"nanprod",
4148
"nanmin",
42-
"nanmax",
49+
"nanprod",
4350
"nanreduce",
44-
"roll",
45-
"kron",
46-
"argwhere",
47-
"argmax",
48-
"argmin",
49-
"isposinf",
50-
"isneginf",
51+
"nansum",
5152
"result_type",
52-
"diagonal",
53-
"diagonalize",
53+
"roll",
54+
"stack",
55+
"tril",
56+
"triu",
5457
"unique_counts",
5558
"unique_values",
59+
"where",
5660
]

sparse/_coo/common.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,113 @@ def clip(a, a_min=None, a_max=None, out=None):
10591059
return a.clip(a_min, a_max)
10601060

10611061

1062+
def expand_dims(x, /, *, axis=0):
1063+
"""
1064+
Expands the shape of an array by inserting a new axis (dimension) of size
1065+
one at the position specified by ``axis``.
1066+
1067+
Parameters
1068+
----------
1069+
a : COO
1070+
Input COO array.
1071+
axis : int
1072+
Position in the expanded axes where the new axis is placed.
1073+
1074+
Returns
1075+
-------
1076+
result : COO
1077+
An expanded output COO array having the same data type as ``x``.
1078+
1079+
Examples
1080+
--------
1081+
>>> import sparse
1082+
>>> x = sparse.COO.from_numpy([[1, 0, 0, 0, 2, -3]])
1083+
>>> x.shape
1084+
(1, 6)
1085+
>>> y1 = sparse.expand_dims(x, axis=1)
1086+
>>> y1.shape
1087+
(1, 1, 6)
1088+
>>> y2 = sparse.expand_dims(x, axis=2)
1089+
>>> y2.shape
1090+
(1, 6, 1)
1091+
1092+
"""
1093+
from .core import COO
1094+
1095+
if isinstance(x, scipy.sparse.spmatrix):
1096+
x = COO.from_scipy_sparse(x)
1097+
elif not isinstance(x, SparseArray):
1098+
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1099+
elif not isinstance(x, COO):
1100+
x = x.asformat(COO)
1101+
1102+
if not isinstance(axis, int):
1103+
raise IndexError(f"Invalid axis position: {axis}")
1104+
1105+
axis = normalize_axis(axis, x.ndim + 1)
1106+
1107+
new_coords = np.insert(x.coords, obj=axis, values=np.zeros(x.nnz, dtype=np.intp), axis=0)
1108+
new_shape = list(x.shape)
1109+
new_shape.insert(axis, 1)
1110+
new_shape = tuple(new_shape)
1111+
1112+
return COO(
1113+
new_coords,
1114+
x.data,
1115+
shape=new_shape,
1116+
fill_value=x.fill_value,
1117+
)
1118+
1119+
1120+
def flip(x, /, *, axis=None):
1121+
"""
1122+
Reverses the order of elements in an array along the given axis.
1123+
1124+
The shape of the array is preserved.
1125+
1126+
Parameters
1127+
----------
1128+
a : COO
1129+
Input COO array.
1130+
axis : int or tuple of ints, optional
1131+
Axis (or axes) along which to flip. If ``axis`` is ``None``, the function must
1132+
flip all input array axes. If ``axis`` is negative, the function must count from
1133+
the last dimension. If provided more than one axis, the function must flip only
1134+
the specified axes. Default: ``None``.
1135+
1136+
Returns
1137+
-------
1138+
result : COO
1139+
An output array having the same data type and shape as ``x`` and whose elements,
1140+
relative to ``x``, are reordered.
1141+
1142+
"""
1143+
from .core import COO
1144+
1145+
if isinstance(x, scipy.sparse.spmatrix):
1146+
x = COO.from_scipy_sparse(x)
1147+
elif not isinstance(x, SparseArray):
1148+
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
1149+
elif not isinstance(x, COO):
1150+
x = x.asformat(COO)
1151+
1152+
if axis is None:
1153+
axis = range(x.ndim)
1154+
if not isinstance(axis, Iterable):
1155+
axis = (axis,)
1156+
1157+
new_coords = x.coords.copy()
1158+
for ax in axis:
1159+
new_coords[ax, :] = x.shape[ax] - 1 - x.coords[ax, :]
1160+
1161+
return COO(
1162+
new_coords,
1163+
x.data,
1164+
shape=x.shape,
1165+
fill_value=x.fill_value,
1166+
)
1167+
1168+
10621169
# Array API set functions
10631170

10641171

sparse/tests/test_coo.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,3 +1777,25 @@ def test_unique_values(self, arr, fill_value):
17771777
def test_input_validation(self, func):
17781778
with pytest.raises(ValueError, match=r"Input must be an instance of SparseArray"):
17791779
func(self.arr)
1780+
1781+
1782+
@pytest.mark.parametrize("axis", [-1, 0, 1, 2, 3])
1783+
def test_expand_dims(axis):
1784+
arr = np.arange(24).reshape((2, 3, 4))
1785+
s_arr = sparse.COO.from_numpy(arr)
1786+
1787+
result = sparse.expand_dims(s_arr, axis=axis)
1788+
expected = np.expand_dims(arr, axis=axis)
1789+
1790+
np.testing.assert_equal(result.todense(), expected)
1791+
1792+
1793+
@pytest.mark.parametrize("axis", [None, -1, 0, 1, 2, (0, 1), (2, 0)])
1794+
def test_flip(axis):
1795+
arr = np.arange(24).reshape((2, 3, 4))
1796+
s_arr = sparse.COO.from_numpy(arr)
1797+
1798+
result = sparse.flip(s_arr, axis=axis)
1799+
expected = np.flip(arr, axis=axis)
1800+
1801+
np.testing.assert_equal(result.todense(), expected)

sparse/tests/test_namespace.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,11 @@ def test_namespace():
5656
"empty_like",
5757
"equal",
5858
"exp",
59+
"expand_dims",
5960
"expm1",
6061
"eye",
6162
"finfo",
63+
"flip",
6264
"float16",
6365
"float32",
6466
"float64",
@@ -147,6 +149,7 @@ def test_namespace():
147149
"uint8",
148150
"unique_counts",
149151
"unique_values",
152+
"var",
150153
"where",
151154
"zeros",
152155
"zeros_like",

0 commit comments

Comments
 (0)