Skip to content

Commit 38d87c6

Browse files
committed
Cleanup
1 parent 3b7b973 commit 38d87c6

File tree

4 files changed

+77
-203
lines changed

4 files changed

+77
-203
lines changed

pytensor/xtensor/rewriting/shape.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def lower_transpose(fgraph, node):
123123

124124
@register_lower_xtensor
125125
@node_rewriter([Squeeze])
126-
def local_squeeze_reshape(fgraph, node):
126+
def lower_squeeze(fgraph, node):
127127
"""Rewrite Squeeze to tensor.squeeze."""
128128
[x] = node.inputs
129129
x_tensor = tensor_from_xtensor(x)
@@ -138,7 +138,7 @@ def local_squeeze_reshape(fgraph, node):
138138

139139
@register_lower_xtensor
140140
@node_rewriter([ExpandDims])
141-
def local_expand_dims_reshape(fgraph, node):
141+
def lower_expand_dims(fgraph, node):
142142
"""Rewrite ExpandDims using tensor operations."""
143143
x, size = node.inputs
144144
out = node.outputs[0]
@@ -155,10 +155,8 @@ def local_expand_dims_reshape(fgraph, node):
155155
# Simple case: just expand with size 1
156156
result_tensor = expand_dims(x_tensor, new_axis)
157157
else:
158-
# First expand with size 1
159-
expanded = expand_dims(x_tensor, new_axis)
160-
# Then broadcast to the requested size
161-
result_tensor = broadcast_to(expanded, (size_tensor, *x_tensor.shape))
158+
# Otherwise broadcast to the requested size
159+
result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape))
162160

163161
# Preserve static shape information
164162
result_tensor = specify_shape(result_tensor, out.type.shape)

pytensor/xtensor/shape.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -442,32 +442,16 @@ def make_node(self, x, size):
442442

443443

444444
def expand_dims(x, dim=None, create_index_for_new_dim=True, **dim_kwargs):
445-
"""Add one or more new dimensions to an XTensorVariable.
446-
447-
Parameters
448-
----------
449-
x : XTensorVariable
450-
Input tensor.
451-
dim : str | Sequence[str] | dict[str, int | Sequence] | None
452-
If str or sequence of str, new dimensions with size 1.
453-
If dict, keys are dimension names and values are either:
454-
- int: the new size
455-
- sequence: coordinates (length determines size)
456-
create_index_for_new_dim : bool, default: True
457-
(Ignored for now) Matches xarray API, reserved for future use.
458-
**dim_kwargs : int | Sequence
459-
Alternative to `dim` dict. Only used if `dim` is None.
460-
461-
Returns
462-
-------
463-
XTensorVariable
464-
A tensor with additional dimensions inserted at the front.
465-
"""
445+
"""Add one or more new dimensions to an XTensorVariable."""
466446
x = as_xtensor(x)
467447

468448
# Extract size from dim_kwargs if present
469449
size = dim_kwargs.pop("size", 1) if dim_kwargs else 1
470450

451+
# xarray compatibility: error if a sequence (list/tuple) of dims and size are given
452+
if (isinstance(dim, list | tuple)) and ("size" in locals() and size != 1):
453+
raise ValueError("cannot specify both keyword and positional arguments")
454+
471455
if dim is None:
472456
dim = dim_kwargs
473457
elif dim_kwargs:

pytensor/xtensor/type.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -483,24 +483,32 @@ def squeeze(
483483

484484
def expand_dims(
485485
self,
486-
dim: str | None = None,
487-
size: int | Variable = 1,
486+
dim: str | Sequence[str] | dict[str, int | Sequence] | None = None,
487+
create_index_for_new_dim: bool = True,
488+
**dim_kwargs,
488489
):
489-
"""Add a new dimension to the tensor.
490+
"""Add one or more new dimensions to the tensor.
490491
491492
Parameters
492493
----------
493-
dim : str or None
494-
Name of new dimension. If None, returns self unchanged.
495-
size : int or symbolic, optional
496-
Size of the new dimension (default 1)
494+
dim : str | Sequence[str] | dict[str, int | Sequence] | None
495+
If str or sequence of str, new dimensions with size 1.
496+
If dict, keys are dimension names and values are either:
497+
- int: the new size
498+
- sequence: coordinates (length determines size)
499+
create_index_for_new_dim : bool, default: True
500+
(Ignored for now) Matches xarray API, reserved for future use.
501+
**dim_kwargs : int | Sequence
502+
Alternative to `dim` dict. Only used if `dim` is None.
497503
498504
Returns
499505
-------
500506
XTensorVariable
501-
Tensor with the new dimension inserted
507+
A tensor with additional dimensions inserted at the front.
502508
"""
503-
return px.shape.expand_dims(self, dim, size=size)
509+
return px.shape.expand_dims(
510+
self, dim, create_index_for_new_dim=create_index_for_new_dim, **dim_kwargs
511+
)
504512

505513
# ndarray methods
506514
# https://docs.xarray.dev/en/latest/api.html#id7

tests/xtensor/test_shape.py

Lines changed: 51 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88
from itertools import chain, combinations
99

1010
import numpy as np
11-
import xarray as xr
1211
from xarray import DataArray
1312
from xarray import concat as xr_concat
1413

1514
from pytensor.tensor import scalar
1615
from pytensor.xtensor.shape import (
1716
concat,
18-
expand_dims,
1917
squeeze,
2018
stack,
2119
transpose,
@@ -373,211 +371,97 @@ def test_squeeze_errors():
373371
fn2(x2_test)
374372

375373

376-
def test_expand_dims_explicit():
377-
"""Test expand_dims with explicitly named dimensions and sizes."""
378-
379-
# 1D case
380-
x = xtensor("x", dims=("city",), shape=(3,))
381-
y = expand_dims(x, "country")
382-
fn = xr_function([x], y)
383-
x_xr = xr_arange_like(x)
384-
xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country"))
385-
386-
# 2D case
374+
def test_expand_dims():
375+
"""Test expand_dims."""
387376
x = xtensor("x", dims=("city", "year"), shape=(2, 2))
388-
y = expand_dims(x, "country")
389-
fn = xr_function([x], y)
390-
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country"))
391-
392-
# 3D case
393-
x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2))
394-
y = expand_dims(x, "country")
395-
fn = xr_function([x], y)
396-
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country"))
397-
398-
# Prepending various dims
399-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
400-
for new_dim in ("x", "y", "z"):
401-
y = expand_dims(x, new_dim)
402-
assert y.type.dims == (new_dim, "a", "b")
403-
assert y.type.shape == (1, 2, 3)
404-
405-
# Explicit size=1 behaves like default
406-
y1 = expand_dims(x, "batch", size=1)
407-
y2 = expand_dims(x, "batch")
408-
fn1 = xr_function([x], y1)
409-
fn2 = xr_function([x], y2)
410377
x_test = xr_arange_like(x)
411-
xr_assert_allclose(fn1(x_test), fn2(x_test))
412378

413-
# Scalar expansion
414-
x = xtensor("x", dims=(), shape=())
415-
y = expand_dims(x, "batch")
416-
assert y.type.dims == ("batch",)
417-
assert y.type.shape == (1,)
379+
# Implicit size=1
380+
y = x.expand_dims("country")
418381
fn = xr_function([x], y)
419-
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch"))
382+
xr_assert_allclose(fn(x_test), x_test.expand_dims("country"))
420383

421-
# Static size > 1: broadcast
422-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
423-
y = expand_dims(x, "batch", size=4)
424-
fn = xr_function([x], y)
425-
expected = xr.DataArray(
426-
np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)),
427-
dims=("batch", "a", "b"),
428-
coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]},
429-
)
430-
xr_assert_allclose(fn(xr_arange_like(x)), expected)
384+
# Explicit size=1
385+
y = x.expand_dims("country", size=1)
386+
xr_assert_allclose(fn(x_test), x_test.expand_dims("country"))
431387

432-
# Insert new dim between existing dims
433-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
434-
y = expand_dims(x, "new")
435-
# Insert new dim between a and b: ("a", "new", "b")
436-
y = transpose(y, "a", "new", "b")
388+
# Explicit size > 1
389+
y = x.expand_dims("country", size=4)
437390
fn = xr_function([x], y)
438-
x_test = xr_arange_like(x)
439-
expected = x_test.expand_dims("new").transpose("a", "new", "b")
440-
xr_assert_allclose(fn(x_test), expected)
391+
xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 4}))
441392

442-
# Expand with multiple dims
443-
x = xtensor("x", dims=(), shape=())
444-
y = expand_dims(expand_dims(x, "a"), "b")
393+
# Test with multiple dimensions
394+
y = x.expand_dims(["country", "state"])
445395
fn = xr_function([x], y)
446-
expected = xr_arange_like(x).expand_dims("a").expand_dims("b")
447-
xr_assert_allclose(fn(xr_arange_like(x)), expected)
396+
xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"]))
448397

398+
# Test with a dict of sizes
399+
y = x.expand_dims({"country": 2, "state": 3})
400+
fn = xr_function([x], y)
401+
xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3}))
449402

450-
def test_expand_dims_symbolic_size():
451-
"""Test expand_dims with symbolic sizes."""
452-
453-
# Symbolic size=1: same as default
454-
size_sym_1 = scalar("size_sym_1", dtype="int64")
455-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
456-
y = expand_dims(x, "batch", size=size_sym_1)
457-
fn = xr_function([x, size_sym_1], y, on_unused_input="ignore")
458-
x_test = xr_arange_like(x)
459-
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch"))
460-
461-
# Test using symbolic size from an existing dimension of the same tensor
462-
# This verifies that expand_dims can use the size of one dimension to create another
463-
x = xtensor(dims=("a", "b", "c"))
464-
y = expand_dims(x, "d", size=x.sizes["b"])
403+
# Test with kwargs (equivalent to dict)
404+
y = x.expand_dims(country=2, state=3)
465405
fn = xr_function([x], y)
466-
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 3, 5)))
467-
res = fn(x_test)
468-
expected = x_test.expand_dims({"d": 3}) # 3 is the size of dimension "b"
469-
xr_assert_allclose(res, expected)
406+
xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3))
470407

471-
# Test broadcasting with symbolic size from a different tensor
472-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
473-
other = xtensor("other", dims=("c",), shape=(4,))
474-
y = expand_dims(x, "batch", size=other.sizes["c"])
475-
fn = xr_function([x, other], y)
476-
x_test = xr_arange_like(x)
477-
other_test = xr_arange_like(other)
478-
res = fn(x_test, other_test)
479-
expected = x_test.expand_dims(
480-
{"batch": 4}
481-
) # 4 is the size of dimension "c" in other
482-
xr_assert_allclose(res, expected)
408+
# Symbolic size=1
409+
size_sym_1 = scalar("size_sym_1", dtype="int64")
410+
y = x.expand_dims("country", size=size_sym_1)
411+
fn = xr_function([x, size_sym_1], y)
412+
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("country"))
483413

484414
# Test behavior with symbolic size > 1
485415
# NOTE: This test documents our current behavior where expand_dims broadcasts to the requested size.
486416
# This differs from xarray's behavior where expand_dims always adds a size-1 dimension.
487417
size_sym_4 = scalar("size_sym_4", dtype="int64")
488-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
489-
y = expand_dims(x, "batch", size=size_sym_4)
490-
fn = xr_function([x, size_sym_4], y, on_unused_input="ignore")
491-
x_test = xr_arange_like(x)
418+
y = x.expand_dims("country", size=size_sym_4)
419+
fn = xr_function([x, size_sym_4], y)
492420
res = fn(x_test, 4)
493421
# Our current behavior: broadcasts to size 4
494-
expected = x_test.expand_dims({"batch": 4})
422+
expected = x_test.expand_dims({"country": 4})
495423
xr_assert_allclose(res, expected)
496424
# xarray's behavior would be:
497-
# expected = x_test.expand_dims("batch") # always size 1
425+
# expected = x_test.expand_dims("country") # always size 1
498426
# xr_assert_allclose(res, expected)
499427

500-
# Test using symbolic size from a reduction operation
501-
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
502-
reduced = x.sum("a") # shape: (b: 3)
503-
y = expand_dims(x, "batch", size=reduced.sizes["b"])
504-
fn = xr_function([x], y)
505-
x_test = xr_arange_like(x)
506-
res = fn(x_test)
507-
expected = x_test.expand_dims({"batch": 3}) # 3 is the size of dimension "b"
508-
xr_assert_allclose(res, expected)
509-
510-
# Test chaining expand_dims with symbolic sizes
511-
x = xtensor("x", dims=("a",), shape=(2,))
512-
y = expand_dims(x, "b", size=x.sizes["a"]) # shape: (a: 2, b: 2)
513-
z = expand_dims(y, "c", size=y.sizes["b"]) # shape: (a: 2, b: 2, c: 2)
514-
fn = xr_function([x], z)
515-
x_test = xr_arange_like(x)
516-
res = fn(x_test)
517-
expected = x_test.expand_dims({"b": 2}).expand_dims({"c": 2})
518-
xr_assert_allclose(res, expected)
428+
# Test with symbolic sizes in dict
429+
size_sym_1 = scalar("size_sym_1", dtype="int64")
430+
size_sym_2 = scalar("size_sym_2", dtype="int64")
431+
y = x.expand_dims({"country": size_sym_1, "state": size_sym_2})
432+
fn = xr_function([x, size_sym_1, size_sym_2], y)
433+
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
519434

520-
# Test bidirectional broadcasting with symbolic sizes
521-
x = xtensor("x", dims=("a",), shape=(2,))
522-
y = xtensor("y", dims=("b",), shape=(3,))
523-
# Expand x with size from y, then add y
524-
expanded = expand_dims(x, "b", size=y.sizes["b"])
525-
z = expanded + y # Should broadcast x to match y's size
526-
fn = xr_function([x, y], z)
527-
x_test = xr_arange_like(x)
528-
y_test = xr_arange_like(y)
529-
res = fn(x_test, y_test)
530-
expected = x_test.expand_dims({"b": 3}) + y_test
531-
xr_assert_allclose(res, expected)
435+
# Test with symbolic sizes in kwargs
436+
y = x.expand_dims(country=size_sym_1, state=size_sym_2)
437+
fn = xr_function([x, size_sym_1, size_sym_2], y)
438+
xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3}))
532439

533440

534441
def test_expand_dims_errors():
535442
"""Test error handling in expand_dims."""
536443

537444
# Expanding existing dim
538445
x = xtensor("x", dims=("city",), shape=(3,))
539-
y = expand_dims(x, "country")
446+
y = x.expand_dims("country")
540447
with pytest.raises(ValueError, match="already exists"):
541-
expand_dims(y, "city")
448+
y.expand_dims("city")
542449

543450
# Invalid dim type
544451
with pytest.raises(TypeError, match="Invalid type for `dim`"):
545-
expand_dims(x, 123)
452+
x.expand_dims(123)
546453

547454
# Invalid size type
548455
with pytest.raises(TypeError, match="size must be an int or scalar variable"):
549-
expand_dims(x, "new", size=[1])
456+
x.expand_dims("new", size=[1])
550457

551458
# Duplicate dimension creation
552-
y = expand_dims(x, "new")
459+
y = x.expand_dims("new")
553460
with pytest.raises(ValueError, match="already exists"):
554-
expand_dims(y, "new")
555-
461+
y.expand_dims("new")
556462

557-
def test_expand_dims_multiple():
558-
"""Test expanding multiple dimensions at once using a list of strings."""
559-
x = xtensor("x", dims=("city",), shape=(3,))
560-
y = expand_dims(x, ["country", "state"])
561-
fn = xr_function([x], y)
562-
x_xr = xr_arange_like(x)
563-
xr_assert_allclose(fn(x_xr), x_xr.expand_dims(["country", "state"]))
564-
565-
# Test with a dict of sizes
566-
y = expand_dims(x, {"country": 2, "state": 3})
567-
fn = xr_function([x], y)
568-
x_xr = xr_arange_like(x)
569-
xr_assert_allclose(fn(x_xr), x_xr.expand_dims({"country": 2, "state": 3}))
570-
571-
# Test with a mix of strings and dicts
572-
y = expand_dims(x, ["country", "state"], size=3)
573-
fn = xr_function([x], y)
574-
x_xr = xr_arange_like(x)
575-
xr_assert_allclose(fn(x_xr), x_xr.expand_dims(["country", "state"]))
576-
577-
# Test with symbolic sizes in dict
578-
size_sym_1 = scalar("size_sym_1", dtype="int64")
579-
size_sym_2 = scalar("size_sym_2", dtype="int64")
580-
y = expand_dims(x, {"country": size_sym_1, "state": size_sym_2})
581-
fn = xr_function([x, size_sym_1, size_sym_2], y, on_unused_input="ignore")
582-
x_xr = xr_arange_like(x)
583-
xr_assert_allclose(fn(x_xr, 2, 3), x_xr.expand_dims({"country": 2, "state": 3}))
463+
# Test for error when both positional and size are given
464+
with pytest.raises(
465+
ValueError, match="cannot specify both keyword and positional arguments"
466+
):
467+
x.expand_dims(["country", "state"], size=3)

0 commit comments

Comments
 (0)