File tree Expand file tree Collapse file tree 3 files changed +20
-5
lines changed Expand file tree Collapse file tree 3 files changed +20
-5
lines changed Original file line number Diff line number Diff line change 3
3
import dpctl
4
4
import numpy
5
5
import pytest
6
+ import scipy
6
7
from numpy .testing import assert_allclose , assert_array_equal
7
8
8
9
import dpnp
@@ -498,3 +499,7 @@ def requires_intel_mkl_version(version):
498
499
499
500
build_deps = numpy .show_config (mode = "dicts" )["Build Dependencies" ]
500
501
return build_deps ["blas" ]["version" ] >= version
502
+
503
+
504
+ def scipy_version ():
505
+ return scipy .__version__
Original file line number Diff line number Diff line change 1
- import math
2
-
3
1
import numpy
4
2
import pytest
5
3
import scipy
6
- from numpy .testing import assert_allclose , assert_almost_equal
4
+ from numpy .testing import assert_allclose
7
5
8
6
import dpnp
9
7
10
8
from .helper import (
9
+ assert_dtype_allclose ,
11
10
generate_random_numpy_array ,
12
11
get_all_dtypes ,
13
12
get_complex_dtypes ,
13
+ scipy_version ,
14
14
)
15
15
16
16
@@ -25,7 +25,13 @@ def test_basic(self, dt):
25
25
26
26
result = dpnp .special .erf (ia )
27
27
expected = scipy .special .erf (a )
28
- assert_almost_equal (result , expected )
28
+
29
+ # scipy >= 0.16.0 returns float64, but dpnp returns float32
30
+ to_float32 = dt in (dpnp .bool , dpnp .float16 )
31
+ only_type_kind = scipy_version () >= "0.16.0" and to_float32
32
+ assert_dtype_allclose (
33
+ result , expected , check_only_type_kind = only_type_kind
34
+ )
29
35
30
36
def test_nan_inf (self ):
31
37
a = numpy .array ([numpy .nan , - numpy .inf , numpy .inf ])
Original file line number Diff line number Diff line change 17
17
get_integer_dtypes ,
18
18
get_integer_float_dtypes ,
19
19
numpy_version ,
20
+ scipy_version ,
20
21
)
21
22
22
23
@@ -174,7 +175,10 @@ def test_erf(dtype, stride):
174
175
175
176
result = dpnp .special .erf (ia )
176
177
expected = scipy .special .erf (a )
177
- assert_dtype_allclose (result , expected )
178
+
179
+ # scipy >= 0.16.0 returns float64, but dpnp returns float32
180
+ only_type_kind = scipy_version () >= "0.16.0" and (dtype == dpnp .float16 )
181
+ assert_dtype_allclose (result , expected , check_only_type_kind = only_type_kind )
178
182
179
183
180
184
@pytest .mark .parametrize ("dtype" , get_float_complex_dtypes ())
You can’t perform that action at this time.
0 commit comments