Skip to content

Commit 8cf1f7d

Browse files
ENH: Let numpy.size accept multiple axes. (numpy#29240)
* Let numpy.size accept multiple axes. * Apply suggestions from code review --------- Co-authored-by: Sebastian Berg <[email protected]>
1 parent 26f1c91 commit 8cf1f7d

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
* Let ``np.size`` accept multiple axes.

numpy/_core/fromnumeric.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
"""
44
import functools
5+
import math
56
import types
67
import warnings
78

@@ -3569,10 +3570,13 @@ def size(a, axis=None):
35693570
----------
35703571
a : array_like
35713572
Input data.
3572-
axis : int, optional
3573-
Axis along which the elements are counted. By default, give
3573+
axis : None or int or tuple of ints, optional
3574+
Axis or axes along which the elements are counted. By default, give
35743575
the total number of elements.
35753576
3577+
.. versionchanged:: 2.4
3578+
Extended to accept multiple axes.
3579+
35763580
Returns
35773581
-------
35783582
element_count : int
@@ -3590,10 +3594,12 @@ def size(a, axis=None):
35903594
>>> a = np.array([[1,2,3],[4,5,6]])
35913595
>>> np.size(a)
35923596
6
3593-
>>> np.size(a,1)
3597+
>>> np.size(a,axis=1)
35943598
3
3595-
>>> np.size(a,0)
3599+
>>> np.size(a,axis=0)
35963600
2
3601+
>>> np.size(a,axis=(0,1))
3602+
6
35973603
35983604
"""
35993605
if axis is None:
@@ -3602,10 +3608,10 @@ def size(a, axis=None):
36023608
except AttributeError:
36033609
return asarray(a).size
36043610
else:
3605-
try:
3606-
return a.shape[axis]
3607-
except AttributeError:
3608-
return asarray(a).shape[axis]
3611+
_shape = shape(a)
3612+
from .numeric import normalize_axis_tuple
3613+
axis = normalize_axis_tuple(axis, len(_shape), allow_duplicate=False)
3614+
return math.prod(_shape[ax] for ax in axis)
36093615

36103616

36113617
def _round_dispatcher(a, decimals=None, out=None):

numpy/_core/fromnumeric.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1397,7 +1397,7 @@ def cumulative_prod(
13971397

13981398
def ndim(a: ArrayLike) -> int: ...
13991399

1400-
def size(a: ArrayLike, axis: int | None = ...) -> int: ...
1400+
def size(a: ArrayLike, axis: int | tuple[int, ...] | None = ...) -> int: ...
14011401

14021402
@overload
14031403
def around(

numpy/_core/tests/test_numeric.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,10 @@ def test_size(self):
291291
assert_(np.size(A) == 6)
292292
assert_(np.size(A, 0) == 2)
293293
assert_(np.size(A, 1) == 3)
294+
assert_(np.size(A, ()) == 1)
295+
assert_(np.size(A, (0,)) == 2)
296+
assert_(np.size(A, (1,)) == 3)
297+
assert_(np.size(A, (0, 1)) == 6)
294298

295299
def test_squeeze(self):
296300
A = [[[1, 1, 1], [2, 2, 2], [3, 3, 3]]]

0 commit comments

Comments
 (0)