Skip to content

Commit 3972771

Browse files
committed
Add tests
1 parent afda756 commit 3972771

File tree

6 files changed

+23
-11
lines changed

6 files changed

+23
-11
lines changed

dpnp/backend/extensions/vm/erf_funcs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,6 @@ void init_erf_funcs(py::module_ m)
190190
m, "_erfcx",
191191
"Call `erfcx` function from OneMKL VM library to compute the scaled "
192192
"complementary error function value of vector elements",
193-
impl::erfc_contig_dispatch_vector);
193+
impl::erfcx_contig_dispatch_vector);
194194
}
195195
} // namespace dpnp::extensions::vm

dpnp/tests/test_special.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
@with_requires("scipy")
17-
@pytest.mark.parametrize("func", ["erf", "erfc"])
17+
@pytest.mark.parametrize("func", ["erf", "erfc", "erfcx"])
1818
class TestCommon:
1919
@pytest.mark.parametrize(
2020
"dt", get_all_dtypes(no_none=True, no_float16=False, no_complex=True)
@@ -65,18 +65,31 @@ def test_complex(self, func, dt):
6565

6666
class TestConsistency:
6767

68-
def test_erfc(self):
68+
def _check_variant_func(self, func, other_func, rtol, atol=0):
6969
# TODO: replace with dpnp.random.RandomState, once pareto is added
7070
rng = numpy.random.RandomState(1234)
7171
n = 10000
7272
a = rng.pareto(0.02, n) * (2 * rng.randint(0, 2, n) - 1)
7373
a = dpnp.array(a)
74+
a = a[::-1]
7475

75-
res = 1 - dpnp.special.erf(a)
76+
res = other_func(a)
7677
mask = dpnp.isfinite(res)
7778
a = a[mask]
7879

79-
tol = 8 * dpnp.finfo(a).resolution
80-
assert dpnp.allclose(
81-
dpnp.special.erfc(a), res[mask], rtol=tol, atol=tol
80+
assert dpnp.allclose(func(a), res[mask], rtol=rtol, atol=atol)
81+
82+
def test_erfc(self):
83+
self._check_variant_func(
84+
dpnp.special.erfc,
85+
lambda z: 1 - dpnp.special.erf(z),
86+
rtol=1e-12,
87+
atol=1e-14,
88+
)
89+
90+
def test_erfcx(self):
91+
self._check_variant_func(
92+
dpnp.special.erfcx,
93+
lambda z: dpnp.exp(z * z) * dpnp.special.erfc(z),
94+
rtol=1e-12,
8295
)

dpnp/tests/test_strides.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def test_reduce_hypot(dtype, stride):
167167

168168

169169
@with_requires("scipy")
170-
@pytest.mark.parametrize("func", ["erf", "erfc"])
170+
@pytest.mark.parametrize("func", ["erf", "erfc", "erfcx"])
171171
@pytest.mark.parametrize("stride", [2, -1, -3])
172172
def test_erf_funcs(func, stride):
173173
import scipy.special

dpnp/tests/test_sycl_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1489,7 +1489,7 @@ def test_interp(device, left, right, period):
14891489
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
14901490

14911491

1492-
@pytest.mark.parametrize("func", ["erf", "erfc"])
1492+
@pytest.mark.parametrize("func", ["erf", "erfc", "erfcx"])
14931493
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
14941494
def test_erf_funcs(func, device):
14951495
x = dpnp.linspace(-3, 3, num=5, device=device)

dpnp/tests/test_usm_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1296,7 +1296,7 @@ def test_choose(usm_type_x, usm_type_ind):
12961296
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind])
12971297

12981298

1299-
@pytest.mark.parametrize("func", ["erf", "erfc"])
1299+
@pytest.mark.parametrize("func", ["erf", "erfc", "erfcx"])
13001300
@pytest.mark.parametrize("usm_type", list_of_usm_types)
13011301
def test_erf_funcs(func, usm_type):
13021302
x = dpnp.linspace(-3, 3, num=5, usm_type=usm_type)

dpnp/tests/third_party/cupyx/scipy_tests/special_tests/test_erf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def test_erf(self):
2727
def test_erfc(self):
2828
self.check_unary("erfc")
2929

30-
@pytest.mark.skip("erfcx() is not supported yet")
3130
@testing.with_requires("scipy>=1.16.0")
3231
def test_erfcx(self):
3332
self.check_unary("erfcx")

0 commit comments

Comments
 (0)