Skip to content

Commit 018b098

Browse files
committed
Add testing for more use cases
1 parent a8daafd commit 018b098

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

dpnp/tests/test_special.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,42 @@
11
import math
22

33
import numpy
4-
from numpy.testing import assert_allclose
4+
import pytest
5+
import scipy
6+
from numpy.testing import assert_allclose, assert_almost_equal
57

68
import dpnp
79

10+
from .helper import (
11+
generate_random_numpy_array,
12+
get_all_dtypes,
13+
get_complex_dtypes,
14+
)
815

9-
def test_erf():
10-
a = numpy.linspace(2.0, 3.0, num=10)
11-
ia = dpnp.array(a)
1216

13-
expected = numpy.empty_like(a)
14-
for idx, val in enumerate(a):
15-
expected[idx] = math.erf(val)
17+
class TestErf:
1618

17-
result = dpnp.erf(ia)
19+
@pytest.mark.parametrize(
20+
"dt", get_all_dtypes(no_none=True, no_float16=False, no_complex=True)
21+
)
22+
def test_basic(self, dt):
23+
a = generate_random_numpy_array((2, 5), dtype=dt)
24+
ia = dpnp.array(a)
1825

19-
assert_allclose(result, expected)
26+
result = dpnp.special.erf(ia)
27+
expected = scipy.special.erf(a)
28+
assert_almost_equal(result, expected)
2029

30+
def test_nan_inf(self):
31+
a = numpy.array([numpy.nan, -numpy.inf, numpy.inf])
32+
ia = dpnp.array(a)
2133

22-
def test_erf_fallback():
23-
a = numpy.linspace(2.0, 3.0, num=10)
24-
dpa = dpnp.linspace(2.0, 3.0, num=10)
34+
result = dpnp.special.erf(ia)
35+
expected = scipy.special.erf(a)
36+
assert_allclose(result, expected)
2537

26-
expected = numpy.empty_like(a)
27-
for idx, val in enumerate(a):
28-
expected[idx] = math.erf(val)
29-
30-
result = dpnp.erf(dpa)
31-
32-
assert_allclose(result, expected)
38+
@pytest.mark.parametrize("dt", get_complex_dtypes())
39+
def test_complex(self, dt):
40+
x = dpnp.empty(5, dtype=dt)
41+
with pytest.raises(ValueError):
42+
dpnp.special.erf(x)

0 commit comments

Comments
 (0)