|
11 | 11 |
|
12 | 12 | from pytensor.tensor import tensor
|
13 | 13 | from pytensor.xtensor import xtensor
|
14 |
| -from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function |
| 14 | +from tests.xtensor.util import ( |
| 15 | + xr_arange_like, |
| 16 | + xr_assert_allclose, |
| 17 | + xr_function, |
| 18 | + xr_random_like, |
| 19 | +) |
15 | 20 |
|
16 | 21 |
|
17 | 22 | @pytest.mark.parametrize(
|
@@ -351,3 +356,138 @@ def test_boolean_indexing():
|
351 | 356 | expected_res2 = x_test[bool_idx_test, int_idx_test.rename(a="b")]
|
352 | 357 | xr_assert_allclose(res1, expected_res1)
|
353 | 358 | xr_assert_allclose(res2, expected_res2)
|
| 359 | + |
| 360 | + |
| 361 | +@pytest.mark.parametrize("mode", ("set", "inc")) |
| 362 | +def test_basic_index_update(mode): |
| 363 | + x = xtensor("x", shape=(11, 7), dims=("a", "b")) |
| 364 | + y = xtensor("y", shape=(7, 5), dims=("a", "b")) |
| 365 | + x_indexed = x[2:-2, 2:] |
| 366 | + update_method = getattr(x_indexed, mode) |
| 367 | + |
| 368 | + x_updated = [ |
| 369 | + update_method(y), |
| 370 | + update_method(y.T), |
| 371 | + update_method(y.isel(a=-1)), |
| 372 | + update_method(y.isel(b=-1)), |
| 373 | + update_method(y.isel(a=-2, b=-2)), |
| 374 | + ] |
| 375 | + |
| 376 | + fn = xr_function([x, y], x_updated) |
| 377 | + x_test = xr_random_like(x) |
| 378 | + y_test = xr_random_like(y) |
| 379 | + results = fn(x_test, y_test) |
| 380 | + |
| 381 | + def update_fn(y): |
| 382 | + x = x_test.copy() |
| 383 | + if mode == "set": |
| 384 | + x[2:-2, 2:] = y |
| 385 | + elif mode == "inc": |
| 386 | + x[2:-2, 2:] += y |
| 387 | + return x |
| 388 | + |
| 389 | + expected_results = [ |
| 390 | + update_fn(y_test), |
| 391 | + update_fn(y_test.T), |
| 392 | + update_fn(y_test.isel(a=-1)), |
| 393 | + update_fn(y_test.isel(b=-1)), |
| 394 | + update_fn(y_test.isel(a=-2, b=-2)), |
| 395 | + ] |
| 396 | + for result, expected_result in zip(results, expected_results): |
| 397 | + xr_assert_allclose(result, expected_result) |
| 398 | + |
| 399 | + |
| 400 | +@pytest.mark.parametrize("mode", ("set", "inc")) |
| 401 | +@pytest.mark.parametrize("idx_dtype", (int, bool)) |
| 402 | +def test_adv_index_update(mode, idx_dtype): |
| 403 | + x = xtensor("x", shape=(5, 5), dims=("a", "b")) |
| 404 | + y = xtensor("y", shape=(3,), dims=("b",)) |
| 405 | + idx = xtensor("idx", dtype=idx_dtype, shape=(None,), dims=("a",)) |
| 406 | + |
| 407 | + orthogonal_update1 = getattr(x[idx, -3:], mode)(y) |
| 408 | + orthogonal_update2 = getattr(x[idx, -3:], mode)(y.rename(b="a")) |
| 409 | + if idx_dtype is not bool: |
| 410 | + # Vectorized booling indexing/update is not allowed |
| 411 | + vectorized_update = getattr(x[idx.rename(a="b"), :3], mode)(y) |
| 412 | + else: |
| 413 | + with pytest.raises( |
| 414 | + IndexError, |
| 415 | + match="Boolean indexer should be unlabeled or on the same dimension to the indexed array.", |
| 416 | + ): |
| 417 | + getattr(x[idx.rename(a="b"), :3], mode)(y) |
| 418 | + vectorized_update = x |
| 419 | + |
| 420 | + outs = [orthogonal_update1, orthogonal_update2, vectorized_update] |
| 421 | + |
| 422 | + fn = xr_function([x, idx, y], outs) |
| 423 | + x_test = xr_random_like(x) |
| 424 | + y_test = xr_random_like(y) |
| 425 | + if idx_dtype is int: |
| 426 | + idx_test = DataArray([0, 1, 2], dims=("a",)) |
| 427 | + else: |
| 428 | + idx_test = DataArray([True, False, True, True, False], dims=("a",)) |
| 429 | + results = fn(x_test, idx_test, y_test) |
| 430 | + |
| 431 | + def update_fn(x, idx, y): |
| 432 | + x = x.copy() |
| 433 | + if mode == "set": |
| 434 | + x[idx] = y |
| 435 | + else: |
| 436 | + x[idx] += y |
| 437 | + return x |
| 438 | + |
| 439 | + expected_results = [ |
| 440 | + update_fn(x_test, (idx_test, slice(-3, None)), y_test), |
| 441 | + update_fn( |
| 442 | + x_test, |
| 443 | + (idx_test, slice(-3, None)), |
| 444 | + y_test.rename(b="a"), |
| 445 | + ), |
| 446 | + update_fn(x_test, (idx_test.rename(a="b"), slice(None, 3)), y_test) |
| 447 | + if idx_dtype is not bool |
| 448 | + else x_test, |
| 449 | + ] |
| 450 | + for result, expected_result in zip(results, expected_results): |
| 451 | + xr_assert_allclose(result, expected_result) |
| 452 | + |
| 453 | + |
| 454 | +@pytest.mark.parametrize("mode", ("set", "inc")) |
| 455 | +def test_non_consecutive_idx_update(mode): |
| 456 | + x = xtensor("x", shape=(2, 3, 5, 7), dims=("a", "b", "c", "d")) |
| 457 | + y = xtensor("y", shape=(5, 4), dims=("c", "b")) |
| 458 | + x_indexed = x[:, [0, 1, 2, 2], :, ("b", [0, 1, 1, 2])] |
| 459 | + out = getattr(x_indexed, mode)(y) |
| 460 | + |
| 461 | + fn = xr_function([x, y], out) |
| 462 | + x_test = xr_random_like(x) |
| 463 | + y_test = xr_random_like(y) |
| 464 | + |
| 465 | + result = fn(x_test, y_test) |
| 466 | + expected_result = x_test.copy() |
| 467 | + # xarray fails inplace operation with the "tuple trick" |
| 468 | + # https://github.com/pydata/xarray/issues/10387 |
| 469 | + d_indexer = DataArray([0, 1, 1, 2], dims=("b",)) |
| 470 | + if mode == "set": |
| 471 | + expected_result[:, [0, 1, 2, 2], :, d_indexer] = y_test |
| 472 | + else: |
| 473 | + expected_result[:, [0, 1, 2, 2], :, d_indexer] += y_test |
| 474 | + xr_assert_allclose(result, expected_result) |
| 475 | + |
| 476 | + |
| 477 | +def test_indexing_renames_into_update_variable(): |
| 478 | + x = xtensor("x", shape=(5, 5), dims=("a", "b")) |
| 479 | + y = xtensor("y", shape=(3,), dims=("d",)) |
| 480 | + idx = xtensor("idx", dtype=int, shape=(None,), dims=("d",)) |
| 481 | + |
| 482 | + # define "d" dimension by slicing the "a" dimension so we can set y into x |
| 483 | + orthogonal_update1 = x[idx].set(y) |
| 484 | + fn = xr_function([x, idx, y], orthogonal_update1) |
| 485 | + |
| 486 | + x_test = np.abs(xr_random_like(x)) |
| 487 | + y_test = -np.abs(xr_random_like(y)) |
| 488 | + idx_test = DataArray([0, 2, 3], dims=("d",)) |
| 489 | + |
| 490 | + result = fn(x_test, idx_test, y_test) |
| 491 | + expected_result = x_test.copy() |
| 492 | + expected_result[idx_test] = y_test |
| 493 | + xr_assert_allclose(result, expected_result) |
0 commit comments