@@ -41,14 +41,21 @@ def unique_all(x: Array, /) -> UniqueAllResult:
41
41
42
42
See its docstring for more information.
43
43
"""
44
- res = np .unique (
44
+ values , indices , inverse_indices , counts = np .unique (
45
45
x ._array ,
46
46
return_counts = True ,
47
47
return_index = True ,
48
48
return_inverse = True ,
49
49
)
50
-
51
- return UniqueAllResult (* [Array ._new (i ) for i in res ])
50
+ # np.unique() flattens inverse indices, but they need to share x's shape
51
+ # See https://github.com/numpy/numpy/issues/20638
52
+ inverse_indices = inverse_indices .reshape (x .shape )
53
+ return UniqueAllResult (
54
+ Array ._new (values ),
55
+ Array ._new (indices ),
56
+ Array ._new (inverse_indices ),
57
+ Array ._new (counts ),
58
+ )
52
59
53
60
54
61
def unique_counts (x : Array , / ) -> UniqueCountsResult :
@@ -68,13 +75,16 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult:
68
75
69
76
See its docstring for more information.
70
77
"""
71
- res = np .unique (
78
+ values , inverse_indices = np .unique (
72
79
x ._array ,
73
80
return_counts = False ,
74
81
return_index = False ,
75
82
return_inverse = True ,
76
83
)
77
- return UniqueInverseResult (* [Array ._new (i ) for i in res ])
84
+ # np.unique() flattens inverse indices, but they need to share x's shape
85
+ # See https://github.com/numpy/numpy/issues/20638
86
+ inverse_indices = inverse_indices .reshape (x .shape )
87
+ return UniqueInverseResult (Array ._new (values ), Array ._new (inverse_indices ))
78
88
79
89
80
90
def unique_values (x : Array , / ) -> Array :
0 commit comments