File tree Expand file tree Collapse file tree 3 files changed +20
-5
lines changed Expand file tree Collapse file tree 3 files changed +20
-5
lines changed Original file line number Diff line number Diff line change 33import dpctl
44import numpy
55import pytest
6+ import scipy
67from numpy .testing import assert_allclose , assert_array_equal
78
89import dpnp
@@ -498,3 +499,7 @@ def requires_intel_mkl_version(version):
498499
499500 build_deps = numpy .show_config (mode = "dicts" )["Build Dependencies" ]
500501 return build_deps ["blas" ]["version" ] >= version
502+
503+
504+ def scipy_version ():
505+ return scipy .__version__
Original file line number Diff line number Diff line change 1- import math
2-
31import numpy
42import pytest
53import scipy
6- from numpy .testing import assert_allclose , assert_almost_equal
4+ from numpy .testing import assert_allclose
75
86import dpnp
97
108from .helper import (
9+ assert_dtype_allclose ,
1110 generate_random_numpy_array ,
1211 get_all_dtypes ,
1312 get_complex_dtypes ,
13+ scipy_version ,
1414)
1515
1616
@@ -25,7 +25,13 @@ def test_basic(self, dt):
2525
2626 result = dpnp .special .erf (ia )
2727 expected = scipy .special .erf (a )
28- assert_almost_equal (result , expected )
28+
29+ # scipy >= 0.16.0 returns float64, but dpnp returns float32
30+ to_float32 = dt in (dpnp .bool , dpnp .float16 )
31+ only_type_kind = scipy_version () >= "0.16.0" and to_float32
32+ assert_dtype_allclose (
33+ result , expected , check_only_type_kind = only_type_kind
34+ )
2935
3036 def test_nan_inf (self ):
3137 a = numpy .array ([numpy .nan , - numpy .inf , numpy .inf ])
Original file line number Diff line number Diff line change 1717 get_integer_dtypes ,
1818 get_integer_float_dtypes ,
1919 numpy_version ,
20+ scipy_version ,
2021)
2122
2223
@@ -174,7 +175,10 @@ def test_erf(dtype, stride):
174175
175176 result = dpnp .special .erf (ia )
176177 expected = scipy .special .erf (a )
177- assert_dtype_allclose (result , expected )
178+
179+ # scipy >= 0.16.0 returns float64, but dpnp returns float32
180+ only_type_kind = scipy_version () >= "0.16.0" and (dtype == dpnp .float16 )
181+ assert_dtype_allclose (result , expected , check_only_type_kind = only_type_kind )
178182
179183
180184@pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
You can’t perform that action at this time.
0 commit comments