Skip to content

Commit 8aa0176

Browse files
committed
Add special dtype check in tests for scipy>=0.16
1 parent 6ef574a commit 8aa0176

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

dpnp/tests/helper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import dpctl
44
import numpy
55
import pytest
6+
import scipy
67
from numpy.testing import assert_allclose, assert_array_equal
78

89
import dpnp
@@ -498,3 +499,7 @@ def requires_intel_mkl_version(version):
498499

499500
build_deps = numpy.show_config(mode="dicts")["Build Dependencies"]
500501
return build_deps["blas"]["version"] >= version
502+
503+
504+
def scipy_version():
505+
return scipy.__version__

dpnp/tests/test_special.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
import math
2-
31
import numpy
42
import pytest
53
import scipy
6-
from numpy.testing import assert_allclose, assert_almost_equal
4+
from numpy.testing import assert_allclose
75

86
import dpnp
97

108
from .helper import (
9+
assert_dtype_allclose,
1110
generate_random_numpy_array,
1211
get_all_dtypes,
1312
get_complex_dtypes,
13+
scipy_version,
1414
)
1515

1616

@@ -25,7 +25,13 @@ def test_basic(self, dt):
2525

2626
result = dpnp.special.erf(ia)
2727
expected = scipy.special.erf(a)
28-
assert_almost_equal(result, expected)
28+
29+
# scipy >= 0.16.0 returns float64, but dpnp returns float32
30+
to_float32 = dt in (dpnp.bool, dpnp.float16)
31+
only_type_kind = scipy_version() >= "0.16.0" and to_float32
32+
assert_dtype_allclose(
33+
result, expected, check_only_type_kind=only_type_kind
34+
)
2935

3036
def test_nan_inf(self):
3137
a = numpy.array([numpy.nan, -numpy.inf, numpy.inf])

dpnp/tests/test_strides.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_integer_dtypes,
1818
get_integer_float_dtypes,
1919
numpy_version,
20+
scipy_version,
2021
)
2122

2223

@@ -174,7 +175,10 @@ def test_erf(dtype, stride):
174175

175176
result = dpnp.special.erf(ia)
176177
expected = scipy.special.erf(a)
177-
assert_dtype_allclose(result, expected)
178+
179+
# scipy >= 0.16.0 returns float64, but dpnp returns float32
180+
only_type_kind = scipy_version() >= "0.16.0" and (dtype == dpnp.float16)
181+
assert_dtype_allclose(result, expected, check_only_type_kind=only_type_kind)
178182

179183

180184
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())

0 commit comments

Comments
 (0)