16
16
17
17
18
18
# import numpy as np
19
- # import pytest
19
+ import pytest
20
20
from helper import get_queue_or_skip
21
21
22
22
# import dpctl
@@ -144,6 +144,10 @@ def test_basic_slice10():
144
144
assert y .strides == (0 , n1 * n2 , n2 , 1 )
145
145
146
146
147
+ def _all_equal (it1 , it2 ):
148
+ return all (dpt .asnumpy (x ) == dpt .asnumpy (y ) for x , y in zip (it1 , it2 ))
149
+
150
+
147
151
def test_advanced_slice1 ():
148
152
q = get_queue_or_skip ()
149
153
ii = dpt .asarray ([1 , 2 ], sycl_queue = q )
@@ -154,6 +158,208 @@ def test_advanced_slice1():
154
158
assert y .strides == (1 ,)
155
159
# FIXME, once usm_ndarray.__equal__ is implemented,
156
160
# use of asnumpy should be removed
157
- assert all (
158
- dpt .asnumpy (x [ii [k ]]) == dpt .asnumpy (y [k ]) for k in range (ii .shape [0 ])
161
+ assert _all_equal (
162
+ (x [ii [k ]] for k in range (ii .shape [0 ])),
163
+ (y [k ] for k in range (ii .shape [0 ])),
164
+ )
165
+ y = x [(ii ,)]
166
+ assert isinstance (y , dpt .usm_ndarray )
167
+ assert y .shape == ii .shape
168
+ assert y .strides == (1 ,)
169
+ # FIXME, once usm_ndarray.__equal__ is implemented,
170
+ # use of asnumpy should be removed
171
+ assert _all_equal (
172
+ (x [ii [k ]] for k in range (ii .shape [0 ])),
173
+ (y [k ] for k in range (ii .shape [0 ])),
159
174
)
175
+
176
+
177
+ def test_advanced_slice2 ():
178
+ q = get_queue_or_skip ()
179
+ ii = dpt .asarray ([1 , 2 ], sycl_queue = q )
180
+ x = dpt .arange (10 , dtype = "i4" , sycl_queue = q )
181
+ y = x [ii , dpt .newaxis ]
182
+ assert isinstance (y , dpt .usm_ndarray )
183
+ assert y .shape == ii .shape + (1 ,)
184
+ assert y .flags ["C" ]
185
+
186
+
187
+ def test_advanced_slice3 ():
188
+ q = get_queue_or_skip ()
189
+ ii = dpt .asarray ([1 , 2 ], sycl_queue = q )
190
+ x = dpt .arange (10 , dtype = "i4" , sycl_queue = q )
191
+ y = x [dpt .newaxis , ii ]
192
+ assert isinstance (y , dpt .usm_ndarray )
193
+ assert y .shape == (1 ,) + ii .shape
194
+ assert y .flags ["C" ]
195
+
196
+
197
+ def _make_3d (dt , q ):
198
+ return dpt .reshape (
199
+ dpt .arange (3 * 3 * 3 , dtype = dt , sycl_queue = q ),
200
+ (
201
+ 3 ,
202
+ 3 ,
203
+ 3 ,
204
+ ),
205
+ )
206
+
207
+
208
+ def test_advanced_slice4 ():
209
+ q = get_queue_or_skip ()
210
+ ii = dpt .asarray ([1 , 2 ], sycl_queue = q )
211
+ x = _make_3d ("i4" , q )
212
+ y = x [ii , ii , ii ]
213
+ assert isinstance (y , dpt .usm_ndarray )
214
+ assert y .shape == ii .shape
215
+ assert _all_equal (
216
+ (x [ii [k ], ii [k ], ii [k ]] for k in range (ii .shape [0 ])),
217
+ (y [k ] for k in range (ii .shape [0 ])),
218
+ )
219
+
220
+
221
+ def test_advanced_slice5 ():
222
+ q = get_queue_or_skip ()
223
+ ii = dpt .asarray ([1 , 2 ], sycl_queue = q )
224
+ x = _make_3d ("i4" , q )
225
+ with pytest .raises (IndexError ):
226
+ x [ii , 0 , ii ]
227
+
228
+
229
+ def test_advanced_slice6 ():
230
+ q = get_queue_or_skip ()
231
+ ii = dpt .asarray ([1 , 2 ], sycl_queue = q )
232
+ x = _make_3d ("i4" , q )
233
+ y = x [:, ii , ii ]
234
+ assert isinstance (y , dpt .usm_ndarray )
235
+ assert y .shape == (
236
+ x .shape [0 ],
237
+ ii .shape [0 ],
238
+ )
239
+ assert _all_equal (
240
+ (
241
+ x [i , ii [k ], ii [k ]]
242
+ for i in range (x .shape [0 ])
243
+ for k in range (ii .shape [0 ])
244
+ ),
245
+ (y [i , k ] for i in range (x .shape [0 ]) for k in range (ii .shape [0 ])),
246
+ )
247
+
248
+
249
+ def test_advanced_slice7 ():
250
+ q = get_queue_or_skip ()
251
+ mask = dpt .asarray (
252
+ [
253
+ [[True , True , False ], [False , True , True ], [True , False , True ]],
254
+ [[True , False , False ], [False , False , True ], [False , True , False ]],
255
+ [[True , True , True ], [False , False , False ], [False , False , True ]],
256
+ ],
257
+ sycl_queue = q ,
258
+ )
259
+ x = _make_3d ("i2" , q )
260
+ y = x [mask ]
261
+ expected = [0 , 1 , 4 , 5 , 6 , 8 , 9 , 14 , 16 , 18 , 19 , 20 , 26 ]
262
+ assert isinstance (y , dpt .usm_ndarray )
263
+ assert y .shape == (len (expected ),)
264
+ assert all (dpt .asnumpy (y [k ]) == expected [k ] for k in range (len (expected )))
265
+
266
+
267
+ def test_advanced_slice8 ():
268
+ q = get_queue_or_skip ()
269
+ mask = dpt .asarray (
270
+ [[True , False , False ], [False , True , False ], [False , True , False ]],
271
+ sycl_queue = q ,
272
+ )
273
+ x = _make_3d ("u2" , q )
274
+ y = x [mask ]
275
+ expected = dpt .asarray (
276
+ [[0 , 1 , 2 ], [12 , 13 , 14 ], [21 , 22 , 23 ]], sycl_queue = q
277
+ )
278
+ assert isinstance (y , dpt .usm_ndarray )
279
+ assert y .shape == expected .shape
280
+ assert (dpt .asnumpy (y ) == dpt .asnumpy (expected )).all ()
281
+
282
+
283
+ def test_advanced_slice9 ():
284
+ q = get_queue_or_skip ()
285
+ mask = dpt .asarray (
286
+ [[True , False , False ], [False , True , False ], [False , True , False ]],
287
+ sycl_queue = q ,
288
+ )
289
+ x = _make_3d ("u4" , q )
290
+ y = x [:, mask ]
291
+ expected = dpt .asarray ([[0 , 4 , 7 ], [9 , 13 , 16 ], [18 , 22 , 25 ]], sycl_queue = q )
292
+ assert isinstance (y , dpt .usm_ndarray )
293
+ assert y .shape == expected .shape
294
+ assert (dpt .asnumpy (y ) == dpt .asnumpy (expected )).all ()
295
+
296
+
297
+ def lin_id (i , j , k ):
298
+ """global_linear_id for (3,3,3) range traversed in C-contiguous order"""
299
+ return 9 * i + 3 * j + k
300
+
301
+
302
+ def test_advanced_slice10 ():
303
+ q = get_queue_or_skip ()
304
+ x = _make_3d ("u8" , q )
305
+ i0 = dpt .asarray ([0 , 1 , 1 ], device = x .device )
306
+ i1 = dpt .asarray ([1 , 1 , 2 ], device = x .device )
307
+ i2 = dpt .asarray ([2 , 0 , 1 ], device = x .device )
308
+ y = x [i0 , i1 , i2 ]
309
+ res_expected = dpt .asarray (
310
+ [
311
+ lin_id (0 , 1 , 2 ),
312
+ lin_id (1 , 1 , 0 ),
313
+ lin_id (1 , 2 , 1 ),
314
+ ],
315
+ sycl_queue = q ,
316
+ )
317
+ assert isinstance (y , dpt .usm_ndarray )
318
+ assert y .shape == res_expected .shape
319
+ assert (dpt .asnumpy (y ) == dpt .asnumpy (res_expected )).all ()
320
+
321
+
322
+ def test_advanced_slice11 ():
323
+ q = get_queue_or_skip ()
324
+ x = _make_3d ("u8" , q )
325
+ i0 = dpt .asarray ([0 , 1 , 1 ], device = x .device )
326
+ i2 = dpt .asarray ([2 , 0 , 1 ], device = x .device )
327
+ with pytest .raises (IndexError ):
328
+ x [i0 , :, i2 ]
329
+
330
+
331
+ def test_advanced_slice12 ():
332
+ q = get_queue_or_skip ()
333
+ x = _make_3d ("u8" , q )
334
+ i1 = dpt .asarray ([1 , 1 , 2 ], device = x .device )
335
+ i2 = dpt .asarray ([2 , 0 , 1 ], device = x .device )
336
+ y = x [:, dpt .newaxis , i1 , i2 , dpt .newaxis ]
337
+ res_expected = dpt .asarray (
338
+ [
339
+ [[[lin_id (0 , 1 , 2 )], [lin_id (0 , 1 , 0 )], [lin_id (0 , 2 , 1 )]]],
340
+ [[[lin_id (1 , 1 , 2 )], [lin_id (1 , 1 , 0 )], [lin_id (1 , 2 , 1 )]]],
341
+ [[[lin_id (2 , 1 , 2 )], [lin_id (2 , 1 , 0 )], [lin_id (2 , 2 , 1 )]]],
342
+ ],
343
+ sycl_queue = q ,
344
+ )
345
+ assert isinstance (y , dpt .usm_ndarray )
346
+ assert y .shape == res_expected .shape
347
+ assert (dpt .asnumpy (y ) == dpt .asnumpy (res_expected )).all ()
348
+
349
+
350
+ def test_advanced_slice13 ():
351
+ q = get_queue_or_skip ()
352
+ x = _make_3d ("u8" , q )
353
+ i1 = dpt .asarray ([[1 ], [2 ]], device = x .device )
354
+ i2 = dpt .asarray ([[0 , 1 ]], device = x .device )
355
+ y = x [i1 , i2 , 0 ]
356
+ expected = dpt .asarray (
357
+ [
358
+ [lin_id (1 , 0 , 0 ), lin_id (1 , 1 , 0 )],
359
+ [lin_id (2 , 0 , 0 ), lin_id (2 , 1 , 0 )],
360
+ ],
361
+ device = x .device ,
362
+ )
363
+ assert isinstance (y , dpt .usm_ndarray )
364
+ assert y .shape == expected .shape
365
+ assert (dpt .asnumpy (y ) == dpt .asnumpy (expected )).all ()
0 commit comments