Skip to content

Commit 963acb4

Browse files
authored
Merge branch 'master' into enable-test-dcptl-2121
2 parents 351ede2 + c67cafb commit 963acb4

File tree

7 files changed

+42
-25
lines changed

7 files changed

+42
-25
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3030
* Changed th order of individual FFTs over `axes` for `dpnp.fft.irfftn` to be in forward order [#2524](https://github.com/IntelPython/dpnp/pull/2524)
3131
* Replaced the use of `numpy.testing.suppress_warnings` with appropriate calls from the warnings module [#2529](https://github.com/IntelPython/dpnp/pull/2529)
3232
* Improved documentations of `dpnp.ndarray` class and added a page with description of supported constants [#2422](https://github.com/IntelPython/dpnp/pull/2422)
33+
* Updated `dpnp.size` to accept tuple of ints for `axes` argument [#2536](https://github.com/IntelPython/dpnp/pull/2536)
3334

3435
### Deprecated
3536

conda-recipe/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{% set max_compiler_and_mkl_version = environ.get("MAX_BUILD_CMPL_MKL_VERSION", "2026.0a0") %}
22
{% set required_compiler_and_mkl_version = "2025.0" %}
3-
{% set required_dpctl_version = "0.20.0*" %}
3+
{% set required_dpctl_version = "0.21.0*" %}
44

55
{% set pyproject = load_file_data('pyproject.toml') %}
66
{% set py_build_deps = pyproject.get('build-system', {}).get('requires', []) %}

dpnp/dpnp_iface_manipulation.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
import dpctl
4747
import dpctl.tensor as dpt
4848
import numpy
49-
from dpctl.tensor._numpy_helper import AxisError, normalize_axis_index
49+
from dpctl.tensor._numpy_helper import (
50+
AxisError,
51+
normalize_axis_index,
52+
normalize_axis_tuple,
53+
)
5054

5155
import dpnp
5256

@@ -3528,8 +3532,8 @@ def size(a, axis=None):
35283532
----------
35293533
a : array_like
35303534
Input data.
3531-
axis : {None, int}, optional
3532-
Axis along which the elements are counted.
3535+
axis : {None, int, tuple of ints}, optional
3536+
Axis or axes along which the elements are counted.
35333537
By default, give the total number of elements.
35343538
35353539
Default: ``None``.
@@ -3551,23 +3555,21 @@ def size(a, axis=None):
35513555
>>> a = [[1, 2, 3], [4, 5, 6]]
35523556
>>> np.size(a)
35533557
6
3554-
>>> np.size(a, 1)
3558+
>>> np.size(a, axis=1)
35553559
3
3556-
>>> np.size(a, 0)
3560+
>>> np.size(a, axis=0)
35573561
2
3558-
3559-
>>> a = np.asarray(a)
3560-
>>> np.size(a)
3562+
>>> np.size(a, axis=(0, 1))
35613563
6
3562-
>>> np.size(a, 1)
3563-
3
35643564
35653565
"""
35663566

35673567
if dpnp.is_supported_array_type(a):
35683568
if axis is None:
35693569
return a.size
3570-
return a.shape[axis]
3570+
_shape = a.shape
3571+
_axis = normalize_axis_tuple(axis, a.ndim)
3572+
return math.prod(_shape[ax] for ax in _axis)
35713573

35723574
return numpy.size(a, axis)
35733575

dpnp/tests/test_manipulation.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,32 @@ def test_ndim():
7474
assert dpnp.ndim(ia) == exp
7575

7676

77-
def test_size():
78-
a = [[1, 2, 3], [4, 5, 6]]
79-
ia = dpnp.array(a)
77+
class TestSize:
78+
def test_size(self):
79+
a = [[1, 2, 3], [4, 5, 6]]
80+
ia = dpnp.array(a)
81+
82+
exp = numpy.size(a)
83+
assert ia.size == exp
84+
assert dpnp.size(a) == exp
85+
assert dpnp.size(ia) == exp
8086

81-
exp = numpy.size(a)
82-
assert ia.size == exp
83-
assert dpnp.size(a) == exp
84-
assert dpnp.size(ia) == exp
87+
exp = numpy.size(a, 0)
88+
assert dpnp.size(a, 0) == exp
89+
assert dpnp.size(ia, 0) == exp
90+
91+
assert dpnp.size(ia, 1) == numpy.size(a, 1)
92+
93+
# TODO: include commented code in the test when numpy-2.4 is released
94+
# @testing.with_requires("numpy>=2.4")
95+
def test_size_tuple(self):
96+
a = [[1, 2, 3], [4, 5, 6]]
97+
ia = dpnp.array(a)
8598

86-
exp = numpy.size(a, 0)
87-
assert dpnp.size(a, 0) == exp
88-
assert dpnp.size(ia, 0) == exp
99+
assert dpnp.size(ia, ()) == 1 # numpy.size(a, ())
100+
assert dpnp.size(ia, (0,)) == 2 # numpy.size(a, (0,))
101+
assert dpnp.size(ia, (1,)) == 3 # numpy.size(a, (1,))
102+
assert dpnp.size(ia, (0, 1)) == 6 # numpy.size(a, (0, 1))
89103

90104

91105
class TestAppend:

environments/dpctl_pkg.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
--index-url https://pypi.anaconda.org/dppy/label/dev/simple
2-
dpctl>=0.20.0dev0
2+
dpctl>=0.21.0dev0

environments/dpctl_pkg.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ name: Install dpctl package
22
channels:
33
- dppy/label/dev
44
dependencies:
5-
- dpctl>=0.20.0dev0
5+
- dpctl>=0.21.0dev0

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ dependencies = [
5050
# "dpcpp-cpp-rt>=0.59.0",
5151
# "intel-cmplr-lib-rt>=0.59.0"
5252
# WARNING: use the latest dpctl dev version, otherwise stable w/f will fail
53-
"dpctl>=0.20.0dev0",
53+
"dpctl>=0.21.0dev0",
5454
"numpy>=1.25.0"
5555
]
5656
description = "Data Parallel Extension for NumPy"

0 commit comments

Comments
 (0)