Skip to content

Commit 5b3beb3

Browse files
committed
Wrap cov call in third party test to prorerly handle exception raised
1 parent 3505842 commit 5b3beb3

File tree

1 file changed

+33
-10
lines changed

1 file changed

+33
-10
lines changed

dpnp/tests/third_party/cupy/statistics_tests/test_correlation.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
import unittest
23

34
import numpy
@@ -7,6 +8,11 @@
78
from dpnp.tests.helper import has_support_aspect64, numpy_version
89
from dpnp.tests.third_party.cupy import testing
910

11+
if numpy_version() >= "2.0.0":
12+
from numpy._core._exceptions import _UFuncOutputCastingError
13+
else:
14+
from numpy.core._exceptions import _UFuncOutputCastingError
15+
1016

1117
class TestCorrcoef(unittest.TestCase):
1218

@@ -60,6 +66,26 @@ def generate_input(self, a_shape, y_shape, xp, dtype):
6066
y = testing.shaped_arange(y_shape, xp, dtype)
6167
return a, y
6268

69+
def call_cov(self, xp, a, y, rowvar, bias, ddof, fweights, aweights, dtype):
70+
try:
71+
return xp.cov(
72+
a, y, rowvar, bias, ddof, fweights, aweights, dtype=dtype
73+
)
74+
except ValueError as e:
75+
if (
76+
xp is cupy
77+
and "function 'subtract' does not support input types" in str(e)
78+
):
79+
# numpy raises _UFuncOutputCastingError
80+
raise _UFuncOutputCastingError(
81+
numpy.subtract,
82+
"same_kind",
83+
numpy.dtype("f8"),
84+
numpy.dtype(dtype),
85+
0,
86+
)
87+
raise
88+
6389
@testing.for_all_dtypes()
6490
@testing.numpy_cupy_allclose(
6591
type_check=has_support_aspect64(), accept_error=True
@@ -82,7 +108,9 @@ def check(
82108
fweights = name.asarray(fweights)
83109
if aweights is not None:
84110
aweights = name.asarray(aweights)
85-
return xp.cov(a, y, rowvar, bias, ddof, fweights, aweights, dtype=dtype)
111+
return self.call_cov(
112+
xp, a, y, rowvar, bias, ddof, fweights, aweights, dtype
113+
)
86114

87115
@testing.for_all_dtypes()
88116
@testing.numpy_cupy_allclose(accept_error=True)
@@ -100,15 +128,9 @@ def check_warns(
100128
):
101129
with testing.assert_warns(RuntimeWarning):
102130
a, y = self.generate_input(a_shape, y_shape, xp, dtype)
103-
try:
104-
res = xp.cov(
105-
a, y, rowvar, bias, ddof, fweights, aweights, dtype=dtype
106-
)
107-
except ValueError as e:
108-
if xp is cupy: # dpnp raises ValueError(...)
109-
raise TypeError(e)
110-
raise
111-
return res
131+
return self.call_cov(
132+
xp, a, y, rowvar, bias, ddof, fweights, aweights, dtype
133+
)
112134

113135
@testing.for_all_dtypes()
114136
def check_raises(
@@ -129,6 +151,7 @@ def check_raises(
129151
a, y, rowvar, bias, ddof, fweights, aweights, dtype=dtype
130152
)
131153

154+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
132155
@testing.with_requires("numpy>=2.2")
133156
def test_cov(self):
134157
self.check((2, 3))

0 commit comments

Comments
 (0)