|
15 | 15 | # limitations under the License.
|
16 | 16 |
|
17 | 17 |
|
18 |
| -# import numpy as np |
| 18 | +import numpy as np |
19 | 19 | import pytest
|
20 | 20 | from helper import get_queue_or_skip
|
21 | 21 |
|
@@ -174,6 +174,22 @@ def test_advanced_slice1():
|
174 | 174 | )
|
175 | 175 |
|
176 | 176 |
|
| 177 | +def test_advanced_slice1_negative_strides(): |
| 178 | + q = get_queue_or_skip() |
| 179 | + ii = dpt.asarray([0, 1], sycl_queue=q) |
| 180 | + x = dpt.flip(dpt.arange(5, dtype="i4", sycl_queue=q)) |
| 181 | + y = x[ii] |
| 182 | + assert isinstance(y, dpt.usm_ndarray) |
| 183 | + assert y.shape == ii.shape |
| 184 | + assert y.strides == (1,) |
| 185 | + # FIXME, once usm_ndarray.__equal__ is implemented, |
| 186 | + # use of asnumpy should be removed |
| 187 | + assert _all_equal( |
| 188 | + (x[ii[k]] for k in range(ii.shape[0])), |
| 189 | + (y[k] for k in range(ii.shape[0])), |
| 190 | + ) |
| 191 | + |
| 192 | + |
177 | 193 | def test_advanced_slice2():
|
178 | 194 | q = get_queue_or_skip()
|
179 | 195 | ii = dpt.asarray([1, 2], sycl_queue=q)
|
@@ -363,3 +379,59 @@ def test_advanced_slice13():
|
363 | 379 | assert isinstance(y, dpt.usm_ndarray)
|
364 | 380 | assert y.shape == expected.shape
|
365 | 381 | assert (dpt.asnumpy(y) == dpt.asnumpy(expected)).all()
|
| 382 | + |
| 383 | + |
| 384 | +def test_integer_indexing_1d(): |
| 385 | + get_queue_or_skip() |
| 386 | + x = dpt.arange(10, dtype="i4") |
| 387 | + ind_1d = dpt.asarray([7, 3, 1], dtype="u2") |
| 388 | + ind_2d = dpt.asarray([[2, 3, 4], [3, 4, 5], [5, 6, 7]], dtype="i4") |
| 389 | + |
| 390 | + y1 = x[ind_1d] |
| 391 | + assert y1.shape == ind_1d.shape |
| 392 | + y2 = x[ind_2d] |
| 393 | + assert y2.shape == ind_2d.shape |
| 394 | + assert (dpt.asnumpy(y1) == np.array([7, 3, 1], dtype="i4")).all() |
| 395 | + assert ( |
| 396 | + dpt.asnumpy(y2) |
| 397 | + == np.array([[2, 3, 4], [3, 4, 5], [5, 6, 7]], dtype="i4") |
| 398 | + ).all() |
| 399 | + |
| 400 | + |
| 401 | +def test_integer_indexing_2d(): |
| 402 | + get_queue_or_skip() |
| 403 | + n0, n1 = 5, 7 |
| 404 | + x = dpt.reshape( |
| 405 | + dpt.arange(n0 * n1, dtype="i4"), |
| 406 | + ( |
| 407 | + n0, |
| 408 | + n1, |
| 409 | + ), |
| 410 | + ) |
| 411 | + ind0 = dpt.arange(n0) |
| 412 | + ind1 = dpt.arange(n1) |
| 413 | + |
| 414 | + y = x[ind0[:2, dpt.newaxis], ind1[dpt.newaxis, -2:]] |
| 415 | + assert y.dtype == x.dtype |
| 416 | + assert (dpt.asnumpy(y) == np.array([[5, 6], [12, 13]])).all() |
| 417 | + |
| 418 | + |
| 419 | +def test_integer_strided_indexing(): |
| 420 | + get_queue_or_skip() |
| 421 | + n0, n1 = 5, 7 |
| 422 | + x = dpt.reshape( |
| 423 | + dpt.arange(2 * n0 * n1, dtype="i4"), |
| 424 | + ( |
| 425 | + 2 * n0, |
| 426 | + n1, |
| 427 | + ), |
| 428 | + ) |
| 429 | + ind0 = dpt.arange(n0) |
| 430 | + ind1 = dpt.arange(n1) |
| 431 | + |
| 432 | + z = x[::-2, :] |
| 433 | + y = z[ind0[:2, dpt.newaxis], ind1[dpt.newaxis, -2:]] |
| 434 | + assert y.dtype == x.dtype |
| 435 | + zc = dpt.copy(z, order="C") |
| 436 | + yc = zc[ind0[:2, dpt.newaxis], ind1[dpt.newaxis, -2:]] |
| 437 | + assert (dpt.asnumpy(y) == dpt.asnumpy(yc)).all() |
0 commit comments