18
18
19
19
import numpy as np
20
20
import pytest
21
- from numpy .testing import assert_equal
21
+ from numpy .testing import assert_allclose , assert_equal
22
22
23
23
import dpctl .tensor as dpt
24
24
from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
@@ -50,7 +50,7 @@ def test_sqrt_output_contig(dtype):
50
50
Y = dpt .sqrt (X )
51
51
tol = 8 * dpt .finfo (Y .dtype ).resolution
52
52
53
- np . testing . assert_allclose (dpt .asnumpy (Y ), np .sqrt (Xnp ), atol = tol , rtol = tol )
53
+ assert_allclose (dpt .asnumpy (Y ), np .sqrt (Xnp ), atol = tol , rtol = tol )
54
54
55
55
56
56
@pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" , "c8" , "c16" ])
@@ -66,7 +66,7 @@ def test_sqrt_output_strided(dtype):
66
66
Y = dpt .sqrt (X )
67
67
tol = 8 * dpt .finfo (Y .dtype ).resolution
68
68
69
- np . testing . assert_allclose (dpt .asnumpy (Y ), np .sqrt (Xnp ), atol = tol , rtol = tol )
69
+ assert_allclose (dpt .asnumpy (Y ), np .sqrt (Xnp ), atol = tol , rtol = tol )
70
70
71
71
72
72
@pytest .mark .parametrize ("usm_type" , _usm_types )
@@ -89,7 +89,7 @@ def test_sqrt_usm_type(usm_type):
89
89
expected_Y [..., 1 ::2 ] = np .sqrt (np .float32 (23.0 ))
90
90
tol = 8 * dpt .finfo (Y .dtype ).resolution
91
91
92
- np . testing . assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
92
+ assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
93
93
94
94
95
95
@pytest .mark .parametrize ("dtype" , _all_dtypes )
@@ -112,11 +112,10 @@ def test_sqrt_order(dtype):
112
112
dpt .finfo (Y .dtype ).resolution ,
113
113
np .finfo (expected_Y .dtype ).resolution ,
114
114
)
115
- np .testing .assert_allclose (
116
- dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol
117
- )
115
+ assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
118
116
119
117
118
+ @pytest .mark .usefixtures ("suppress_invalid_numpy_warnings" )
120
119
def test_sqrt_special_cases ():
121
120
q = get_queue_or_skip ()
122
121
@@ -126,3 +125,27 @@ def test_sqrt_special_cases():
126
125
Xnp = dpt .asnumpy (X )
127
126
128
127
assert_equal (dpt .asnumpy (dpt .sqrt (X )), np .sqrt (Xnp ))
128
+
129
+
130
+ @pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" , "c8" , "c16" ])
131
+ def test_sqrt_out_overlap (dtype ):
132
+ q = get_queue_or_skip ()
133
+ skip_if_dtype_not_supported (dtype , q )
134
+
135
+ X = dpt .linspace (0 , 35 , 60 , dtype = dtype , sycl_queue = q )
136
+ X = dpt .reshape (X , (3 , 5 , 4 ))
137
+
138
+ Xnp = dpt .asnumpy (X )
139
+ Ynp = np .sqrt (Xnp , out = Xnp )
140
+
141
+ Y = dpt .sqrt (X , out = X )
142
+ assert Y is X
143
+
144
+ tol = 8 * dpt .finfo (Y .dtype ).resolution
145
+ assert_allclose (dpt .asnumpy (X ), Xnp , atol = tol , rtol = tol )
146
+
147
+ Ynp = np .sqrt (Xnp , out = Xnp [::- 1 ])
148
+ Y = dpt .sqrt (X , out = X [::- 1 ])
149
+ assert Y is not X
150
+ assert_allclose (dpt .asnumpy (X ), Xnp , atol = tol , rtol = tol )
151
+ assert_allclose (dpt .asnumpy (Y ), Ynp , atol = tol , rtol = tol )
0 commit comments