Skip to content

Commit f5f5841

Browse files
committed
Add tests to cover
1 parent dee6047 commit f5f5841

File tree

5 files changed

+46
-28
lines changed

5 files changed

+46
-28
lines changed

dpnp/tests/test_special.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,19 @@
1414

1515

1616
@with_requires("scipy")
17-
class TestErf:
17+
@pytest.mark.parametrize("func", ["erf", "erfc"])
18+
class TestCommon:
1819
@pytest.mark.parametrize(
1920
"dt", get_all_dtypes(no_none=True, no_float16=False, no_complex=True)
2021
)
21-
def test_basic(self, dt):
22+
def test_basic(self, func, dt):
2223
import scipy.special
2324

2425
a = generate_random_numpy_array((2, 5), dtype=dt)
2526
ia = dpnp.array(a)
2627

27-
result = dpnp.special.erf(ia)
28-
expected = scipy.special.erf(a)
28+
result = getattr(dpnp.special, func)(ia)
29+
expected = getattr(scipy.special, func)(a)
2930

3031
# scipy >= 0.16.0 returns float64, but dpnp returns float32
3132
to_float32 = dt in (dpnp.bool, dpnp.float16)
@@ -34,29 +35,48 @@ def test_basic(self, dt):
3435
result, expected, check_only_type_kind=only_type_kind
3536
)
3637

37-
def test_nan_inf(self):
38+
def test_nan_inf(self, func):
3839
import scipy.special
3940

4041
a = numpy.array([numpy.nan, -numpy.inf, numpy.inf])
4142
ia = dpnp.array(a)
4243

43-
result = dpnp.special.erf(ia)
44-
expected = scipy.special.erf(a)
44+
result = getattr(dpnp.special, func)(ia)
45+
expected = getattr(scipy.special, func)(a)
4546
assert_allclose(result, expected)
4647

47-
def test_zeros(self):
48+
def test_zeros(self, func):
4849
import scipy.special
4950

5051
a = numpy.array([0.0, -0.0])
5152
ia = dpnp.array(a)
5253

53-
result = dpnp.special.erf(ia)
54-
expected = scipy.special.erf(a)
54+
result = getattr(dpnp.special, func)(ia)
55+
expected = getattr(scipy.special, func)(a)
5556
assert_allclose(result, expected)
5657
assert_equal(dpnp.signbit(result), numpy.signbit(expected))
5758

5859
@pytest.mark.parametrize("dt", get_complex_dtypes())
59-
def test_complex(self, dt):
60+
def test_complex(self, func, dt):
6061
x = dpnp.empty(5, dtype=dt)
6162
with pytest.raises(ValueError):
62-
dpnp.special.erf(x)
63+
getattr(dpnp.special, func)(x)
64+
65+
66+
class TestConsistency:
67+
68+
def test_erfc(self):
69+
# TODO: replace with dpnp.random.RandomState, once pareto is added
70+
rng = numpy.random.RandomState(1234)
71+
n = 10000
72+
a = rng.pareto(0.02, n) * (2 * rng.randint(0, 2, n) - 1)
73+
a = dpnp.array(a)
74+
75+
res = 1 - dpnp.special.erf(a)
76+
mask = dpnp.isfinite(res)
77+
a = a[mask]
78+
79+
tol = 8 * dpnp.finfo(a).resolution
80+
assert dpnp.allclose(
81+
dpnp.special.erfc(a), res[mask], rtol=tol, atol=tol
82+
)

dpnp/tests/test_strides.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
get_integer_float_dtypes,
1818
numpy_version,
1919
)
20-
from .third_party.cupy.testing import installed, with_requires
20+
from .third_party.cupy.testing import with_requires
2121

2222

2323
@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
@@ -167,20 +167,17 @@ def test_reduce_hypot(dtype, stride):
167167

168168

169169
@with_requires("scipy")
170-
@pytest.mark.parametrize("dtype", get_float_dtypes(no_float16=False))
170+
@pytest.mark.parametrize("func", ["erf", "erfc"])
171171
@pytest.mark.parametrize("stride", [2, -1, -3])
172-
def test_erf(dtype, stride):
172+
def test_erf_funcs(func, stride):
173173
import scipy.special
174174

175-
x = generate_random_numpy_array(10, dtype=dtype)
175+
x = generate_random_numpy_array(10)
176176
a, ia = x[::stride], dpnp.array(x)[::stride]
177177

178-
result = dpnp.special.erf(ia)
179-
expected = scipy.special.erf(a)
180-
181-
# scipy >= 0.16.0 returns float64, but dpnp returns float32
182-
only_type_kind = installed("scipy>=0.16.0") and (dtype == dpnp.float16)
183-
assert_dtype_allclose(result, expected, check_only_type_kind=only_type_kind)
178+
result = getattr(dpnp.special, func)(ia)
179+
expected = getattr(scipy.special, func)(a)
180+
assert_dtype_allclose(result, expected)
184181

185182

186183
@pytest.mark.filterwarnings("ignore::RuntimeWarning")

dpnp/tests/test_sycl_queue.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,11 +1488,12 @@ def test_interp(device, left, right, period):
14881488
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
14891489

14901490

1491+
@pytest.mark.parametrize("func", ["erf", "erfc"])
14911492
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
1492-
def test_erf(device):
1493+
def test_erf_funcs(func, device):
14931494
x = dpnp.linspace(-3, 3, num=5, device=device)
14941495

1495-
result = dpnp.special.erf(x)
1496+
result = getattr(dpnp.special, func)(x)
14961497
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
14971498

14981499

dpnp/tests/test_usm_type.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,10 +1296,11 @@ def test_choose(usm_type_x, usm_type_ind):
12961296
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind])
12971297

12981298

1299+
@pytest.mark.parametrize("func", ["erf", "erfc"])
12991300
@pytest.mark.parametrize("usm_type", list_of_usm_types)
1300-
def test_erf(usm_type):
1301+
def test_erf_funcs(func, usm_type):
13011302
x = dpnp.linspace(-3, 3, num=5, usm_type=usm_type)
1302-
y = dpnp.special.erf(x)
1303+
y = getattr(dpnp.special, func)(x)
13031304
assert x.usm_type == y.usm_type == usm_type
13041305

13051306

dpnp/tests/third_party/cupyx/scipy_tests/special_tests/test_erf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@ def _boundary_inputs(boundary, rtol, atol):
1919
@testing.with_requires("scipy")
2020
class _TestBase:
2121

22-
# @testing.with_requires('scipy>=1.16.0')
22+
@testing.with_requires("scipy>=1.16.0")
2323
def test_erf(self):
2424
self.check_unary("erf")
2525

26-
@pytest.mark.skip("erfc() is not supported yet")
2726
@testing.with_requires("scipy>=1.16.0")
2827
def test_erfc(self):
2928
self.check_unary("erfc")

0 commit comments

Comments
 (0)