|
8 | 8 | from itertools import chain, combinations |
9 | 9 |
|
10 | 10 | import numpy as np |
11 | | -import xarray as xr |
12 | 11 | from xarray import DataArray |
13 | 12 | from xarray import concat as xr_concat |
14 | 13 |
|
15 | 14 | from pytensor.tensor import scalar |
16 | 15 | from pytensor.xtensor.shape import ( |
17 | 16 | concat, |
18 | | - expand_dims, |
19 | 17 | squeeze, |
20 | 18 | stack, |
21 | 19 | transpose, |
@@ -373,211 +371,97 @@ def test_squeeze_errors(): |
373 | 371 | fn2(x2_test) |
374 | 372 |
|
375 | 373 |
|
376 | | -def test_expand_dims_explicit(): |
377 | | - """Test expand_dims with explicitly named dimensions and sizes.""" |
378 | | - |
379 | | - # 1D case |
380 | | - x = xtensor("x", dims=("city",), shape=(3,)) |
381 | | - y = expand_dims(x, "country") |
382 | | - fn = xr_function([x], y) |
383 | | - x_xr = xr_arange_like(x) |
384 | | - xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country")) |
385 | | - |
386 | | - # 2D case |
| 374 | +def test_expand_dims(): |
| 375 | + """Test expand_dims.""" |
387 | 376 | x = xtensor("x", dims=("city", "year"), shape=(2, 2)) |
388 | | - y = expand_dims(x, "country") |
389 | | - fn = xr_function([x], y) |
390 | | - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) |
391 | | - |
392 | | - # 3D case |
393 | | - x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2)) |
394 | | - y = expand_dims(x, "country") |
395 | | - fn = xr_function([x], y) |
396 | | - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) |
397 | | - |
398 | | - # Prepending various dims |
399 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
400 | | - for new_dim in ("x", "y", "z"): |
401 | | - y = expand_dims(x, new_dim) |
402 | | - assert y.type.dims == (new_dim, "a", "b") |
403 | | - assert y.type.shape == (1, 2, 3) |
404 | | - |
405 | | - # Explicit size=1 behaves like default |
406 | | - y1 = expand_dims(x, "batch", size=1) |
407 | | - y2 = expand_dims(x, "batch") |
408 | | - fn1 = xr_function([x], y1) |
409 | | - fn2 = xr_function([x], y2) |
410 | 377 | x_test = xr_arange_like(x) |
411 | | - xr_assert_allclose(fn1(x_test), fn2(x_test)) |
412 | 378 |
|
413 | | - # Scalar expansion |
414 | | - x = xtensor("x", dims=(), shape=()) |
415 | | - y = expand_dims(x, "batch") |
416 | | - assert y.type.dims == ("batch",) |
417 | | - assert y.type.shape == (1,) |
| 379 | + # Implicit size=1 |
| 380 | + y = x.expand_dims("country") |
418 | 381 | fn = xr_function([x], y) |
419 | | - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch")) |
| 382 | + xr_assert_allclose(fn(x_test), x_test.expand_dims("country")) |
420 | 383 |
|
421 | | - # Static size > 1: broadcast |
422 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
423 | | - y = expand_dims(x, "batch", size=4) |
424 | | - fn = xr_function([x], y) |
425 | | - expected = xr.DataArray( |
426 | | - np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)), |
427 | | - dims=("batch", "a", "b"), |
428 | | - coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]}, |
429 | | - ) |
430 | | - xr_assert_allclose(fn(xr_arange_like(x)), expected) |
| 384 | + # Explicit size=1 |
| 385 | + y = x.expand_dims("country", size=1) |
| 386 | + xr_assert_allclose(fn(x_test), x_test.expand_dims("country")) |
431 | 387 |
|
432 | | - # Insert new dim between existing dims |
433 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
434 | | - y = expand_dims(x, "new") |
435 | | - # Insert new dim between a and b: ("a", "new", "b") |
436 | | - y = transpose(y, "a", "new", "b") |
| 388 | + # Explicit size > 1 |
| 389 | + y = x.expand_dims("country", size=4) |
437 | 390 | fn = xr_function([x], y) |
438 | | - x_test = xr_arange_like(x) |
439 | | - expected = x_test.expand_dims("new").transpose("a", "new", "b") |
440 | | - xr_assert_allclose(fn(x_test), expected) |
| 391 | + xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 4})) |
441 | 392 |
|
442 | | - # Expand with multiple dims |
443 | | - x = xtensor("x", dims=(), shape=()) |
444 | | - y = expand_dims(expand_dims(x, "a"), "b") |
| 393 | + # Test with multiple dimensions |
| 394 | + y = x.expand_dims(["country", "state"]) |
445 | 395 | fn = xr_function([x], y) |
446 | | - expected = xr_arange_like(x).expand_dims("a").expand_dims("b") |
447 | | - xr_assert_allclose(fn(xr_arange_like(x)), expected) |
| 396 | + xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"])) |
448 | 397 |
|
| 398 | + # Test with a dict of sizes |
| 399 | + y = x.expand_dims({"country": 2, "state": 3}) |
| 400 | + fn = xr_function([x], y) |
| 401 | + xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3})) |
449 | 402 |
|
450 | | -def test_expand_dims_symbolic_size(): |
451 | | - """Test expand_dims with symbolic sizes.""" |
452 | | - |
453 | | - # Symbolic size=1: same as default |
454 | | - size_sym_1 = scalar("size_sym_1", dtype="int64") |
455 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
456 | | - y = expand_dims(x, "batch", size=size_sym_1) |
457 | | - fn = xr_function([x, size_sym_1], y, on_unused_input="ignore") |
458 | | - x_test = xr_arange_like(x) |
459 | | - xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch")) |
460 | | - |
461 | | - # Test using symbolic size from an existing dimension of the same tensor |
462 | | - # This verifies that expand_dims can use the size of one dimension to create another |
463 | | - x = xtensor(dims=("a", "b", "c")) |
464 | | - y = expand_dims(x, "d", size=x.sizes["b"]) |
| 403 | + # Test with kwargs (equivalent to dict) |
| 404 | + y = x.expand_dims(country=2, state=3) |
465 | 405 | fn = xr_function([x], y) |
466 | | - x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 3, 5))) |
467 | | - res = fn(x_test) |
468 | | - expected = x_test.expand_dims({"d": 3}) # 3 is the size of dimension "b" |
469 | | - xr_assert_allclose(res, expected) |
| 406 | + xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3)) |
470 | 407 |
|
471 | | - # Test broadcasting with symbolic size from a different tensor |
472 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
473 | | - other = xtensor("other", dims=("c",), shape=(4,)) |
474 | | - y = expand_dims(x, "batch", size=other.sizes["c"]) |
475 | | - fn = xr_function([x, other], y) |
476 | | - x_test = xr_arange_like(x) |
477 | | - other_test = xr_arange_like(other) |
478 | | - res = fn(x_test, other_test) |
479 | | - expected = x_test.expand_dims( |
480 | | - {"batch": 4} |
481 | | - ) # 4 is the size of dimension "c" in other |
482 | | - xr_assert_allclose(res, expected) |
| 408 | + # Symbolic size=1 |
| 409 | + size_sym_1 = scalar("size_sym_1", dtype="int64") |
| 410 | + y = x.expand_dims("country", size=size_sym_1) |
| 411 | + fn = xr_function([x, size_sym_1], y) |
| 412 | + xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("country")) |
483 | 413 |
|
484 | 414 | # Test behavior with symbolic size > 1 |
485 | 415 | # NOTE: This test documents our current behavior where expand_dims broadcasts to the requested size. |
486 | 416 | # This differs from xarray's behavior where expand_dims always adds a size-1 dimension. |
487 | 417 | size_sym_4 = scalar("size_sym_4", dtype="int64") |
488 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
489 | | - y = expand_dims(x, "batch", size=size_sym_4) |
490 | | - fn = xr_function([x, size_sym_4], y, on_unused_input="ignore") |
491 | | - x_test = xr_arange_like(x) |
| 418 | + y = x.expand_dims("country", size=size_sym_4) |
| 419 | + fn = xr_function([x, size_sym_4], y) |
492 | 420 | res = fn(x_test, 4) |
493 | 421 | # Our current behavior: broadcasts to size 4 |
494 | | - expected = x_test.expand_dims({"batch": 4}) |
| 422 | + expected = x_test.expand_dims({"country": 4}) |
495 | 423 | xr_assert_allclose(res, expected) |
496 | 424 | # xarray's behavior would be: |
497 | | - # expected = x_test.expand_dims("batch") # always size 1 |
| 425 | + # expected = x_test.expand_dims("country") # always size 1 |
498 | 426 | # xr_assert_allclose(res, expected) |
499 | 427 |
|
500 | | - # Test using symbolic size from a reduction operation |
501 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
502 | | - reduced = x.sum("a") # shape: (b: 3) |
503 | | - y = expand_dims(x, "batch", size=reduced.sizes["b"]) |
504 | | - fn = xr_function([x], y) |
505 | | - x_test = xr_arange_like(x) |
506 | | - res = fn(x_test) |
507 | | - expected = x_test.expand_dims({"batch": 3}) # 3 is the size of dimension "b" |
508 | | - xr_assert_allclose(res, expected) |
509 | | - |
510 | | - # Test chaining expand_dims with symbolic sizes |
511 | | - x = xtensor("x", dims=("a",), shape=(2,)) |
512 | | - y = expand_dims(x, "b", size=x.sizes["a"]) # shape: (a: 2, b: 2) |
513 | | - z = expand_dims(y, "c", size=y.sizes["b"]) # shape: (a: 2, b: 2, c: 2) |
514 | | - fn = xr_function([x], z) |
515 | | - x_test = xr_arange_like(x) |
516 | | - res = fn(x_test) |
517 | | - expected = x_test.expand_dims({"b": 2}).expand_dims({"c": 2}) |
518 | | - xr_assert_allclose(res, expected) |
| 428 | + # Test with symbolic sizes in dict |
| 429 | + size_sym_1 = scalar("size_sym_1", dtype="int64") |
| 430 | + size_sym_2 = scalar("size_sym_2", dtype="int64") |
| 431 | + y = x.expand_dims({"country": size_sym_1, "state": size_sym_2}) |
| 432 | + fn = xr_function([x, size_sym_1, size_sym_2], y) |
| 433 | + xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3})) |
519 | 434 |
|
520 | | - # Test bidirectional broadcasting with symbolic sizes |
521 | | - x = xtensor("x", dims=("a",), shape=(2,)) |
522 | | - y = xtensor("y", dims=("b",), shape=(3,)) |
523 | | - # Expand x with size from y, then add y |
524 | | - expanded = expand_dims(x, "b", size=y.sizes["b"]) |
525 | | - z = expanded + y # Should broadcast x to match y's size |
526 | | - fn = xr_function([x, y], z) |
527 | | - x_test = xr_arange_like(x) |
528 | | - y_test = xr_arange_like(y) |
529 | | - res = fn(x_test, y_test) |
530 | | - expected = x_test.expand_dims({"b": 3}) + y_test |
531 | | - xr_assert_allclose(res, expected) |
| 435 | + # Test with symbolic sizes in kwargs |
| 436 | + y = x.expand_dims(country=size_sym_1, state=size_sym_2) |
| 437 | + fn = xr_function([x, size_sym_1, size_sym_2], y) |
| 438 | + xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3})) |
532 | 439 |
|
533 | 440 |
|
534 | 441 | def test_expand_dims_errors(): |
535 | 442 | """Test error handling in expand_dims.""" |
536 | 443 |
|
537 | 444 | # Expanding existing dim |
538 | 445 | x = xtensor("x", dims=("city",), shape=(3,)) |
539 | | - y = expand_dims(x, "country") |
| 446 | + y = x.expand_dims("country") |
540 | 447 | with pytest.raises(ValueError, match="already exists"): |
541 | | - expand_dims(y, "city") |
| 448 | + y.expand_dims("city") |
542 | 449 |
|
543 | 450 | # Invalid dim type |
544 | 451 | with pytest.raises(TypeError, match="Invalid type for `dim`"): |
545 | | - expand_dims(x, 123) |
| 452 | + x.expand_dims(123) |
546 | 453 |
|
547 | 454 | # Invalid size type |
548 | 455 | with pytest.raises(TypeError, match="size must be an int or scalar variable"): |
549 | | - expand_dims(x, "new", size=[1]) |
| 456 | + x.expand_dims("new", size=[1]) |
550 | 457 |
|
551 | 458 | # Duplicate dimension creation |
552 | | - y = expand_dims(x, "new") |
| 459 | + y = x.expand_dims("new") |
553 | 460 | with pytest.raises(ValueError, match="already exists"): |
554 | | - expand_dims(y, "new") |
555 | | - |
| 461 | + y.expand_dims("new") |
556 | 462 |
|
557 | | -def test_expand_dims_multiple(): |
558 | | - """Test expanding multiple dimensions at once using a list of strings.""" |
559 | | - x = xtensor("x", dims=("city",), shape=(3,)) |
560 | | - y = expand_dims(x, ["country", "state"]) |
561 | | - fn = xr_function([x], y) |
562 | | - x_xr = xr_arange_like(x) |
563 | | - xr_assert_allclose(fn(x_xr), x_xr.expand_dims(["country", "state"])) |
564 | | - |
565 | | - # Test with a dict of sizes |
566 | | - y = expand_dims(x, {"country": 2, "state": 3}) |
567 | | - fn = xr_function([x], y) |
568 | | - x_xr = xr_arange_like(x) |
569 | | - xr_assert_allclose(fn(x_xr), x_xr.expand_dims({"country": 2, "state": 3})) |
570 | | - |
571 | | - # Test with a mix of strings and dicts |
572 | | - y = expand_dims(x, ["country", "state"], size=3) |
573 | | - fn = xr_function([x], y) |
574 | | - x_xr = xr_arange_like(x) |
575 | | - xr_assert_allclose(fn(x_xr), x_xr.expand_dims(["country", "state"])) |
576 | | - |
577 | | - # Test with symbolic sizes in dict |
578 | | - size_sym_1 = scalar("size_sym_1", dtype="int64") |
579 | | - size_sym_2 = scalar("size_sym_2", dtype="int64") |
580 | | - y = expand_dims(x, {"country": size_sym_1, "state": size_sym_2}) |
581 | | - fn = xr_function([x, size_sym_1, size_sym_2], y, on_unused_input="ignore") |
582 | | - x_xr = xr_arange_like(x) |
583 | | - xr_assert_allclose(fn(x_xr, 2, 3), x_xr.expand_dims({"country": 2, "state": 3})) |
| 463 | + # Test for error when both positional and size are given |
| 464 | + with pytest.raises( |
| 465 | + ValueError, match="cannot specify both keyword and positional arguments" |
| 466 | + ): |
| 467 | + x.expand_dims(["country", "state"], size=3) |
0 commit comments