|
8 | 8 | from itertools import chain, combinations |
9 | 9 |
|
10 | 10 | import numpy as np |
11 | | -import pytest |
| 11 | +import xarray as xr |
12 | 12 | from xarray import DataArray |
13 | 13 | from xarray import concat as xr_concat |
14 | 14 |
|
| 15 | +from pytensor.tensor import scalar |
15 | 16 | from pytensor.xtensor.shape import ( |
16 | 17 | concat, |
| 18 | + expand_dims, |
17 | 19 | squeeze, |
18 | 20 | stack, |
19 | 21 | transpose, |
@@ -369,3 +371,153 @@ def test_squeeze_errors(): |
369 | 371 | fn2 = xr_function([x2], y2) |
370 | 372 | with pytest.raises(Exception): |
371 | 373 | fn2(x2_test) |
| 374 | + |
| 375 | + |
| 376 | +def test_expand_dims_explicit(): |
| 377 | + """Test expand_dims with explicitly named dimensions and sizes.""" |
| 378 | + |
| 379 | + # 1D case |
| 380 | + x = xtensor("x", dims=("city",), shape=(3,)) |
| 381 | + y = expand_dims(x, "country") |
| 382 | + fn = xr_function([x], y) |
| 383 | + x_xr = xr_arange_like(x) |
| 384 | + xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country")) |
| 385 | + |
| 386 | + # 2D case |
| 387 | + x = xtensor("x", dims=("city", "year"), shape=(2, 2)) |
| 388 | + y = expand_dims(x, "country") |
| 389 | + fn = xr_function([x], y) |
| 390 | + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) |
| 391 | + |
| 392 | + # 3D case |
| 393 | + x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2)) |
| 394 | + y = expand_dims(x, "country") |
| 395 | + fn = xr_function([x], y) |
| 396 | + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) |
| 397 | + |
| 398 | + # Prepending various dims |
| 399 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 400 | + for new_dim in ("x", "y", "z"): |
| 401 | + y = expand_dims(x, new_dim) |
| 402 | + assert y.type.dims == (new_dim, "a", "b") |
| 403 | + assert y.type.shape == (1, 2, 3) |
| 404 | + |
| 405 | + # Explicit size=1 behaves like default |
| 406 | + y1 = expand_dims(x, "batch", size=1) |
| 407 | + y2 = expand_dims(x, "batch") |
| 408 | + fn1 = xr_function([x], y1) |
| 409 | + fn2 = xr_function([x], y2) |
| 410 | + x_test = xr_arange_like(x) |
| 411 | + xr_assert_allclose(fn1(x_test), fn2(x_test)) |
| 412 | + |
| 413 | + # Scalar expansion |
| 414 | + x = xtensor("x", dims=(), shape=()) |
| 415 | + y = expand_dims(x, "batch") |
| 416 | + assert y.type.dims == ("batch",) |
| 417 | + assert y.type.shape == (1,) |
| 418 | + fn = xr_function([x], y) |
| 419 | + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch")) |
| 420 | + |
| 421 | + # Static size > 1: broadcast |
| 422 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 423 | + y = expand_dims(x, "batch", size=4) |
| 424 | + fn = xr_function([x], y) |
| 425 | + expected = xr.DataArray( |
| 426 | + np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)), |
| 427 | + dims=("batch", "a", "b"), |
| 428 | + coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]}, |
| 429 | + ) |
| 430 | + xr_assert_allclose(fn(xr_arange_like(x)), expected) |
| 431 | + |
| 432 | + # Insert new dim between existing dims |
| 433 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 434 | + y = expand_dims(x, "new") |
| 435 | + # Insert new dim between a and b: ("a", "new", "b") |
| 436 | + y = transpose(y, "a", "new", "b") |
| 437 | + fn = xr_function([x], y) |
| 438 | + x_test = xr_arange_like(x) |
| 439 | + expected = x_test.expand_dims("new").transpose("a", "new", "b") |
| 440 | + xr_assert_allclose(fn(x_test), expected) |
| 441 | + |
| 442 | + # Expand with multiple dims |
| 443 | + x = xtensor("x", dims=(), shape=()) |
| 444 | + y = expand_dims(expand_dims(x, "a"), "b") |
| 445 | + fn = xr_function([x], y) |
| 446 | + expected = xr_arange_like(x).expand_dims("a").expand_dims("b") |
| 447 | + xr_assert_allclose(fn(xr_arange_like(x)), expected) |
| 448 | + |
| 449 | + |
| 450 | +def test_expand_dims_implicit(): |
| 451 | + """Test expand_dims with default or symbolic sizes and dim=None.""" |
| 452 | + |
| 453 | + # Symbolic size=1: same as default |
| 454 | + size_sym_1 = scalar("size_sym_1", dtype="int64") |
| 455 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 456 | + y = expand_dims(x, "batch", size=size_sym_1) |
| 457 | + fn = xr_function([x, size_sym_1], y, on_unused_input="ignore") |
| 458 | + expected = xr_arange_like(x).expand_dims("batch") |
| 459 | + xr_assert_allclose(fn(xr_arange_like(x), 1), expected) |
| 460 | + |
| 461 | + # Symbolic size > 1 (but expand only adds dim=1) |
| 462 | + size_sym_4 = scalar("size_sym_4", dtype="int64") |
| 463 | + y = expand_dims(x, "batch", size=size_sym_4) |
| 464 | + fn = xr_function([x, size_sym_4], y, on_unused_input="ignore") |
| 465 | + xr_assert_allclose(fn(xr_arange_like(x), 4), expected) |
| 466 | + |
| 467 | + # Reversibility: expand then squeeze |
| 468 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 469 | + y = expand_dims(x, "batch") |
| 470 | + z = squeeze(y, "batch") |
| 471 | + fn = xr_function([x], z) |
| 472 | + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x)) |
| 473 | + |
| 474 | + # expand_dims with dim=None = no-op |
| 475 | + x = xtensor("x", dims=("a",), shape=(3,)) |
| 476 | + y = expand_dims(x, None) |
| 477 | + fn = xr_function([x], y) |
| 478 | + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x)) |
| 479 | + |
| 480 | + # broadcast after symbolic size |
| 481 | + size_sym = scalar("size_sym", dtype="int64") |
| 482 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
| 483 | + y = expand_dims(x, "batch", size=size_sym) |
| 484 | + z = y + y # triggers shape alignment |
| 485 | + fn = xr_function([x, size_sym], z, on_unused_input="ignore") |
| 486 | + x_test = xr_arange_like(x) |
| 487 | + out = fn(x_test, 1) |
| 488 | + expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") |
| 489 | + xr_assert_allclose(out, expected) |
| 490 | + |
| 491 | + |
| 492 | +def test_expand_dims_errors(): |
| 493 | + """Test error handling in expand_dims.""" |
| 494 | + |
| 495 | + # Expanding existing dim |
| 496 | + x = xtensor("x", dims=("city",), shape=(3,)) |
| 497 | + y = expand_dims(x, "country") |
| 498 | + with pytest.raises(ValueError, match="already exists"): |
| 499 | + expand_dims(y, "city") |
| 500 | + |
| 501 | + # Size = 0 is invalid |
| 502 | + with pytest.raises(ValueError, match="size must be.*positive"): |
| 503 | + expand_dims(x, "batch", size=0) |
| 504 | + |
| 505 | + # Invalid dim type |
| 506 | + with pytest.raises(TypeError): |
| 507 | + expand_dims(x, 123) |
| 508 | + |
| 509 | + # Invalid size type |
| 510 | + with pytest.raises(TypeError): |
| 511 | + expand_dims(x, "new", size=[1]) |
| 512 | + |
| 513 | + # Duplicate dimension creation |
| 514 | + y = expand_dims(x, "new") |
| 515 | + with pytest.raises(ValueError): |
| 516 | + expand_dims(y, "new") |
| 517 | + |
| 518 | + # Symbolic size with invalid runtime value |
| 519 | + size_sym = scalar("size_sym", dtype="int64") |
| 520 | + y = expand_dims(x, "batch", size=size_sym) |
| 521 | + fn = xr_function([x, size_sym], y, on_unused_input="ignore") |
| 522 | + with pytest.raises(Exception): |
| 523 | + fn(xr_arange_like(x), 0) |
0 commit comments