|
8 | 8 | from itertools import chain, combinations |
9 | 9 |
|
10 | 10 | import numpy as np |
11 | | -import pytest |
| 11 | +import xarray as xr |
12 | 12 | from xarray import DataArray |
13 | 13 | from xarray import concat as xr_concat |
14 | 14 |
|
| 15 | +from pytensor.tensor import scalar |
15 | 16 | from pytensor.xtensor.shape import ( |
16 | 17 | concat, |
| 18 | + expand_dims, |
17 | 19 | squeeze, |
18 | 20 | stack, |
19 | 21 | transpose, |
@@ -301,6 +303,15 @@ def test_squeeze_explicit_dims(): |
301 | 303 | fn3d = xr_function([x3], y3d) |
302 | 304 | xr_assert_allclose(fn3d(x3_test), x3_test) |
303 | 305 |
|
| 306 | + # Reversibility with expand_dims |
| 307 | + x6 = xtensor("x6", dims=("a", "b", "c"), shape=(2, 1, 3)) |
| 308 | + y6 = squeeze(x6, "b") |
| 309 | + # First expand_dims adds at front, then transpose puts it in the right place |
| 310 | + z6 = transpose(expand_dims(y6, "b"), "a", "b", "c") |
| 311 | + fn6 = xr_function([x6], z6) |
| 312 | + x6_test = xr_arange_like(x6) |
| 313 | + xr_assert_allclose(fn6(x6_test), x6_test) |
| 314 | + |
304 | 315 |
|
305 | 316 | def test_squeeze_implicit_dims(): |
306 | 317 | """Test squeeze with implicit dim=None (all size-1 dimensions).""" |
@@ -369,3 +380,199 @@ def test_squeeze_errors(): |
369 | 380 | fn2 = xr_function([x2], y2) |
370 | 381 | with pytest.raises(Exception): |
371 | 382 | fn2(x2_test) |
| 383 | + |
| 384 | + |
| 385 | +def test_expand_dims_explicit(): |
| 386 | + """Test expand_dims with explicitly named dimensions and sizes.""" |
| 387 | + |
| 388 | + # 1D case |
| 389 | + x = xtensor("x", dims=("city",), shape=(3,)) |
| 390 | + y = expand_dims(x, "country") |
| 391 | + fn = xr_function([x], y) |
| 392 | + x_xr = xr_arange_like(x) |
| 393 | + xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country")) |
| 394 | + |
| 395 | + # 2D case |
| 396 | + x = xtensor("x", dims=("city", "year"), shape=(2, 2)) |
| 397 | + y = expand_dims(x, "country") |
| 398 | + fn = xr_function([x], y) |
| 399 | + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) |
| 400 | + |
| 401 | + # 3D case |
| 402 | + x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2)) |
| 403 | + y = expand_dims(x, "country") |
| 404 | + fn = xr_function([x], y) |
| 405 | + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) |
| 406 | + |
| 407 | + # Prepending various dims |
| 408 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 409 | + for new_dim in ("x", "y", "z"): |
| 410 | + y = expand_dims(x, new_dim) |
| 411 | + assert y.type.dims == (new_dim, "a", "b") |
| 412 | + assert y.type.shape == (1, 2, 3) |
| 413 | + |
| 414 | + # Explicit size=1 behaves like default |
| 415 | + y1 = expand_dims(x, "batch", size=1) |
| 416 | + y2 = expand_dims(x, "batch") |
| 417 | + fn1 = xr_function([x], y1) |
| 418 | + fn2 = xr_function([x], y2) |
| 419 | + x_test = xr_arange_like(x) |
| 420 | + xr_assert_allclose(fn1(x_test), fn2(x_test)) |
| 421 | + |
| 422 | + # Scalar expansion |
| 423 | + x = xtensor("x", dims=(), shape=()) |
| 424 | + y = expand_dims(x, "batch") |
| 425 | + assert y.type.dims == ("batch",) |
| 426 | + assert y.type.shape == (1,) |
| 427 | + fn = xr_function([x], y) |
| 428 | + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch")) |
| 429 | + |
| 430 | + # Static size > 1: broadcast |
| 431 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 432 | + y = expand_dims(x, "batch", size=4) |
| 433 | + fn = xr_function([x], y) |
| 434 | + expected = xr.DataArray( |
| 435 | + np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)), |
| 436 | + dims=("batch", "a", "b"), |
| 437 | + coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]}, |
| 438 | + ) |
| 439 | + xr_assert_allclose(fn(xr_arange_like(x)), expected) |
| 440 | + |
| 441 | + # Insert new dim between existing dims |
| 442 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 443 | + y = expand_dims(x, "new") |
| 444 | + # Insert new dim between a and b: ("a", "new", "b") |
| 445 | + y = transpose(y, "a", "new", "b") |
| 446 | + fn = xr_function([x], y) |
| 447 | + x_test = xr_arange_like(x) |
| 448 | + expected = x_test.expand_dims("new").transpose("a", "new", "b") |
| 449 | + xr_assert_allclose(fn(x_test), expected) |
| 450 | + |
| 451 | + # Expand with multiple dims |
| 452 | + x = xtensor("x", dims=(), shape=()) |
| 453 | + y = expand_dims(expand_dims(x, "a"), "b") |
| 454 | + fn = xr_function([x], y) |
| 455 | + expected = xr_arange_like(x).expand_dims("a").expand_dims("b") |
| 456 | + xr_assert_allclose(fn(xr_arange_like(x)), expected) |
| 457 | + |
| 458 | + |
| 459 | +def test_expand_dims_implicit(): |
| 460 | + """Test expand_dims with default or symbolic sizes and dim=None.""" |
| 461 | + |
| 462 | + # Symbolic size=1: same as default |
| 463 | + size_sym_1 = scalar("size_sym_1", dtype="int64") |
| 464 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 465 | + y = expand_dims(x, "batch", size=size_sym_1) |
| 466 | + fn = xr_function([x, size_sym_1], y, on_unused_input="ignore") |
| 467 | + x_test = xr_arange_like(x) |
| 468 | + xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch")) |
| 469 | + |
| 470 | + # Symbolic size > 1 (but expand only adds dim=1) |
| 471 | + size_sym_4 = scalar("size_sym_4", dtype="int64") |
| 472 | + y = expand_dims(x, "batch", size=size_sym_4) |
| 473 | + fn = xr_function([x, size_sym_4], y, on_unused_input="ignore") |
| 474 | + xr_assert_allclose(fn(x_test, 4), x_test.expand_dims("batch")) |
| 475 | + |
| 476 | + # Symbolic size > 1 with broadcasting |
| 477 | + size_sym_4 = scalar("size_sym_4", dtype="int64") |
| 478 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 479 | + y = expand_dims(x, "batch", size=size_sym_4) |
| 480 | + z = y + y # This should broadcast along the batch dimension |
| 481 | + fn = xr_function([x, size_sym_4], z, on_unused_input="ignore") |
| 482 | + x_test = xr_arange_like(x) |
| 483 | + out = fn(x_test, 4) |
| 484 | + expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") |
| 485 | + xr_assert_allclose(out, expected) |
| 486 | + |
| 487 | + # Symbolic size with shape validation |
| 488 | + size_sym = scalar("size_sym", dtype="int64") |
| 489 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 490 | + y = expand_dims(x, "batch", size=size_sym) |
| 491 | + z = y + y # This should validate the shape |
| 492 | + fn = xr_function([x, size_sym], z, on_unused_input="ignore") |
| 493 | + x_test = xr_arange_like(x) |
| 494 | + out = fn(x_test, 4) |
| 495 | + expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") |
| 496 | + xr_assert_allclose(out, expected) |
| 497 | + |
| 498 | + # Symbolic size with subsequent operations |
| 499 | + size_sym = scalar("size_sym", dtype="int64") |
| 500 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 501 | + y = expand_dims(x, "batch", size=size_sym) |
| 502 | + z = y.sum("batch") # This should work with symbolic size |
| 503 | + fn = xr_function([x, size_sym], z, on_unused_input="ignore") |
| 504 | + x_test = xr_arange_like(x) |
| 505 | + out = fn(x_test, 4) |
| 506 | + expected = x_test.expand_dims("batch").sum("batch") |
| 507 | + xr_assert_allclose(out, expected) |
| 508 | + |
| 509 | + # Symbolic size with transpose and broadcasting |
| 510 | + size_sym = scalar("size_sym", dtype="int64") |
| 511 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 512 | + y = expand_dims(x, "batch", size=size_sym) |
| 513 | + z = transpose(y, "batch", "a", "b") # This should work with symbolic size |
| 514 | + fn = xr_function([x, size_sym], z, on_unused_input="ignore") |
| 515 | + x_test = xr_arange_like(x) |
| 516 | + out = fn(x_test, 4) |
| 517 | + expected = x_test.expand_dims("batch").transpose("batch", "a", "b") |
| 518 | + xr_assert_allclose(out, expected) |
| 519 | + |
| 520 | + # Reversibility: expand then squeeze |
| 521 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 522 | + y = expand_dims(x, "batch") |
| 523 | + z = squeeze(y, "batch") |
| 524 | + fn = xr_function([x], z) |
| 525 | + x_test = xr_arange_like(x) |
| 526 | + xr_assert_allclose(fn(x_test), x_test) |
| 527 | + |
| 528 | + # expand_dims with dim=None = no-op |
| 529 | + x = xtensor("x", dims=("a",), shape=(3,)) |
| 530 | + y = expand_dims(x, None) |
| 531 | + fn = xr_function([x], y) |
| 532 | + x_test = xr_arange_like(x) |
| 533 | + xr_assert_allclose(fn(x_test), x_test) |
| 534 | + |
| 535 | + # broadcast after symbolic size |
| 536 | + size_sym = scalar("size_sym", dtype="int64") |
| 537 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 538 | + y = expand_dims(x, "batch", size=size_sym) |
| 539 | + z = y + y # triggers shape alignment |
| 540 | + fn = xr_function([x, size_sym], z, on_unused_input="ignore") |
| 541 | + x_test = xr_arange_like(x) |
| 542 | + out = fn(x_test, 1) |
| 543 | + expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") |
| 544 | + xr_assert_allclose(out, expected) |
| 545 | + |
| 546 | + |
| 547 | +def test_expand_dims_errors(): |
| 548 | + """Test error handling in expand_dims.""" |
| 549 | + |
| 550 | + # Expanding existing dim |
| 551 | + x = xtensor("x", dims=("city",), shape=(3,)) |
| 552 | + y = expand_dims(x, "country") |
| 553 | + with pytest.raises(ValueError, match="already exists"): |
| 554 | + expand_dims(y, "city") |
| 555 | + |
| 556 | + # Size = 0 is invalid |
| 557 | + with pytest.raises(ValueError, match="size must be.*positive"): |
| 558 | + expand_dims(x, "batch", size=0) |
| 559 | + |
| 560 | + # Invalid dim type |
| 561 | + with pytest.raises(TypeError): |
| 562 | + expand_dims(x, 123) |
| 563 | + |
| 564 | + # Invalid size type |
| 565 | + with pytest.raises(TypeError): |
| 566 | + expand_dims(x, "new", size=[1]) |
| 567 | + |
| 568 | + # Duplicate dimension creation |
| 569 | + y = expand_dims(x, "new") |
| 570 | + with pytest.raises(ValueError): |
| 571 | + expand_dims(y, "new") |
| 572 | + |
| 573 | + # Symbolic size with invalid runtime value |
| 574 | + size_sym = scalar("size_sym", dtype="int64") |
| 575 | + y = expand_dims(x, "batch", size=size_sym) |
| 576 | + fn = xr_function([x, size_sym], y, on_unused_input="ignore") |
| 577 | + with pytest.raises(Exception): |
| 578 | + fn(xr_arange_like(x), 0) |
0 commit comments