Skip to content

Commit c261a83

Browse files
committed
Skip erf tests when no scipy installed
1 parent 607c87f commit c261a83

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

dpnp/tests/helper.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import dpctl
44
import numpy
55
import pytest
6-
import scipy
76
from numpy.testing import assert_allclose, assert_array_equal
87

98
import dpnp
@@ -499,7 +498,3 @@ def requires_intel_mkl_version(version):
499498

500499
build_deps = numpy.show_config(mode="dicts")["Build Dependencies"]
501500
return build_deps["blas"]["version"] >= version
502-
503-
504-
def scipy_version():
505-
return scipy.__version__

dpnp/tests/test_special.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy
22
import pytest
3-
import scipy
43
from numpy.testing import assert_allclose
54

65
import dpnp
@@ -10,16 +9,18 @@
109
generate_random_numpy_array,
1110
get_all_dtypes,
1211
get_complex_dtypes,
13-
scipy_version,
1412
)
13+
from .third_party.cupy.testing import installed, with_requires
1514

1615

16+
@with_requires("scipy")
1717
class TestErf:
18-
1918
@pytest.mark.parametrize(
2019
"dt", get_all_dtypes(no_none=True, no_float16=False, no_complex=True)
2120
)
2221
def test_basic(self, dt):
22+
import scipy.special
23+
2324
a = generate_random_numpy_array((2, 5), dtype=dt)
2425
ia = dpnp.array(a)
2526

@@ -28,12 +29,14 @@ def test_basic(self, dt):
2829

2930
# scipy >= 0.16.0 returns float64, but dpnp returns float32
3031
to_float32 = dt in (dpnp.bool, dpnp.float16)
31-
only_type_kind = scipy_version() >= "0.16.0" and to_float32
32+
only_type_kind = installed("scipy>=0.16.0") and to_float32
3233
assert_dtype_allclose(
3334
result, expected, check_only_type_kind=only_type_kind
3435
)
3536

3637
def test_nan_inf(self):
38+
import scipy.special
39+
3740
a = numpy.array([numpy.nan, -numpy.inf, numpy.inf])
3841
ia = dpnp.array(a)
3942

dpnp/tests/test_strides.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy
44
import pytest
5-
import scipy
65
from numpy.testing import assert_array_equal
76

87
import dpnp
@@ -17,8 +16,8 @@
1716
get_integer_dtypes,
1817
get_integer_float_dtypes,
1918
numpy_version,
20-
scipy_version,
2119
)
20+
from .third_party.cupy.testing import installed, with_requires
2221

2322

2423
@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
@@ -167,17 +166,20 @@ def test_reduce_hypot(dtype, stride):
167166
assert_dtype_allclose(result, expected, check_only_type_kind=flag)
168167

169168

169+
@with_requires("scipy")
170170
@pytest.mark.parametrize("dtype", get_float_dtypes(no_float16=False))
171171
@pytest.mark.parametrize("stride", [2, -1, -3])
172172
def test_erf(dtype, stride):
173+
import scipy.special
174+
173175
x = generate_random_numpy_array(10, dtype=dtype)
174176
a, ia = x[::stride], dpnp.array(x)[::stride]
175177

176178
result = dpnp.special.erf(ia)
177179
expected = scipy.special.erf(a)
178180

179181
# scipy >= 0.16.0 returns float64, but dpnp returns float32
180-
only_type_kind = scipy_version() >= "0.16.0" and (dtype == dpnp.float16)
182+
only_type_kind = installed("scipy>=0.16.0") and (dtype == dpnp.float16)
181183
assert_dtype_allclose(result, expected, check_only_type_kind=only_type_kind)
182184

183185

0 commit comments

Comments
 (0)