Skip to content

Commit 3505842

Browse files
committed
Update third party tests
1 parent 26cbbdb commit 3505842

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

dpnp/tests/third_party/cupy/statistics_tests/test_correlation.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,7 @@ def check(
8282
fweights = name.asarray(fweights)
8383
if aweights is not None:
8484
aweights = name.asarray(aweights)
85-
# print(type(fweights))
86-
# return xp.cov(a, y, rowvar, bias, ddof,
87-
# fweights, aweights, dtype=dtype)
88-
return xp.cov(a, y, rowvar, bias, ddof, fweights, aweights)
85+
return xp.cov(a, y, rowvar, bias, ddof, fweights, aweights, dtype=dtype)
8986

9087
@testing.for_all_dtypes()
9188
@testing.numpy_cupy_allclose(accept_error=True)
@@ -103,9 +100,15 @@ def check_warns(
103100
):
104101
with testing.assert_warns(RuntimeWarning):
105102
a, y = self.generate_input(a_shape, y_shape, xp, dtype)
106-
return xp.cov(
107-
a, y, rowvar, bias, ddof, fweights, aweights, dtype=dtype
108-
)
103+
try:
104+
res = xp.cov(
105+
a, y, rowvar, bias, ddof, fweights, aweights, dtype=dtype
106+
)
107+
except ValueError as e:
108+
if xp is cupy: # dpnp raises ValueError(...)
109+
raise TypeError(e)
110+
raise
111+
return res
109112

110113
@testing.for_all_dtypes()
111114
def check_raises(
@@ -126,15 +129,11 @@ def check_raises(
126129
a, y, rowvar, bias, ddof, fweights, aweights, dtype=dtype
127130
)
128131

129-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
130-
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
132+
@testing.with_requires("numpy>=2.2")
131133
def test_cov(self):
132134
self.check((2, 3))
133135
self.check((2,), (2,))
134-
if numpy_version() >= "2.2.0":
135-
# TODO: enable once numpy 2.2 resolves ValueError
136-
# self.check((1, 3), (1, 3), rowvar=False)
137-
self.check((1, 3), (1, 1), rowvar=False) # TODO: remove
136+
self.check((1, 3), (1, 3), rowvar=False)
138137
self.check((2, 3), (2, 3), rowvar=False)
139138
self.check((2, 3), bias=True)
140139
self.check((2, 3), ddof=2)
@@ -144,12 +143,10 @@ def test_cov(self):
144143
self.check((1, 3), bias=True, aweights=(1.0, 4.0, 1.0))
145144
self.check((1, 3), fweights=(1, 4, 1), aweights=(1.0, 4.0, 1.0))
146145

147-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
148146
def test_cov_warns(self):
149147
self.check_warns((2, 3), ddof=3)
150148
self.check_warns((2, 3), ddof=4)
151149

152-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
153150
def test_cov_raises(self):
154151
self.check_raises((2, 3), ddof=1.2)
155152
self.check_raises((3, 4, 2))

0 commit comments

Comments
 (0)