|
14 | 14 |
|
15 | 15 |
|
16 | 16 | @with_requires("scipy")
|
17 |
| -@pytest.mark.parametrize("func", ["erf", "erfc"]) |
| 17 | +@pytest.mark.parametrize("func", ["erf", "erfc", "erfcx"]) |
18 | 18 | class TestCommon:
|
19 | 19 | @pytest.mark.parametrize(
|
20 | 20 | "dt", get_all_dtypes(no_none=True, no_float16=False, no_complex=True)
|
@@ -65,18 +65,31 @@ def test_complex(self, func, dt):
|
65 | 65 |
|
66 | 66 | class TestConsistency:
|
67 | 67 |
|
68 |
| - def test_erfc(self): |
| 68 | + def _check_variant_func(self, func, other_func, rtol, atol=0): |
69 | 69 | # TODO: replace with dpnp.random.RandomState, once pareto is added
|
70 | 70 | rng = numpy.random.RandomState(1234)
|
71 | 71 | n = 10000
|
72 | 72 | a = rng.pareto(0.02, n) * (2 * rng.randint(0, 2, n) - 1)
|
73 | 73 | a = dpnp.array(a)
|
| 74 | + a = a[::-1] |
74 | 75 |
|
75 |
| - res = 1 - dpnp.special.erf(a) |
| 76 | + res = other_func(a) |
76 | 77 | mask = dpnp.isfinite(res)
|
77 | 78 | a = a[mask]
|
78 | 79 |
|
79 |
| - tol = 8 * dpnp.finfo(a).resolution |
80 |
| - assert dpnp.allclose( |
81 |
| - dpnp.special.erfc(a), res[mask], rtol=tol, atol=tol |
| 80 | + assert dpnp.allclose(func(a), res[mask], rtol=rtol, atol=atol) |
| 81 | + |
| 82 | + def test_erfc(self): |
| 83 | + self._check_variant_func( |
| 84 | + dpnp.special.erfc, |
| 85 | + lambda z: 1 - dpnp.special.erf(z), |
| 86 | + rtol=1e-12, |
| 87 | + atol=1e-14, |
| 88 | + ) |
| 89 | + |
| 90 | + def test_erfcx(self): |
| 91 | + self._check_variant_func( |
| 92 | + dpnp.special.erfcx, |
| 93 | + lambda z: dpnp.exp(z * z) * dpnp.special.erfc(z), |
| 94 | + rtol=1e-12, |
82 | 95 | )
|
0 commit comments