Skip to content

Commit 6568c6b

Browse files
czgdp1807sebergeric-wieser
authored
ENH: Adding keepdims to np.argmin,np.argmax (numpy#19211)
* keepdims added to np.argmin,np.argmax * Added release notes entry * tested for axis=None,keepdims=True * Apply suggestions from code review * updated interface * updated interface * API changed, implementation to be done * Added reshape approach to C implementation * buggy implementation without reshape * TestArgMax, TestArgMin fixed, comments added * Fixed for matrix * removed unrequired changes * fixed CI failure * fixed linting issue * PyArray_ArgMaxKeepdims now only modifies shape and strides * Comments added to PyArray_ArgMaxKeepdims * Updated implementation of PyArray_ArgMinKeepdims to match with PyArray_ArgMaxKeepdims * Testing complete for PyArray_ArgMinKeepdims and PyArray_ArgMaxKeepdims * PyArray_ArgMinWithKeepdims both keepdims=True and keepdims=False * matched implementation of PyArray_ArgMaxKeepdims and PyArray_ArgMinKeepdims * simplified implementation * Added missing comment * removed unwanted header * addressed all the reviews * Removing unwanted changes * fixed docs * Added new lines * restored annotations * parametrized test * Apply suggestions from code review Co-authored-by: Sebastian Berg <[email protected]> * keyword handling now done in np.argmin/np.argmax * corrected indendation * used with pytest.riases(ValueError) * fixed release notes * removed PyArray_ArgMaxWithKeepdims and PyArray_ArgMinWithKeepdims from public C-API * Apply suggestions from code review Co-authored-by: Eric Wieser <[email protected]> * Apply suggestions from code review Co-authored-by: Eric Wieser <[email protected]> Co-authored-by: Sebastian Berg <[email protected]> Co-authored-by: Eric Wieser <[email protected]>
1 parent d785aa3 commit 6568c6b

File tree

10 files changed

+254
-33
lines changed

10 files changed

+254
-33
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
``keepdims`` optional argument added to `numpy.argmin`, `numpy.argmax`
2+
----------------------------------------------------------------------
3+
4+
``keepdims`` argument is added to `numpy.argmin`, `numpy.argmax`.
5+
If set to ``True``, the axes which are reduced are left in the result as dimensions with size one.
6+
The resulting array has the same number of dimensions and will broadcast with the
7+
input array.

numpy/__init__.pyi

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1289,37 +1289,49 @@ class _ArrayOrScalarCommon:
12891289
self,
12901290
axis: None = ...,
12911291
out: None = ...,
1292+
*,
1293+
keepdims: L[False] = ...,
12921294
) -> intp: ...
12931295
@overload
12941296
def argmax(
12951297
self,
12961298
axis: _ShapeLike = ...,
12971299
out: None = ...,
1300+
*,
1301+
keepdims: bool = ...,
12981302
) -> Any: ...
12991303
@overload
13001304
def argmax(
13011305
self,
13021306
axis: Optional[_ShapeLike] = ...,
13031307
out: _NdArraySubClass = ...,
1308+
*,
1309+
keepdims: bool = ...,
13041310
) -> _NdArraySubClass: ...
13051311

13061312
@overload
13071313
def argmin(
13081314
self,
13091315
axis: None = ...,
13101316
out: None = ...,
1317+
*,
1318+
keepdims: L[False] = ...,
13111319
) -> intp: ...
13121320
@overload
13131321
def argmin(
13141322
self,
13151323
axis: _ShapeLike = ...,
1316-
out: None = ...,
1324+
out: None = ...,
1325+
*,
1326+
keepdims: bool = ...,
13171327
) -> Any: ...
13181328
@overload
13191329
def argmin(
13201330
self,
13211331
axis: Optional[_ShapeLike] = ...,
13221332
out: _NdArraySubClass = ...,
1333+
*,
1334+
keepdims: bool = ...,
13231335
) -> _NdArraySubClass: ...
13241336

13251337
def argsort(

numpy/core/fromnumeric.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,12 +1114,12 @@ def argsort(a, axis=-1, kind=None, order=None):
11141114
return _wrapfunc(a, 'argsort', axis=axis, kind=kind, order=order)
11151115

11161116

1117-
def _argmax_dispatcher(a, axis=None, out=None):
1117+
def _argmax_dispatcher(a, axis=None, out=None, *, keepdims=np._NoValue):
11181118
return (a, out)
11191119

11201120

11211121
@array_function_dispatch(_argmax_dispatcher)
1122-
def argmax(a, axis=None, out=None):
1122+
def argmax(a, axis=None, out=None, *, keepdims=np._NoValue):
11231123
"""
11241124
Returns the indices of the maximum values along an axis.
11251125
@@ -1133,12 +1133,18 @@ def argmax(a, axis=None, out=None):
11331133
out : array, optional
11341134
If provided, the result will be inserted into this array. It should
11351135
be of the appropriate shape and dtype.
1136+
keepdims : bool, optional
1137+
If this is set to True, the axes which are reduced are left
1138+
in the result as dimensions with size one. With this option,
1139+
the result will broadcast correctly against the array.
11361140
11371141
Returns
11381142
-------
11391143
index_array : ndarray of ints
11401144
Array of indices into the array. It has the same shape as `a.shape`
1141-
with the dimension along `axis` removed.
1145+
with the dimension along `axis` removed. If `keepdims` is set to True,
1146+
then the size of `axis` will be 1 with the resulting array having same
1147+
shape as `a.shape`.
11421148
11431149
See Also
11441150
--------
@@ -1191,16 +1197,23 @@ def argmax(a, axis=None, out=None):
11911197
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
11921198
array([4, 3])
11931199
1200+
Setting `keepdims` to `True`,
1201+
1202+
>>> x = np.arange(24).reshape((2, 3, 4))
1203+
>>> res = np.argmax(x, axis=1, keepdims=True)
1204+
>>> res.shape
1205+
(2, 1, 4)
11941206
"""
1195-
return _wrapfunc(a, 'argmax', axis=axis, out=out)
1207+
kwds = {'keepdims': keepdims} if keepdims is not np._NoValue else {}
1208+
return _wrapfunc(a, 'argmax', axis=axis, out=out, **kwds)
11961209

11971210

1198-
def _argmin_dispatcher(a, axis=None, out=None):
1211+
def _argmin_dispatcher(a, axis=None, out=None, *, keepdims=np._NoValue):
11991212
return (a, out)
12001213

12011214

12021215
@array_function_dispatch(_argmin_dispatcher)
1203-
def argmin(a, axis=None, out=None):
1216+
def argmin(a, axis=None, out=None, *, keepdims=np._NoValue):
12041217
"""
12051218
Returns the indices of the minimum values along an axis.
12061219
@@ -1214,12 +1227,18 @@ def argmin(a, axis=None, out=None):
12141227
out : array, optional
12151228
If provided, the result will be inserted into this array. It should
12161229
be of the appropriate shape and dtype.
1230+
keepdims : bool, optional
1231+
If this is set to True, the axes which are reduced are left
1232+
in the result as dimensions with size one. With this option,
1233+
the result will broadcast correctly against the array.
12171234
12181235
Returns
12191236
-------
12201237
index_array : ndarray of ints
12211238
Array of indices into the array. It has the same shape as `a.shape`
1222-
with the dimension along `axis` removed.
1239+
with the dimension along `axis` removed. If `keepdims` is set to True,
1240+
then the size of `axis` will be 1 with the resulting array having same
1241+
shape as `a.shape`.
12231242
12241243
See Also
12251244
--------
@@ -1272,8 +1291,15 @@ def argmin(a, axis=None, out=None):
12721291
>>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
12731292
array([2, 0])
12741293
1294+
Setting `keepdims` to `True`,
1295+
1296+
>>> x = np.arange(24).reshape((2, 3, 4))
1297+
>>> res = np.argmin(x, axis=1, keepdims=True)
1298+
>>> res.shape
1299+
(2, 1, 4)
12751300
"""
1276-
return _wrapfunc(a, 'argmin', axis=axis, out=out)
1301+
kwds = {'keepdims': keepdims} if keepdims is not np._NoValue else {}
1302+
return _wrapfunc(a, 'argmin', axis=axis, out=out, **kwds)
12771303

12781304

12791305
def _searchsorted_dispatcher(a, v, side=None, sorter=None):

numpy/core/fromnumeric.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,25 +130,33 @@ def argmax(
130130
a: ArrayLike,
131131
axis: None = ...,
132132
out: Optional[ndarray] = ...,
133+
*,
134+
keepdims: Literal[False] = ...,
133135
) -> intp: ...
134136
@overload
135137
def argmax(
136138
a: ArrayLike,
137139
axis: Optional[int] = ...,
138140
out: Optional[ndarray] = ...,
141+
*,
142+
keepdims: bool = ...,
139143
) -> Any: ...
140144

141145
@overload
142146
def argmin(
143147
a: ArrayLike,
144148
axis: None = ...,
145149
out: Optional[ndarray] = ...,
150+
*,
151+
keepdims: Literal[False] = ...,
146152
) -> intp: ...
147153
@overload
148154
def argmin(
149155
a: ArrayLike,
150156
axis: Optional[int] = ...,
151157
out: Optional[ndarray] = ...,
158+
*,
159+
keepdims: bool = ...,
152160
) -> Any: ...
153161

154162
@overload

numpy/core/src/multiarray/calculation.c

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,24 @@ power_of_ten(int n)
3434
return ret;
3535
}
3636

37-
/*NUMPY_API
38-
* ArgMax
39-
*/
4037
NPY_NO_EXPORT PyObject *
41-
PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
38+
_PyArray_ArgMaxWithKeepdims(PyArrayObject *op,
39+
int axis, PyArrayObject *out, int keepdims)
4240
{
4341
PyArrayObject *ap = NULL, *rp = NULL;
4442
PyArray_ArgFunc* arg_func;
4543
char *ip;
4644
npy_intp *rptr;
4745
npy_intp i, n, m;
4846
int elsize;
47+
// Keep a copy because axis changes via call to PyArray_CheckAxis
48+
int axis_copy = axis;
49+
npy_intp _shape_buf[NPY_MAXDIMS];
50+
npy_intp *out_shape;
51+
// Keep the number of dimensions and shape of
52+
// original array. Helps when `keepdims` is True.
53+
npy_intp* original_op_shape = PyArray_DIMS(op);
54+
int out_ndim = PyArray_NDIM(op);
4955
NPY_BEGIN_THREADS_DEF;
5056

5157
if ((ap = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) {
@@ -86,6 +92,29 @@ PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
8692
if (ap == NULL) {
8793
return NULL;
8894
}
95+
96+
// Decides the shape of the output array.
97+
if (!keepdims) {
98+
out_ndim = PyArray_NDIM(ap) - 1;
99+
out_shape = PyArray_DIMS(ap);
100+
}
101+
else {
102+
out_shape = _shape_buf;
103+
if (axis_copy == NPY_MAXDIMS) {
104+
for (int i = 0; i < out_ndim; i++) {
105+
out_shape[i] = 1;
106+
}
107+
}
108+
else {
109+
/*
110+
* While `ap` may be transposed, we can ignore this for `out` because the
111+
* transpose only reorders the size 1 `axis` (not changing memory layout).
112+
*/
113+
memcpy(out_shape, original_op_shape, out_ndim * sizeof(npy_intp));
114+
out_shape[axis] = 1;
115+
}
116+
}
117+
89118
arg_func = PyArray_DESCR(ap)->f->argmax;
90119
if (arg_func == NULL) {
91120
PyErr_SetString(PyExc_TypeError,
@@ -103,16 +132,16 @@ PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
103132
if (!out) {
104133
rp = (PyArrayObject *)PyArray_NewFromDescr(
105134
Py_TYPE(ap), PyArray_DescrFromType(NPY_INTP),
106-
PyArray_NDIM(ap) - 1, PyArray_DIMS(ap), NULL, NULL,
135+
out_ndim, out_shape, NULL, NULL,
107136
0, (PyObject *)ap);
108137
if (rp == NULL) {
109138
goto fail;
110139
}
111140
}
112141
else {
113-
if ((PyArray_NDIM(out) != PyArray_NDIM(ap) - 1) ||
114-
!PyArray_CompareLists(PyArray_DIMS(out), PyArray_DIMS(ap),
115-
PyArray_NDIM(out))) {
142+
if ((PyArray_NDIM(out) != out_ndim) ||
143+
!PyArray_CompareLists(PyArray_DIMS(out), out_shape,
144+
out_ndim)) {
116145
PyErr_SetString(PyExc_ValueError,
117146
"output array does not match result of np.argmax.");
118147
goto fail;
@@ -135,7 +164,7 @@ PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
135164
NPY_END_THREADS_DESCR(PyArray_DESCR(ap));
136165

137166
Py_DECREF(ap);
138-
/* Trigger the UPDATEIFCOPY/WRTIEBACKIFCOPY if necessary */
167+
/* Trigger the UPDATEIFCOPY/WRITEBACKIFCOPY if necessary */
139168
if (out != NULL && out != rp) {
140169
PyArray_ResolveWritebackIfCopy(rp);
141170
Py_DECREF(rp);
@@ -151,17 +180,32 @@ PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
151180
}
152181

153182
/*NUMPY_API
154-
* ArgMin
183+
* ArgMax
155184
*/
156185
NPY_NO_EXPORT PyObject *
157-
PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
186+
PyArray_ArgMax(PyArrayObject *op, int axis, PyArrayObject *out)
187+
{
188+
return _PyArray_ArgMaxWithKeepdims(op, axis, out, 0);
189+
}
190+
191+
NPY_NO_EXPORT PyObject *
192+
_PyArray_ArgMinWithKeepdims(PyArrayObject *op,
193+
int axis, PyArrayObject *out, int keepdims)
158194
{
159195
PyArrayObject *ap = NULL, *rp = NULL;
160196
PyArray_ArgFunc* arg_func;
161197
char *ip;
162198
npy_intp *rptr;
163199
npy_intp i, n, m;
164200
int elsize;
201+
// Keep a copy because axis changes via call to PyArray_CheckAxis
202+
int axis_copy = axis;
203+
npy_intp _shape_buf[NPY_MAXDIMS];
204+
npy_intp *out_shape;
205+
// Keep the number of dimensions and shape of
206+
// original array. Helps when `keepdims` is True.
207+
npy_intp* original_op_shape = PyArray_DIMS(op);
208+
int out_ndim = PyArray_NDIM(op);
165209
NPY_BEGIN_THREADS_DEF;
166210

167211
if ((ap = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0)) == NULL) {
@@ -202,6 +246,27 @@ PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
202246
if (ap == NULL) {
203247
return NULL;
204248
}
249+
250+
// Decides the shape of the output array.
251+
if (!keepdims) {
252+
out_ndim = PyArray_NDIM(ap) - 1;
253+
out_shape = PyArray_DIMS(ap);
254+
} else {
255+
out_shape = _shape_buf;
256+
if (axis_copy == NPY_MAXDIMS) {
257+
for (int i = 0; i < out_ndim; i++) {
258+
out_shape[i] = 1;
259+
}
260+
} else {
261+
/*
262+
* While `ap` may be transposed, we can ignore this for `out` because the
263+
* transpose only reorders the size 1 `axis` (not changing memory layout).
264+
*/
265+
memcpy(out_shape, original_op_shape, out_ndim * sizeof(npy_intp));
266+
out_shape[axis] = 1;
267+
}
268+
}
269+
205270
arg_func = PyArray_DESCR(ap)->f->argmin;
206271
if (arg_func == NULL) {
207272
PyErr_SetString(PyExc_TypeError,
@@ -219,16 +284,15 @@ PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
219284
if (!out) {
220285
rp = (PyArrayObject *)PyArray_NewFromDescr(
221286
Py_TYPE(ap), PyArray_DescrFromType(NPY_INTP),
222-
PyArray_NDIM(ap) - 1, PyArray_DIMS(ap), NULL, NULL,
287+
out_ndim, out_shape, NULL, NULL,
223288
0, (PyObject *)ap);
224289
if (rp == NULL) {
225290
goto fail;
226291
}
227292
}
228293
else {
229-
if ((PyArray_NDIM(out) != PyArray_NDIM(ap) - 1) ||
230-
!PyArray_CompareLists(PyArray_DIMS(out), PyArray_DIMS(ap),
231-
PyArray_NDIM(out))) {
294+
if ((PyArray_NDIM(out) != out_ndim) ||
295+
!PyArray_CompareLists(PyArray_DIMS(out), out_shape, out_ndim)) {
232296
PyErr_SetString(PyExc_ValueError,
233297
"output array does not match result of np.argmin.");
234298
goto fail;
@@ -266,6 +330,15 @@ PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
266330
return NULL;
267331
}
268332

333+
/*NUMPY_API
334+
* ArgMin
335+
*/
336+
NPY_NO_EXPORT PyObject *
337+
PyArray_ArgMin(PyArrayObject *op, int axis, PyArrayObject *out)
338+
{
339+
return _PyArray_ArgMinWithKeepdims(op, axis, out, 0);
340+
}
341+
269342
/*NUMPY_API
270343
* Max
271344
*/

numpy/core/src/multiarray/calculation.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44
NPY_NO_EXPORT PyObject*
55
PyArray_ArgMax(PyArrayObject* self, int axis, PyArrayObject *out);
66

7+
NPY_NO_EXPORT PyObject*
8+
_PyArray_ArgMaxWithKeepdims(PyArrayObject* self, int axis, PyArrayObject *out, int keepdims);
9+
710
NPY_NO_EXPORT PyObject*
811
PyArray_ArgMin(PyArrayObject* self, int axis, PyArrayObject *out);
912

13+
NPY_NO_EXPORT PyObject*
14+
_PyArray_ArgMinWithKeepdims(PyArrayObject* self, int axis, PyArrayObject *out, int keepdims);
15+
1016
NPY_NO_EXPORT PyObject*
1117
PyArray_Max(PyArrayObject* self, int axis, PyArrayObject* out);
1218

0 commit comments

Comments
 (0)