|
17 | 17 |
|
18 | 18 | import numpy as np
|
19 | 19 | import pytest
|
20 |
| -from helper import get_queue_or_skip |
| 20 | +from helper import get_queue_or_skip, skip_if_dtype_not_supported |
21 | 21 |
|
22 | 22 | # import dpctl
|
23 | 23 | import dpctl.tensor as dpt
|
24 | 24 |
|
25 |
| -# from helper import skip_if_dtype_not_supported |
26 |
| - |
27 | 25 |
|
28 | 26 | def test_basic_slice1():
|
29 | 27 | q = get_queue_or_skip()
|
@@ -435,3 +433,82 @@ def test_integer_strided_indexing():
|
435 | 433 | zc = dpt.copy(z, order="C")
|
436 | 434 | yc = zc[ind0[:2, dpt.newaxis], ind1[dpt.newaxis, -2:]]
|
437 | 435 | assert (dpt.asnumpy(y) == dpt.asnumpy(yc)).all()
|
| 436 | + |
| 437 | + |
| 438 | +@pytest.mark.parametrize( |
| 439 | + "data_dt", |
| 440 | + ["u1", "i1", "u2", "i2", "u4", "i4", "u8", "i8", "e", "f", "d", "F", "D"], |
| 441 | +) |
| 442 | +@pytest.mark.parametrize( |
| 443 | + "ind_dt", ["u1", "i1", "u2", "i2", "u4", "i4", "u8", "i8"] |
| 444 | +) |
| 445 | +def test_take_basic(data_dt, ind_dt): |
| 446 | + q = get_queue_or_skip() |
| 447 | + skip_if_dtype_not_supported(data_dt, q) |
| 448 | + |
| 449 | + x = dpt.arange(10, dtype=data_dt) |
| 450 | + ind = dpt.arange(2, 5, dtype=ind_dt) |
| 451 | + y = dpt.take(x, ind) |
| 452 | + assert y.dtype == x.dtype |
| 453 | + assert (dpt.asnumpy(y) == np.arange(2, 5, dtype=data_dt)).all() |
| 454 | + |
| 455 | + |
| 456 | +@pytest.mark.parametrize( |
| 457 | + "data_dt", |
| 458 | + ["u1", "i1", "u2", "i2", "u4", "i4", "u8", "i8", "e", "f", "d", "F", "D"], |
| 459 | +) |
| 460 | +@pytest.mark.parametrize( |
| 461 | + "ind_dt", ["u1", "i1", "u2", "i2", "u4", "i4", "u8", "i8"] |
| 462 | +) |
| 463 | +def test_put_basic(data_dt, ind_dt): |
| 464 | + q = get_queue_or_skip() |
| 465 | + skip_if_dtype_not_supported(data_dt, q) |
| 466 | + |
| 467 | + x = dpt.arange(10, dtype=data_dt) |
| 468 | + ind = dpt.arange(2, 5, dtype=ind_dt) |
| 469 | + val = dpt.ones(3, dtype=data_dt) |
| 470 | + dpt.put(x, ind, val) |
| 471 | + assert ( |
| 472 | + dpt.asnumpy(x) |
| 473 | + == np.array([0, 1, 1, 1, 1, 5, 6, 7, 8, 9], dtype=data_dt) |
| 474 | + ).all() |
| 475 | + |
| 476 | + |
| 477 | +def test_take_basic_axis(): |
| 478 | + get_queue_or_skip() |
| 479 | + |
| 480 | + n0, n1 = 5, 7 |
| 481 | + x = dpt.reshape( |
| 482 | + dpt.arange(n0 * n1, dtype="i4"), |
| 483 | + ( |
| 484 | + n0, |
| 485 | + n1, |
| 486 | + ), |
| 487 | + ) |
| 488 | + ind = dpt.arange(2, 4) |
| 489 | + y0 = dpt.take(x, ind, axis=0) |
| 490 | + y1 = dpt.take(x, ind, axis=1) |
| 491 | + assert y0.shape == (2, n1) |
| 492 | + assert y1.shape == (n0, 2) |
| 493 | + |
| 494 | + |
| 495 | +def test_put_basic_axis(): |
| 496 | + get_queue_or_skip() |
| 497 | + |
| 498 | + n0, n1 = 5, 7 |
| 499 | + x = dpt.reshape( |
| 500 | + dpt.arange(n0 * n1, dtype="i4"), |
| 501 | + ( |
| 502 | + n0, |
| 503 | + n1, |
| 504 | + ), |
| 505 | + ) |
| 506 | + ind = dpt.arange(2, 4) |
| 507 | + v0 = dpt.zeros((2, n1), dtype=x.dtype) |
| 508 | + v1 = dpt.zeros((n0, 2), dtype=x.dtype) |
| 509 | + dpt.put(x, ind, v0, axis=0) |
| 510 | + dpt.put(x, ind, v1, axis=1) |
| 511 | + expected = np.arange(n0 * n1, dtype="i4").reshape((n0, n1)) |
| 512 | + expected[[2, 3], :] = 0 |
| 513 | + expected[:, [2, 3]] = 0 |
| 514 | + assert (expected == dpt.asnumpy(x)).all() |
0 commit comments