Skip to content

Commit f6512c5

Browse files
Update erf tests after migration
1 parent 13bdcfd commit f6512c5

File tree

5 files changed

+29
-10
lines changed

5 files changed

+29
-10
lines changed

dpnp/tests/test_special.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_basic(self, func, dt):
2525
a = generate_random_numpy_array((2, 5), dtype=dt)
2626
ia = dpnp.array(a)
2727

28-
result = getattr(dpnp.special, func)(ia)
28+
result = getattr(dpnp.scipy.special, func)(ia)
2929
expected = getattr(scipy.special, func)(a)
3030

3131
# scipy >= 0.16.0 returns float64, but dpnp returns float32
@@ -41,7 +41,7 @@ def test_nan_inf(self, func):
4141
a = numpy.array([numpy.nan, -numpy.inf, numpy.inf])
4242
ia = dpnp.array(a)
4343

44-
result = getattr(dpnp.special, func)(ia)
44+
result = getattr(dpnp.scipy.special, func)(ia)
4545
expected = getattr(scipy.special, func)(a)
4646
assert_allclose(result, expected)
4747

@@ -51,7 +51,7 @@ def test_zeros(self, func):
5151
a = numpy.array([0.0, -0.0])
5252
ia = dpnp.array(a)
5353

54-
result = getattr(dpnp.special, func)(ia)
54+
result = getattr(dpnp.scipy.special, func)(ia)
5555
expected = getattr(scipy.special, func)(a)
5656
assert_allclose(result, expected)
5757
assert_equal(dpnp.signbit(result), numpy.signbit(expected))
@@ -60,7 +60,7 @@ def test_zeros(self, func):
6060
def test_complex(self, func, dt):
6161
x = dpnp.empty(5, dtype=dt)
6262
with pytest.raises(ValueError):
63-
getattr(dpnp.special, func)(x)
63+
getattr(dpnp.scipy.special, func)(x)
6464

6565

6666
class TestConsistency:
@@ -72,11 +72,11 @@ def test_erfc(self):
7272
a = rng.pareto(0.02, n) * (2 * rng.randint(0, 2, n) - 1)
7373
a = dpnp.array(a)
7474

75-
res = 1 - dpnp.special.erf(a)
75+
res = 1 - dpnp.scipy.special.erf(a)
7676
mask = dpnp.isfinite(res)
7777
a = a[mask]
7878

7979
tol = 8 * dpnp.finfo(a).resolution
8080
assert dpnp.allclose(
81-
dpnp.special.erfc(a), res[mask], rtol=tol, atol=tol
81+
dpnp.scipy.special.erfc(a), res[mask], rtol=tol, atol=tol
8282
)

dpnp/tests/test_strides.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def test_erf_funcs(func, stride):
175175
x = generate_random_numpy_array(10)
176176
a, ia = x[::stride], dpnp.array(x)[::stride]
177177

178-
result = getattr(dpnp.special, func)(ia)
178+
result = getattr(dpnp.scipy.special, func)(ia)
179179
expected = getattr(scipy.special, func)(a)
180180
assert_dtype_allclose(result, expected)
181181

dpnp/tests/test_sycl_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1493,7 +1493,7 @@ def test_interp(device, left, right, period):
14931493
def test_erf_funcs(func, device):
14941494
x = dpnp.linspace(-3, 3, num=5, device=device)
14951495

1496-
result = getattr(dpnp.special, func)(x)
1496+
result = getattr(dpnp.scipy.special, func)(x)
14971497
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
14981498

14991499

dpnp/tests/test_usm_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1300,7 +1300,7 @@ def test_choose(usm_type_x, usm_type_ind):
13001300
@pytest.mark.parametrize("usm_type", list_of_usm_types)
13011301
def test_erf_funcs(func, usm_type):
13021302
x = dpnp.linspace(-3, 3, num=5, usm_type=usm_type)
1303-
y = getattr(dpnp.special, func)(x)
1303+
y = getattr(dpnp.scipy.special, func)(x)
13041304
assert x.usm_type == y.usm_type == usm_type
13051305

13061306

dpnp/tests/third_party/cupyx/scipy_tests/special_tests/test_erf.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

33
import unittest
4+
from functools import wraps
45

56
import numpy
67
import pytest
78

89
import dpnp as cupy
9-
import dpnp.special
10+
import dpnp.scipy.special
1011
from dpnp.tests.third_party.cupy import testing
1112

1213

@@ -16,6 +17,19 @@ def _boundary_inputs(boundary, rtol, atol):
1617
return [left, boundary, right]
1718

1819

20+
# Ensure `scp` exposes `.special` submodule:
21+
# use `dpnp.scipy` for DPNP, `scipy` for SciPy
22+
def resolve_special_scp(func):
23+
@wraps(func)
24+
def wrapper(*args, **kwargs):
25+
scp = kwargs.get("scp")
26+
if scp is not None and not hasattr(scp, "special"):
27+
kwargs["scp"] = getattr(scp, "scipy", scp) # dpnp -> dpnp.scipy
28+
return func(*args, **kwargs)
29+
30+
return wrapper
31+
32+
1933
@testing.with_requires("scipy")
2034
class _TestBase:
2135

@@ -53,14 +67,18 @@ class TestSpecial(unittest.TestCase, _TestBase):
5367
# scipy>=1.16: 'e -> d', which causes type_check=False
5468
@testing.for_dtypes(["e", "f", "d"])
5569
@testing.numpy_cupy_allclose(atol=1e-5, scipy_name="scp", type_check=False)
70+
@resolve_special_scp
5671
def check_unary(self, name, xp, scp, dtype):
5772
import scipy.special
5873

5974
a = testing.shaped_arange((2, 3), xp, dtype)
75+
# _scp = getattr(scp, "scipy", scp)
76+
# return getattr(_scp.special, name)(a)
6077
return getattr(scp.special, name)(a)
6178

6279
@testing.for_dtypes(["f", "d"])
6380
@testing.numpy_cupy_allclose(atol=1e-5, scipy_name="scp")
81+
@resolve_special_scp
6482
def check_unary_random(self, name, xp, scp, dtype, scale, offset):
6583
import scipy.special
6684

@@ -69,6 +87,7 @@ def check_unary_random(self, name, xp, scp, dtype, scale, offset):
6987

7088
@testing.for_dtypes(["f", "d"])
7189
@testing.numpy_cupy_allclose(atol=1e-5, scipy_name="scp")
90+
@resolve_special_scp
7291
def check_unary_boundary(self, name, xp, scp, dtype, boundary):
7392
import scipy.special
7493

0 commit comments

Comments
 (0)