Skip to content

Commit e6cc664

Browse files
committed
Add test with strides
1 parent db96676 commit e6cc664

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

dpnp/tests/test_strides.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy
44
import pytest
5+
import scipy
56
from numpy.testing import assert_array_equal
67

78
import dpnp
@@ -12,6 +13,7 @@
1213
get_all_dtypes,
1314
get_complex_dtypes,
1415
get_float_complex_dtypes,
16+
get_float_dtypes,
1517
get_integer_dtypes,
1618
get_integer_float_dtypes,
1719
numpy_version,
@@ -164,21 +166,14 @@ def test_reduce_hypot(dtype, stride):
164166
assert_dtype_allclose(result, expected, check_only_type_kind=flag)
165167

166168

167-
@pytest.mark.parametrize(
168-
"dtype",
169-
get_integer_float_dtypes(
170-
no_unsigned=True, xfail_dtypes=[dpnp.int8, dpnp.int16]
171-
),
172-
)
173-
def test_erf(dtype):
174-
a = dpnp.linspace(-1, 1, num=10, dtype=dtype)
175-
b = a[::2]
176-
result = dpnp.erf(b)
177-
178-
expected = numpy.empty_like(b.asnumpy())
179-
for idx, val in enumerate(b):
180-
expected[idx] = math.erf(val)
169+
@pytest.mark.parametrize("dtype", get_float_dtypes(no_float16=False))
170+
@pytest.mark.parametrize("stride", [2, -1, -3])
171+
def test_erf(dtype, stride):
172+
x = generate_random_numpy_array(10, dtype=dtype)
173+
a, ia = x[::stride], dpnp.array(x)[::stride]
181174

175+
result = dpnp.special.erf(ia)
176+
expected = scipy.special.erf(a)
182177
assert_dtype_allclose(result, expected)
183178

184179

0 commit comments

Comments
 (0)