@@ -114,6 +114,64 @@ def test_properties():
114
114
assert isinstance (X .ndim , numbers .Integral )
115
115
116
116
117
+ @pytest .mark .parametrize ("func" , [bool , float , int , complex ])
118
+ @pytest .mark .parametrize ("shape" , [tuple (), (1 ,), (1 , 1 ), (1 , 1 , 1 )])
119
+ @pytest .mark .parametrize ("dtype" , ["|b1" , "|u2" , "|f4" , "|i8" ])
120
+ def test_copy_scalar_with_func (func , shape , dtype ):
121
+ X = dpt .usm_ndarray (shape , dtype = dtype )
122
+ Y = np .arange (1 , X .size + 1 , dtype = dtype ).reshape (shape )
123
+ X .usm_data .copy_from_host (Y .reshape (- 1 ).view ("|u1" ))
124
+ assert func (X ) == func (Y )
125
+
126
+
127
+ @pytest .mark .parametrize (
128
+ "method" , ["__bool__" , "__float__" , "__int__" , "__complex__" ]
129
+ )
130
+ @pytest .mark .parametrize ("shape" , [tuple (), (1 ,), (1 , 1 ), (1 , 1 , 1 )])
131
+ @pytest .mark .parametrize ("dtype" , ["|b1" , "|u2" , "|f4" , "|i8" ])
132
+ def test_copy_scalar_with_method (method , shape , dtype ):
133
+ X = dpt .usm_ndarray (shape , dtype = dtype )
134
+ Y = np .arange (1 , X .size + 1 , dtype = dtype ).reshape (shape )
135
+ X .usm_data .copy_from_host (Y .reshape (- 1 ).view ("|u1" ))
136
+ assert getattr (X , method )() == getattr (Y , method )()
137
+
138
+
139
+ @pytest .mark .parametrize ("func" , [bool , float , int , complex ])
140
+ @pytest .mark .parametrize ("shape" , [(2 ,), (1 , 2 ), (3 , 4 , 5 ), (0 ,)])
141
+ def test_copy_scalar_invalid_shape (func , shape ):
142
+ X = dpt .usm_ndarray (shape )
143
+ with pytest .raises (ValueError ):
144
+ func (X )
145
+
146
+
147
+ @pytest .mark .parametrize ("shape" , [(1 ,), (1 , 1 ), (1 , 1 , 1 )])
148
+ @pytest .mark .parametrize ("index_dtype" , ["|i8" ])
149
+ def test_usm_ndarray_as_index (shape , index_dtype ):
150
+ X = dpt .usm_ndarray (shape , dtype = index_dtype )
151
+ Xnp = np .arange (1 , X .size + 1 , dtype = index_dtype ).reshape (shape )
152
+ X .usm_data .copy_from_host (Xnp .reshape (- 1 ).view ("|u1" ))
153
+ Y = np .arange (X .size + 1 )
154
+ assert Y [X ] == Y [1 ]
155
+
156
+
157
+ @pytest .mark .parametrize ("shape" , [(2 ,), (1 , 2 ), (3 , 4 , 5 ), (0 ,)])
158
+ @pytest .mark .parametrize ("index_dtype" , ["|i8" ])
159
+ def test_usm_ndarray_as_index_invalid_shape (shape , index_dtype ):
160
+ X = dpt .usm_ndarray (shape , dtype = index_dtype )
161
+ Y = np .arange (X .size + 1 )
162
+ with pytest .raises (IndexError ):
163
+ Y [X ]
164
+
165
+
166
+ @pytest .mark .parametrize ("shape" , [(1 ,), (1 , 1 ), (1 , 1 , 1 )])
167
+ @pytest .mark .parametrize ("index_dtype" , ["|f8" ])
168
+ def test_usm_ndarray_as_index_invalid_dtype (shape , index_dtype ):
169
+ X = dpt .usm_ndarray (shape , dtype = index_dtype )
170
+ Y = np .arange (X .size + 1 )
171
+ with pytest .raises (IndexError ):
172
+ Y [X ]
173
+
174
+
117
175
@pytest .mark .parametrize (
118
176
"ind" ,
119
177
[
@@ -251,6 +309,14 @@ def test_slicing_basic():
251
309
Xusm [:, - 128 ]
252
310
with pytest .raises (TypeError ):
253
311
Xusm [{1 , 2 , 3 , 4 , 5 , 6 , 7 }]
312
+ X = dpt .usm_ndarray (10 , "u1" )
313
+ X .usm_data .copy_from_host (b"\x00 \x01 \x02 \x03 \x04 \x05 \x06 \x07 \x08 \x09 " )
314
+ int (
315
+ X [X [2 ]]
316
+ ) # check that objects with __index__ method can be used as indices
317
+ Xh = dpm .as_usm_memory (X [X [2 ] : X [5 ]]).copy_to_host ()
318
+ Xnp = np .arange (0 , 10 , dtype = "u1" )
319
+ assert np .array_equal (Xh , Xnp [Xnp [2 ] : Xnp [5 ]])
254
320
255
321
256
322
def test_ctor_invalid_shape ():
@@ -291,3 +357,19 @@ def test_usm_ndarray_props():
291
357
except dpctl .SyclQueueCreationError :
292
358
pytest .skip ("Sycl device CPU was not detected" )
293
359
Xusm .to_device ("cpu" )
360
+
361
+
362
+ def test_datapi_device ():
363
+ X = dpt .usm_ndarray (1 )
364
+ dev_t = type (X .device )
365
+ with pytest .raises (TypeError ):
366
+ dev_t ()
367
+ dev_t .create_device (X .device )
368
+ dev_t .create_device (X .sycl_queue )
369
+ dev_t .create_device (X .sycl_device )
370
+ dev_t .create_device (X .sycl_device .filter_string )
371
+ dev_t .create_device (None )
372
+ X .device .sycl_context
373
+ X .device .sycl_queue
374
+ X .device .sycl_device
375
+ repr (X .device )
0 commit comments