Skip to content

Commit b4f2cdf

Browse files
committed
BUG: Return correctly shaped inverse indices in array_api
Specifically for `xp.unique_all()` and `xp.unique_inverse()` Original NumPy Commit: 8b967ff2e70afe3a1fd32e33b36f66f34c259139
1 parent c84e019 commit b4f2cdf

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

array_api_strict/_set_functions.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,21 @@ def unique_all(x: Array, /) -> UniqueAllResult:
4141
4242
See its docstring for more information.
4343
"""
44-
res = np.unique(
44+
values, indices, inverse_indices, counts = np.unique(
4545
x._array,
4646
return_counts=True,
4747
return_index=True,
4848
return_inverse=True,
4949
)
50-
51-
return UniqueAllResult(*[Array._new(i) for i in res])
50+
# np.unique() flattens inverse indices, but they need to share x's shape
51+
# See https://github.com/numpy/numpy/issues/20638
52+
inverse_indices = inverse_indices.reshape(x.shape)
53+
return UniqueAllResult(
54+
Array._new(values),
55+
Array._new(indices),
56+
Array._new(inverse_indices),
57+
Array._new(counts),
58+
)
5259

5360

5461
def unique_counts(x: Array, /) -> UniqueCountsResult:
@@ -68,13 +75,16 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult:
6875
6976
See its docstring for more information.
7077
"""
71-
res = np.unique(
78+
values, inverse_indices = np.unique(
7279
x._array,
7380
return_counts=False,
7481
return_index=False,
7582
return_inverse=True,
7683
)
77-
return UniqueInverseResult(*[Array._new(i) for i in res])
84+
# np.unique() flattens inverse indices, but they need to share x's shape
85+
# See https://github.com/numpy/numpy/issues/20638
86+
inverse_indices = inverse_indices.reshape(x.shape)
87+
return UniqueInverseResult(Array._new(values), Array._new(inverse_indices))
7888

7989

8090
def unique_values(x: Array, /) -> Array:

0 commit comments

Comments
 (0)