Skip to content

Commit 2b4482a

Browse files
committed
Add correction keyword to std/var functions
1 parent 63ce858 commit 2b4482a

File tree

7 files changed

+119
-7
lines changed

7 files changed

+119
-7
lines changed

.github/workflows/array-api-skips.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,3 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_clip
2929
# unexpected result is returned - unmute when dpctl-1986 is resolved
3030
array_api_tests/test_operators_and_elementwise_functions.py::test_asin
3131
array_api_tests/test_operators_and_elementwise_functions.py::test_asinh
32-
33-
# missing 'correction' keyword argument
34-
array_api_tests/test_signatures.py::test_func_signature[std]
35-
array_api_tests/test_signatures.py::test_func_signature[var]

dpnp/dpnp_array.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,6 +1732,7 @@ def std(
17321732
*,
17331733
where=True,
17341734
mean=None,
1735+
correction=None,
17351736
):
17361737
"""
17371738
Returns the standard deviation of the array elements, along given axis.
@@ -1741,7 +1742,15 @@ def std(
17411742
"""
17421743

17431744
return dpnp.std(
1744-
self, axis, dtype, out, ddof, keepdims, where=where, mean=mean
1745+
self,
1746+
axis,
1747+
dtype,
1748+
out,
1749+
ddof,
1750+
keepdims,
1751+
where=where,
1752+
mean=mean,
1753+
correction=correction,
17451754
)
17461755

17471756
@property
@@ -1942,6 +1951,7 @@ def var(
19421951
*,
19431952
where=True,
19441953
mean=None,
1954+
correction=None,
19451955
):
19461956
"""
19471957
Returns the variance of the array elements, along given axis.
@@ -1951,7 +1961,15 @@ def var(
19511961
"""
19521962

19531963
return dpnp.var(
1954-
self, axis, dtype, out, ddof, keepdims, where=where, mean=mean
1964+
self,
1965+
axis,
1966+
dtype,
1967+
out,
1968+
ddof,
1969+
keepdims,
1970+
where=where,
1971+
mean=mean,
1972+
correction=correction,
19551973
)
19561974

19571975

dpnp/dpnp_iface_nanfunctions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,7 @@ def nanstd(
966966
*,
967967
where=True,
968968
mean=None,
969+
correction=None,
969970
):
970971
"""
971972
Compute the standard deviation along the specified axis,
@@ -1018,6 +1019,12 @@ def nanstd(
10181019
10191020
Default: ``None``.
10201021
1022+
correction : {int, float}, optional
1023+
Array API compatible name for the `ddof` parameter. Only one of them
1024+
can be provided at the same time.
1025+
1026+
Default: ``None``.
1027+
10211028
Returns
10221029
-------
10231030
out : dpnp.ndarray
@@ -1094,6 +1101,7 @@ def nanstd(
10941101
keepdims=keepdims,
10951102
where=where,
10961103
mean=mean,
1104+
correction=correction,
10971105
)
10981106
return dpnp.sqrt(res, out=res)
10991107

@@ -1108,6 +1116,7 @@ def nanvar(
11081116
*,
11091117
where=True,
11101118
mean=None,
1119+
correction=None,
11111120
):
11121121
"""
11131122
Compute the variance along the specified axis, while ignoring NaNs.
@@ -1158,6 +1167,12 @@ def nanvar(
11581167
11591168
Default: ``None``.
11601169
1170+
correction : {int, float}, optional
1171+
Array API compatible name for the `ddof` parameter. Only one of them
1172+
can be provided at the same time.
1173+
1174+
Default: ``None``.
1175+
11611176
Returns
11621177
-------
11631178
out : dpnp.ndarray
@@ -1231,6 +1246,7 @@ def nanvar(
12311246
ddof=ddof,
12321247
keepdims=keepdims,
12331248
where=where,
1249+
correction=correction,
12341250
)
12351251

12361252
if dtype is not None:
@@ -1243,6 +1259,13 @@ def nanvar(
12431259
if not dpnp.issubdtype(out.dtype, dpnp.inexact):
12441260
raise TypeError("If input is inexact, then out must be inexact.")
12451261

1262+
if correction is not None:
1263+
if ddof != 0:
1264+
raise ValueError(
1265+
"ddof and correction can't be provided simultaneously."
1266+
)
1267+
ddof = correction
1268+
12461269
# Compute mean
12471270
cnt = dpnp.sum(
12481271
~mask, axis=axis, dtype=dpnp.intp, keepdims=True, where=where

dpnp/dpnp_iface_statistics.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,7 @@ def std(
12011201
*,
12021202
where=True,
12031203
mean=None,
1204+
correction=None,
12041205
):
12051206
r"""
12061207
Compute the standard deviation along the specified axis.
@@ -1253,6 +1254,12 @@ def std(
12531254
12541255
Default: ``None``.
12551256
1257+
correction : {int, float}, optional
1258+
Array API compatible name for the `ddof` parameter. Only one of them
1259+
can be provided at the same time.
1260+
1261+
Default: ``None``.
1262+
12561263
Returns
12571264
-------
12581265
out : dpnp.ndarray
@@ -1344,6 +1351,13 @@ def std(
13441351
dpnp.check_supported_arrays_type(a)
13451352
dpnp.check_limitations(where=where)
13461353

1354+
if correction is not None:
1355+
if ddof != 0:
1356+
raise ValueError(
1357+
"ddof and correction can't be provided simultaneously."
1358+
)
1359+
ddof = correction
1360+
13471361
if not isinstance(ddof, (int, float)):
13481362
raise TypeError(
13491363
f"An integer or float is required, but got {type(ddof)}"
@@ -1382,6 +1396,7 @@ def var(
13821396
*,
13831397
where=True,
13841398
mean=None,
1399+
correction=None,
13851400
):
13861401
r"""
13871402
Compute the variance along the specified axis.
@@ -1433,6 +1448,12 @@ def var(
14331448
14341449
Default: ``None``.
14351450
1451+
correction : {int, float}, optional
1452+
Array API compatible name for the `ddof` parameter. Only one of them
1453+
can be provided at the same time.
1454+
1455+
Default: ``None``.
1456+
14361457
Returns
14371458
-------
14381459
out : dpnp.ndarray
@@ -1518,6 +1539,13 @@ def var(
15181539
dpnp.check_supported_arrays_type(a)
15191540
dpnp.check_limitations(where=where)
15201541

1542+
if correction is not None:
1543+
if ddof != 0:
1544+
raise ValueError(
1545+
"ddof and correction can't be provided simultaneously."
1546+
)
1547+
ddof = correction
1548+
15211549
if not isinstance(ddof, (int, float)):
15221550
raise TypeError(
15231551
f"An integer or float is required, but got {type(ddof)}"

dpnp/tests/test_nanfunctions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
assert_array_equal,
1010
assert_equal,
1111
assert_raises,
12+
assert_raises_regex,
1213
)
1314

1415
import dpnp
@@ -24,6 +25,7 @@
2425
numpy_version,
2526
)
2627
from .third_party.cupy import testing
28+
from .third_party.cupy.testing import with_requires
2729

2830

2931
class TestNanArgmaxNanArgmin:
@@ -750,6 +752,28 @@ def test_mean_keyword(self, dtype, axis, keepdims):
750752
)
751753
assert_dtype_allclose(result, expected)
752754

755+
@with_requires("numpy>=2.0")
756+
def test_correction(self):
757+
a = numpy.array([127, 39, 93, 87, 46])
758+
ia = dpnp.array(a)
759+
760+
expected = getattr(numpy, self.func)(a, correction=0.5)
761+
result = getattr(dpnp, self.func)(ia, correction=0.5)
762+
assert_dtype_allclose(result, expected)
763+
764+
@with_requires("numpy>=2.0")
765+
@pytest.mark.parametrize("xp", [dpnp, numpy])
766+
def test_both_ddof_correction_are_set(self, xp):
767+
a = xp.array(5)
768+
769+
err_msg = "ddof and correction can't be provided simultaneously."
770+
771+
with assert_raises_regex(ValueError, err_msg):
772+
getattr(xp, self.func)(a, ddof=0.5, correction=0.5)
773+
774+
with assert_raises_regex(ValueError, err_msg):
775+
getattr(xp, self.func)(a, ddof=1, correction=0)
776+
753777
def test_error(self):
754778
ia = dpnp.arange(5, dtype=dpnp.float32)
755779
ia[0] = dpnp.nan

dpnp/tests/test_statistics.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from numpy.testing import (
66
assert_allclose,
77
assert_array_equal,
8+
assert_raises_regex,
89
)
910

1011
import dpnp
@@ -763,6 +764,28 @@ def test_scalar(self):
763764
result = getattr(ia, self.func)()
764765
assert_dtype_allclose(result, expected)
765766

767+
@with_requires("numpy>=2.0")
768+
def test_correction(self):
769+
a = numpy.array([1, -1, 1, -1])
770+
ia = dpnp.array(a)
771+
772+
expected = getattr(a, self.func)(correction=1)
773+
result = getattr(ia, self.func)(correction=1)
774+
assert_dtype_allclose(result, expected)
775+
776+
@with_requires("numpy>=2.0")
777+
@pytest.mark.parametrize("xp", [dpnp, numpy])
778+
def test_both_ddof_correction_are_set(self, xp):
779+
a = xp.array([1, -1, 1, -1])
780+
781+
err_msg = "ddof and correction can't be provided simultaneously."
782+
783+
with assert_raises_regex(ValueError, err_msg):
784+
getattr(xp, self.func)(a, ddof=1, correction=0)
785+
786+
with assert_raises_regex(ValueError, err_msg):
787+
getattr(xp, self.func)(a, ddof=1, correction=1)
788+
766789
def test_error(self):
767790
ia = dpnp.arange(5)
768791
# where keyword is not implemented

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ exclude-protected = ["_create_from_usm_ndarray"]
1818
max-args = 11
1919
max-positional-arguments = 9
2020
max-locals = 30
21-
max-branches = 15
21+
max-branches = 16
2222
max-returns = 8
2323

2424
[tool.pylint.format]

0 commit comments

Comments
 (0)