From 19361d6811fb1096d8d2265a7613a4dbc4bb4f59 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Fri, 6 Jun 2025 14:14:29 -0400 Subject: [PATCH 01/11] Adding expand_dims for xtensor --- pytensor/xtensor/rewriting/shape.py | 32 +++++ pytensor/xtensor/shape.py | 74 ++++++++++ tests/xtensor/test_shape.py | 209 +++++++++++++++++++++++++++- 3 files changed, 314 insertions(+), 1 deletion(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index b0ca4f3bd4..cbd5bf98b7 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,6 +1,11 @@ +import numpy as np + from pytensor.graph import node_rewriter +from pytensor.raise_op import Assert from pytensor.tensor import ( broadcast_to, + expand_dims, + gt, join, moveaxis, specify_shape, @@ -10,6 +15,7 @@ from pytensor.xtensor.rewriting.basic import register_lower_xtensor from pytensor.xtensor.shape import ( Concat, + ExpandDims, Squeeze, Stack, Transpose, @@ -132,3 +138,29 @@ def local_squeeze_reshape(fgraph, node): new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims) return [new_out] + + +@register_lower_xtensor +@node_rewriter([ExpandDims]) +def local_expand_dims_reshape(fgraph, node): + """Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify_shape.""" + [x] = node.inputs + x_tensor = tensor_from_xtensor(x) + x_tensor_expanded = expand_dims(x_tensor, axis=0) + + target_shape = node.outputs[0].type.shape + + size = getattr(node.op, "size", 1) + if isinstance(size, int | np.integer): + if size != 1 and None not in target_shape: + x_tensor_expanded = broadcast_to(x_tensor_expanded, target_shape) + else: + # Symbolic size: enforce shape so broadcast happens downstream correctly + # Also validate that size is positive + x_tensor_expanded = Assert(msg="size must be positive")( + x_tensor_expanded, gt(size, 0) + ) + x_tensor_expanded = specify_shape(x_tensor_expanded, target_shape) + + new_out = xtensor_from_tensor(x_tensor_expanded, dims=node.outputs[0].type.dims) + return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index cd0f024e56..ca8eb2308c 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -3,6 +3,8 @@ from types import EllipsisType from typing import Literal +import numpy as np + from pytensor.graph import Apply from pytensor.scalar import discrete_dtypes, upcast from pytensor.tensor import as_tensor, get_scalar_constant_value @@ -380,3 +382,75 @@ def squeeze(x, dim=None): return x # no-op if nothing to squeeze return Squeeze(dims=dims)(x) + + +class ExpandDims(XOp): + """Add a new dimension to an XTensorVariable.""" + + __props__ = ("dims", "size") + + def __init__(self, dim, size=1): + self.dims = dim + self.size = size + + def make_node(self, x): + x = as_xtensor(x) + + if self.dims is None: + # No-op: return same variable + return Apply(self, [x], [x]) + + # Insert new dim at front + new_dims = (self.dims, *x.type.dims) + + # Determine shape + if isinstance(self.size, int | np.integer): + new_shape = (self.size, *x.type.shape) + else: + new_shape = (None, *x.type.shape) # symbolic size + + out = xtensor( + dtype=x.type.dtype, + shape=new_shape, + dims=new_dims, + ) + return Apply(self, [x], [out]) + + +def expand_dims(x, dim: str | None, size=1): + """Add a new dimension to an XTensorVariable. + + Parameters + ---------- + x : XTensorVariable + Input tensor + dim : str or None + Name of new dimension. If None, returns x unchanged. + size : int or symbolic, optional + Size of the new dimension (default 1) + + Returns + ------- + XTensorVariable + Tensor with the new dimension inserted + """ + x = as_xtensor(x) + + if dim is None: + return x # No-op + + if not isinstance(dim, str): + raise TypeError(f"`dim` must be a string or None, got: {type(dim)}") + + if dim in x.type.dims: + raise ValueError(f"Dimension {dim} already exists in {x.type.dims}") + + if isinstance(size, int | np.integer): + if size <= 0: + raise ValueError(f"size must be positive, got: {size}") + elif not ( + hasattr(size, "ndim") and getattr(size, "ndim", None) == 0 # symbolic scalar + ): + raise TypeError(f"size must be an int or scalar variable, got: {type(size)}") + + return ExpandDims(dim=dim, size=size)(x) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index f5db72bf1f..7206d40942 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -8,12 +8,14 @@ from itertools import chain, combinations import numpy as np -import pytest +import xarray as xr from xarray import DataArray from xarray import concat as xr_concat +from pytensor.tensor import scalar from pytensor.xtensor.shape import ( concat, + expand_dims, squeeze, stack, transpose, @@ -301,6 +303,15 @@ def test_squeeze_explicit_dims(): fn3d = xr_function([x3], y3d) xr_assert_allclose(fn3d(x3_test), x3_test) + # Reversibility with expand_dims + x6 = xtensor("x6", dims=("a", "b", "c"), shape=(2, 1, 3)) + y6 = squeeze(x6, "b") + # First expand_dims adds at front, then transpose puts it in the right place + z6 = transpose(expand_dims(y6, "b"), "a", "b", "c") + fn6 = xr_function([x6], z6) + x6_test = xr_arange_like(x6) + xr_assert_allclose(fn6(x6_test), x6_test) + def test_squeeze_implicit_dims(): """Test squeeze with implicit dim=None (all size-1 dimensions).""" @@ -369,3 +380,199 @@ def test_squeeze_errors(): fn2 = xr_function([x2], y2) with pytest.raises(Exception): fn2(x2_test) + + +def test_expand_dims_explicit(): + """Test expand_dims with explicitly named dimensions and sizes.""" + + # 1D case + x = xtensor("x", dims=("city",), shape=(3,)) + y = expand_dims(x, "country") + fn = xr_function([x], y) + x_xr = xr_arange_like(x) + xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country")) + + # 2D case + x = xtensor("x", dims=("city", "year"), shape=(2, 2)) + y = expand_dims(x, "country") + fn = xr_function([x], y) + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) + + # 3D case + x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2)) + y = expand_dims(x, "country") + fn = xr_function([x], y) + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) + + # Prepending various dims + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + for new_dim in ("x", "y", "z"): + y = expand_dims(x, new_dim) + assert y.type.dims == (new_dim, "a", "b") + assert y.type.shape == (1, 2, 3) + + # Explicit size=1 behaves like default + y1 = expand_dims(x, "batch", size=1) + y2 = expand_dims(x, "batch") + fn1 = xr_function([x], y1) + fn2 = xr_function([x], y2) + x_test = xr_arange_like(x) + xr_assert_allclose(fn1(x_test), fn2(x_test)) + + # Scalar expansion + x = xtensor("x", dims=(), shape=()) + y = expand_dims(x, "batch") + assert y.type.dims == ("batch",) + assert y.type.shape == (1,) + fn = xr_function([x], y) + xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch")) + + # Static size > 1: broadcast + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch", size=4) + fn = xr_function([x], y) + expected = xr.DataArray( + np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)), + dims=("batch", "a", "b"), + coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]}, + ) + xr_assert_allclose(fn(xr_arange_like(x)), expected) + + # Insert new dim between existing dims + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "new") + # Insert new dim between a and b: ("a", "new", "b") + y = transpose(y, "a", "new", "b") + fn = xr_function([x], y) + x_test = xr_arange_like(x) + expected = x_test.expand_dims("new").transpose("a", "new", "b") + xr_assert_allclose(fn(x_test), expected) + + # Expand with multiple dims + x = xtensor("x", dims=(), shape=()) + y = expand_dims(expand_dims(x, "a"), "b") + fn = xr_function([x], y) + expected = xr_arange_like(x).expand_dims("a").expand_dims("b") + xr_assert_allclose(fn(xr_arange_like(x)), expected) + + +def test_expand_dims_implicit(): + """Test expand_dims with default or symbolic sizes and dim=None.""" + + # Symbolic size=1: same as default + size_sym_1 = scalar("size_sym_1", dtype="int64") + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch", size=size_sym_1) + fn = xr_function([x, size_sym_1], y, on_unused_input="ignore") + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch")) + + # Symbolic size > 1 (but expand only adds dim=1) + size_sym_4 = scalar("size_sym_4", dtype="int64") + y = expand_dims(x, "batch", size=size_sym_4) + fn = xr_function([x, size_sym_4], y, on_unused_input="ignore") + xr_assert_allclose(fn(x_test, 4), x_test.expand_dims("batch")) + + # Symbolic size > 1 with broadcasting + size_sym_4 = scalar("size_sym_4", dtype="int64") + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch", size=size_sym_4) + z = y + y # This should broadcast along the batch dimension + fn = xr_function([x, size_sym_4], z, on_unused_input="ignore") + x_test = xr_arange_like(x) + out = fn(x_test, 4) + expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") + xr_assert_allclose(out, expected) + + # Symbolic size with shape validation + size_sym = scalar("size_sym", dtype="int64") + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch", size=size_sym) + z = y + y # This should validate the shape + fn = xr_function([x, size_sym], z, on_unused_input="ignore") + x_test = xr_arange_like(x) + out = fn(x_test, 4) + expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") + xr_assert_allclose(out, expected) + + # Symbolic size with subsequent operations + size_sym = scalar("size_sym", dtype="int64") + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch", size=size_sym) + z = y.sum("batch") # This should work with symbolic size + fn = xr_function([x, size_sym], z, on_unused_input="ignore") + x_test = xr_arange_like(x) + out = fn(x_test, 4) + expected = x_test.expand_dims("batch").sum("batch") + xr_assert_allclose(out, expected) + + # Symbolic size with transpose and broadcasting + size_sym = scalar("size_sym", dtype="int64") + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch", size=size_sym) + z = transpose(y, "batch", "a", "b") # This should work with symbolic size + fn = xr_function([x, size_sym], z, on_unused_input="ignore") + x_test = xr_arange_like(x) + out = fn(x_test, 4) + expected = x_test.expand_dims("batch").transpose("batch", "a", "b") + xr_assert_allclose(out, expected) + + # Reversibility: expand then squeeze + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch") + z = squeeze(y, "batch") + fn = xr_function([x], z) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test) + + # expand_dims with dim=None = no-op + x = xtensor("x", dims=("a",), shape=(3,)) + y = expand_dims(x, None) + fn = xr_function([x], y) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test) + + # broadcast after symbolic size + size_sym = scalar("size_sym", dtype="int64") + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = expand_dims(x, "batch", size=size_sym) + z = y + y # triggers shape alignment + fn = xr_function([x, size_sym], z, on_unused_input="ignore") + x_test = xr_arange_like(x) + out = fn(x_test, 1) + expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") + xr_assert_allclose(out, expected) + + +def test_expand_dims_errors(): + """Test error handling in expand_dims.""" + + # Expanding existing dim + x = xtensor("x", dims=("city",), shape=(3,)) + y = expand_dims(x, "country") + with pytest.raises(ValueError, match="already exists"): + expand_dims(y, "city") + + # Size = 0 is invalid + with pytest.raises(ValueError, match="size must be.*positive"): + expand_dims(x, "batch", size=0) + + # Invalid dim type + with pytest.raises(TypeError): + expand_dims(x, 123) + + # Invalid size type + with pytest.raises(TypeError): + expand_dims(x, "new", size=[1]) + + # Duplicate dimension creation + y = expand_dims(x, "new") + with pytest.raises(ValueError): + expand_dims(y, "new") + + # Symbolic size with invalid runtime value + size_sym = scalar("size_sym", dtype="int64") + y = expand_dims(x, "batch", size=size_sym) + fn = xr_function([x, size_sym], y, on_unused_input="ignore") + with pytest.raises(Exception): + fn(xr_arange_like(x), 0) From 95a71a6505a5fae4249cbd7f1eaa7d76cea7b613 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Sat, 7 Jun 2025 16:09:45 -0400 Subject: [PATCH 02/11] Handling symbolic size --- pytensor/xtensor/rewriting/shape.py | 50 ++++++------ pytensor/xtensor/shape.py | 26 +++--- tests/xtensor/test_shape.py | 118 ++++++++++++++-------------- 3 files changed, 97 insertions(+), 97 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index cbd5bf98b7..1c83329403 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,16 +1,17 @@ -import numpy as np - from pytensor.graph import node_rewriter from pytensor.raise_op import Assert from pytensor.tensor import ( broadcast_to, - expand_dims, + get_scalar_constant_value, gt, join, moveaxis, specify_shape, squeeze, ) +from pytensor.tensor import ( + shape as tensor_shape, +) from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_lower_xtensor from pytensor.xtensor.shape import ( @@ -143,24 +144,27 @@ def local_squeeze_reshape(fgraph, node): @register_lower_xtensor @node_rewriter([ExpandDims]) def local_expand_dims_reshape(fgraph, node): - """Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify_shape.""" - [x] = node.inputs - x_tensor = tensor_from_xtensor(x) - x_tensor_expanded = expand_dims(x_tensor, axis=0) - - target_shape = node.outputs[0].type.shape - - size = getattr(node.op, "size", 1) - if isinstance(size, int | np.integer): - if size != 1 and None not in target_shape: - x_tensor_expanded = broadcast_to(x_tensor_expanded, target_shape) + """Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify shape.""" + x, size = node.inputs + out = node.outputs[0] + # Lower to tensor.expand_dims(x, axis=0) + from pytensor.tensor import expand_dims as tensor_expand_dims + + expanded = tensor_expand_dims(tensor_from_xtensor(x), 0) + # Optionally broadcast to the correct shape if size is not 1 + from pytensor.tensor import broadcast_to + + # Ensure size is positive + expanded = Assert(msg="size must be positive")(expanded, gt(size, 0)) + # If size is not 1, broadcast + try: + static_size = get_scalar_constant_value(size) + except Exception: + static_size = None + if static_size is not None and static_size == 1: + result = expanded else: - # Symbolic size: enforce shape so broadcast happens downstream correctly - # Also validate that size is positive - x_tensor_expanded = Assert(msg="size must be positive")( - x_tensor_expanded, gt(size, 0) - ) - x_tensor_expanded = specify_shape(x_tensor_expanded, target_shape) - - new_out = xtensor_from_tensor(x_tensor_expanded, dims=node.outputs[0].type.dims) - return [new_out] + # Broadcast to (size, ...) + new_shape = (size,) + tuple(tensor_shape(expanded))[1:] + result = broadcast_to(expanded, new_shape) + return [xtensor_from_tensor(result, out.type.dims)] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index ca8eb2308c..80e8e1405a 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -387,25 +387,23 @@ def squeeze(x, dim=None): class ExpandDims(XOp): """Add a new dimension to an XTensorVariable.""" - __props__ = ("dims", "size") + __props__ = ("dims",) - def __init__(self, dim, size=1): + def __init__(self, dim): self.dims = dim - self.size = size - def make_node(self, x): + def make_node(self, x, size): x = as_xtensor(x) - - if self.dims is None: - # No-op: return same variable - return Apply(self, [x], [x]) - # Insert new dim at front new_dims = (self.dims, *x.type.dims) # Determine shape - if isinstance(self.size, int | np.integer): - new_shape = (self.size, *x.type.shape) + try: + static_size = get_scalar_constant_value(size) + except NotScalarConstantError: + static_size = None + if static_size is not None: + new_shape = (int(static_size), *x.type.shape) else: new_shape = (None, *x.type.shape) # symbolic size @@ -414,7 +412,7 @@ def make_node(self, x): shape=new_shape, dims=new_dims, ) - return Apply(self, [x], [out]) + return Apply(self, [x, size], [out]) def expand_dims(x, dim: str | None, size=1): @@ -453,4 +451,6 @@ def expand_dims(x, dim: str | None, size=1): ): raise TypeError(f"size must be an int or scalar variable, got: {type(size)}") - return ExpandDims(dim=dim, size=size)(x) + # Always convert size to a PyTensor scalar variable + size_var = as_tensor(size, ndim=0) + return ExpandDims(dim)(x, size_var) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 7206d40942..d3972f5db2 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -467,81 +467,77 @@ def test_expand_dims_implicit(): x_test = xr_arange_like(x) xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch")) - # Symbolic size > 1 (but expand only adds dim=1) - size_sym_4 = scalar("size_sym_4", dtype="int64") - y = expand_dims(x, "batch", size=size_sym_4) - fn = xr_function([x, size_sym_4], y, on_unused_input="ignore") - xr_assert_allclose(fn(x_test, 4), x_test.expand_dims("batch")) - - # Symbolic size > 1 with broadcasting - size_sym_4 = scalar("size_sym_4", dtype="int64") - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch", size=size_sym_4) - z = y + y # This should broadcast along the batch dimension - fn = xr_function([x, size_sym_4], z, on_unused_input="ignore") - x_test = xr_arange_like(x) - out = fn(x_test, 4) - expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") - xr_assert_allclose(out, expected) + # Test using symbolic size from an existing dimension of the same tensor + # This verifies that expand_dims can use the size of one dimension to create another + x = xtensor(dims=("a", "b", "c")) + y = expand_dims(x, "d", size=x.sizes["b"]) + fn = xr_function([x], y) + x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 3, 5))) + res = fn(x_test) + expected = x_test.expand_dims({"d": 3}) # 3 is the size of dimension "b" + xr_assert_allclose(res, expected) - # Symbolic size with shape validation - size_sym = scalar("size_sym", dtype="int64") + # Test broadcasting with symbolic size from a different tensor x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch", size=size_sym) - z = y + y # This should validate the shape - fn = xr_function([x, size_sym], z, on_unused_input="ignore") + other = xtensor("other", dims=("c",), shape=(4,)) + y = expand_dims(x, "batch", size=other.sizes["c"]) + fn = xr_function([x, other], y) x_test = xr_arange_like(x) - out = fn(x_test, 4) - expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") - xr_assert_allclose(out, expected) + other_test = xr_arange_like(other) + res = fn(x_test, other_test) + expected = x_test.expand_dims( + {"batch": 4} + ) # 4 is the size of dimension "c" in other + xr_assert_allclose(res, expected) - # Symbolic size with subsequent operations - size_sym = scalar("size_sym", dtype="int64") + # Test behavior with symbolic size > 1 + # NOTE: This test documents our current behavior where expand_dims broadcasts to the requested size. + # This differs from xarray's behavior where expand_dims always adds a size-1 dimension. + size_sym_4 = scalar("size_sym_4", dtype="int64") x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch", size=size_sym) - z = y.sum("batch") # This should work with symbolic size - fn = xr_function([x, size_sym], z, on_unused_input="ignore") + y = expand_dims(x, "batch", size=size_sym_4) + fn = xr_function([x, size_sym_4], y, on_unused_input="ignore") x_test = xr_arange_like(x) - out = fn(x_test, 4) - expected = x_test.expand_dims("batch").sum("batch") - xr_assert_allclose(out, expected) + res = fn(x_test, 4) + # Our current behavior: broadcasts to size 4 + expected = x_test.expand_dims({"batch": 4}) + xr_assert_allclose(res, expected) + # xarray's behavior would be: + # expected = x_test.expand_dims("batch") # always size 1 + # xr_assert_allclose(res, expected) - # Symbolic size with transpose and broadcasting - size_sym = scalar("size_sym", dtype="int64") + # Test using symbolic size from a reduction operation x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch", size=size_sym) - z = transpose(y, "batch", "a", "b") # This should work with symbolic size - fn = xr_function([x, size_sym], z, on_unused_input="ignore") + reduced = x.sum("a") # shape: (b: 3) + y = expand_dims(x, "batch", size=reduced.sizes["b"]) + fn = xr_function([x], y) x_test = xr_arange_like(x) - out = fn(x_test, 4) - expected = x_test.expand_dims("batch").transpose("batch", "a", "b") - xr_assert_allclose(out, expected) + res = fn(x_test) + expected = x_test.expand_dims({"batch": 3}) # 3 is the size of dimension "b" + xr_assert_allclose(res, expected) - # Reversibility: expand then squeeze - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch") - z = squeeze(y, "batch") + # Test chaining expand_dims with symbolic sizes + x = xtensor("x", dims=("a",), shape=(2,)) + y = expand_dims(x, "b", size=x.sizes["a"]) # shape: (a: 2, b: 2) + z = expand_dims(y, "c", size=y.sizes["b"]) # shape: (a: 2, b: 2, c: 2) fn = xr_function([x], z) x_test = xr_arange_like(x) - xr_assert_allclose(fn(x_test), x_test) + res = fn(x_test) + expected = x_test.expand_dims({"b": 2}).expand_dims({"c": 2}) + xr_assert_allclose(res, expected) - # expand_dims with dim=None = no-op - x = xtensor("x", dims=("a",), shape=(3,)) - y = expand_dims(x, None) - fn = xr_function([x], y) + # Test bidirectional broadcasting with symbolic sizes + x = xtensor("x", dims=("a",), shape=(2,)) + y = xtensor("y", dims=("b",), shape=(3,)) + # Expand x with size from y, then add y + expanded = expand_dims(x, "b", size=y.sizes["b"]) + z = expanded + y # Should broadcast x to match y's size + fn = xr_function([x, y], z) x_test = xr_arange_like(x) - xr_assert_allclose(fn(x_test), x_test) - - # broadcast after symbolic size - size_sym = scalar("size_sym", dtype="int64") - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch", size=size_sym) - z = y + y # triggers shape alignment - fn = xr_function([x, size_sym], z, on_unused_input="ignore") - x_test = xr_arange_like(x) - out = fn(x_test, 1) - expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") - xr_assert_allclose(out, expected) + y_test = xr_arange_like(y) + res = fn(x_test, y_test) + expected = x_test.expand_dims({"b": 3}) + y_test + xr_assert_allclose(res, expected) def test_expand_dims_errors(): From 802a536ef129982cc47e609c9cf1a4d7ca9abf05 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Mon, 9 Jun 2025 11:19:36 -0400 Subject: [PATCH 03/11] Adding support for expanding multiple dimensions --- pytensor/xtensor/shape.py | 95 ++++++++++++++++++++++++++----------- pytensor/xtensor/type.py | 21 ++++++++ tests/xtensor/test_shape.py | 48 +++++++++++++------ 3 files changed, 123 insertions(+), 41 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 80e8e1405a..737f125878 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -9,6 +9,7 @@ from pytensor.scalar import discrete_dtypes, upcast from pytensor.tensor import as_tensor, get_scalar_constant_value from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.variable import TensorVariable from pytensor.xtensor.basic import XOp from pytensor.xtensor.type import as_xtensor, xtensor @@ -387,15 +388,35 @@ def squeeze(x, dim=None): class ExpandDims(XOp): """Add a new dimension to an XTensorVariable.""" - __props__ = ("dims",) + __props__ = ("dim",) def __init__(self, dim): - self.dims = dim + self.dim = dim def make_node(self, x, size): x = as_xtensor(x) + + if not isinstance(self.dim, str): + raise TypeError(f"`dim` must be a string or None, got: {type(self.dim)}") + + if self.dim in x.type.dims: + raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}") + if isinstance(size, int | np.integer): + if size <= 0: + raise ValueError(f"size must be positive, got: {size}") + elif not ( + hasattr(size, "ndim") + and getattr(size, "ndim", None) == 0 # symbolic scalar + ): + raise TypeError( + f"size must be an int or scalar variable, got: {type(size)}" + ) + + # Convert size to tensor + size = as_tensor(size, ndim=0) + # Insert new dim at front - new_dims = (self.dims, *x.type.dims) + new_dims = (self.dim, *x.type.dims) # Determine shape try: @@ -415,42 +436,62 @@ def make_node(self, x, size): return Apply(self, [x, size], [out]) -def expand_dims(x, dim: str | None, size=1): - """Add a new dimension to an XTensorVariable. +def expand_dims(x, dim=None, create_index_for_new_dim=True, **dim_kwargs): + """Add one or more new dimensions to an XTensorVariable. Parameters ---------- x : XTensorVariable - Input tensor - dim : str or None - Name of new dimension. If None, returns x unchanged. - size : int or symbolic, optional - Size of the new dimension (default 1) + Input tensor. + dim : str | Sequence[str] | dict[str, int | Sequence] | None + If str or sequence of str, new dimensions with size 1. + If dict, keys are dimension names and values are either: + - int: the new size + - sequence: coordinates (length determines size) + create_index_for_new_dim : bool, default: True + (Ignored for now) Matches xarray API, reserved for future use. + **dim_kwargs : int | Sequence + Alternative to `dim` dict. Only used if `dim` is None. Returns ------- XTensorVariable - Tensor with the new dimension inserted + A tensor with additional dimensions inserted at the front. """ x = as_xtensor(x) - if dim is None: - return x # No-op + # Extract size from dim_kwargs if present + size = dim_kwargs.pop("size", 1) if dim_kwargs else 1 - if not isinstance(dim, str): - raise TypeError(f"`dim` must be a string or None, got: {type(dim)}") + if dim is None: + dim = dim_kwargs + elif dim_kwargs: + raise ValueError("Cannot specify both `dim` and `**dim_kwargs`") + + # Normalize to a dimension-size mapping + if isinstance(dim, str): + dims_dict = {dim: size} + elif isinstance(dim, Sequence) and not isinstance(dim, dict): + dims_dict = {d: 1 for d in dim} + elif isinstance(dim, dict): + dims_dict = {} + for name, val in dim.items(): + if isinstance(val, Sequence | np.ndarray) and not isinstance(val, str): + dims_dict[name] = len(val) + elif isinstance(val, int): + dims_dict[name] = val + else: + dims_dict[name] = val # symbolic/int scalar allowed + else: + raise TypeError(f"Invalid type for `dim`: {type(dim)}") - if dim in x.type.dims: - raise ValueError(f"Dimension {dim} already exists in {x.type.dims}") + # Convert to canonical form: list of (dim_name, size) + canonical_dims: list[tuple[str, int | np.integer | TensorVariable]] = [] + for name, size in dims_dict.items(): + canonical_dims.append((name, size)) - if isinstance(size, int | np.integer): - if size <= 0: - raise ValueError(f"size must be positive, got: {size}") - elif not ( - hasattr(size, "ndim") and getattr(size, "ndim", None) == 0 # symbolic scalar - ): - raise TypeError(f"size must be an int or scalar variable, got: {type(size)}") + # Insert each new dim at the front (reverse order preserves user intent) + for name, size in reversed(canonical_dims): + x = ExpandDims(dim=name)(x, size) - # Always convert size to a PyTensor scalar variable - size_var = as_tensor(size, ndim=0) - return ExpandDims(dim)(x, size_var) + return x diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index fd601df018..6d0ed0ab41 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -481,6 +481,27 @@ def squeeze( raise NotImplementedError("Squeeze with axis not Implemented") return px.shape.squeeze(self, dim) + def expand_dims( + self, + dim: str | None = None, + size: int | Variable = 1, + ): + """Add a new dimension to the tensor. + + Parameters + ---------- + dim : str or None + Name of new dimension. If None, returns self unchanged. + size : int or symbolic, optional + Size of the new dimension (default 1) + + Returns + ------- + XTensorVariable + Tensor with the new dimension inserted + """ + return px.shape.expand_dims(self, dim, size=size) + # ndarray methods # https://docs.xarray.dev/en/latest/api.html#id7 def clip(self, min, max): diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index d3972f5db2..d337b51096 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -303,15 +303,6 @@ def test_squeeze_explicit_dims(): fn3d = xr_function([x3], y3d) xr_assert_allclose(fn3d(x3_test), x3_test) - # Reversibility with expand_dims - x6 = xtensor("x6", dims=("a", "b", "c"), shape=(2, 1, 3)) - y6 = squeeze(x6, "b") - # First expand_dims adds at front, then transpose puts it in the right place - z6 = transpose(expand_dims(y6, "b"), "a", "b", "c") - fn6 = xr_function([x6], z6) - x6_test = xr_arange_like(x6) - xr_assert_allclose(fn6(x6_test), x6_test) - def test_squeeze_implicit_dims(): """Test squeeze with implicit dim=None (all size-1 dimensions).""" @@ -456,8 +447,8 @@ def test_expand_dims_explicit(): xr_assert_allclose(fn(xr_arange_like(x)), expected) -def test_expand_dims_implicit(): - """Test expand_dims with default or symbolic sizes and dim=None.""" +def test_expand_dims_symbolic_size(): + """Test expand_dims with symbolic sizes.""" # Symbolic size=1: same as default size_sym_1 = scalar("size_sym_1", dtype="int64") @@ -554,16 +545,16 @@ def test_expand_dims_errors(): expand_dims(x, "batch", size=0) # Invalid dim type - with pytest.raises(TypeError): + with pytest.raises(TypeError, match="Invalid type for `dim`"): expand_dims(x, 123) # Invalid size type - with pytest.raises(TypeError): + with pytest.raises(TypeError, match="size must be an int or scalar variable"): expand_dims(x, "new", size=[1]) # Duplicate dimension creation y = expand_dims(x, "new") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="already exists"): expand_dims(y, "new") # Symbolic size with invalid runtime value @@ -572,3 +563,32 @@ def test_expand_dims_errors(): fn = xr_function([x, size_sym], y, on_unused_input="ignore") with pytest.raises(Exception): fn(xr_arange_like(x), 0) + + +def test_expand_dims_multiple(): + """Test expanding multiple dimensions at once using a list of strings.""" + x = xtensor("x", dims=("city",), shape=(3,)) + y = expand_dims(x, ["country", "state"]) + fn = xr_function([x], y) + x_xr = xr_arange_like(x) + xr_assert_allclose(fn(x_xr), x_xr.expand_dims(["country", "state"])) + + # Test with a dict of sizes + y = expand_dims(x, {"country": 2, "state": 3}) + fn = xr_function([x], y) + x_xr = xr_arange_like(x) + xr_assert_allclose(fn(x_xr), x_xr.expand_dims({"country": 2, "state": 3})) + + # Test with a mix of strings and dicts + y = expand_dims(x, ["country", "state"], size=3) + fn = xr_function([x], y) + x_xr = xr_arange_like(x) + xr_assert_allclose(fn(x_xr), x_xr.expand_dims(["country", "state"])) + + # Test with symbolic sizes in dict + size_sym_1 = scalar("size_sym_1", dtype="int64") + size_sym_2 = scalar("size_sym_2", dtype="int64") + y = expand_dims(x, {"country": size_sym_1, "state": size_sym_2}) + fn = xr_function([x, size_sym_1, size_sym_2], y, on_unused_input="ignore") + x_xr = xr_arange_like(x) + xr_assert_allclose(fn(x_xr, 2, 3), x_xr.expand_dims({"country": 2, "state": 3})) From 3b7b973776fcb1a48b84044b4637344ab77a5ff0 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Tue, 10 Jun 2025 12:21:35 -0400 Subject: [PATCH 04/11] Fixing ExpandDims rewrite --- pytensor/xtensor/rewriting/shape.py | 52 ++++++++++++++--------------- pytensor/xtensor/shape.py | 35 ++++++++++--------- tests/xtensor/test_shape.py | 11 ------ 3 files changed, 45 insertions(+), 53 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 1c83329403..a0c04c1906 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,17 +1,12 @@ from pytensor.graph import node_rewriter -from pytensor.raise_op import Assert from pytensor.tensor import ( broadcast_to, - get_scalar_constant_value, - gt, + expand_dims, join, moveaxis, specify_shape, squeeze, ) -from pytensor.tensor import ( - shape as tensor_shape, -) from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_lower_xtensor from pytensor.xtensor.shape import ( @@ -144,27 +139,30 @@ def local_squeeze_reshape(fgraph, node): @register_lower_xtensor @node_rewriter([ExpandDims]) def local_expand_dims_reshape(fgraph, node): - """Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify shape.""" + """Rewrite ExpandDims using tensor operations.""" x, size = node.inputs out = node.outputs[0] - # Lower to tensor.expand_dims(x, axis=0) - from pytensor.tensor import expand_dims as tensor_expand_dims - - expanded = tensor_expand_dims(tensor_from_xtensor(x), 0) - # Optionally broadcast to the correct shape if size is not 1 - from pytensor.tensor import broadcast_to - - # Ensure size is positive - expanded = Assert(msg="size must be positive")(expanded, gt(size, 0)) - # If size is not 1, broadcast - try: - static_size = get_scalar_constant_value(size) - except Exception: - static_size = None - if static_size is not None and static_size == 1: - result = expanded + + # Convert inputs to tensors + x_tensor = tensor_from_xtensor(x) + size_tensor = tensor_from_xtensor(size) + + # Get the new dimension name and position + new_axis = 0 # Always insert at front + + # Use tensor operations + if out.type.shape[0] == 1: + # Simple case: just expand with size 1 + result_tensor = expand_dims(x_tensor, new_axis) else: - # Broadcast to (size, ...) - new_shape = (size,) + tuple(tensor_shape(expanded))[1:] - result = broadcast_to(expanded, new_shape) - return [xtensor_from_tensor(result, out.type.dims)] + # First expand with size 1 + expanded = expand_dims(x_tensor, new_axis) + # Then broadcast to the requested size + result_tensor = broadcast_to(expanded, (size_tensor, *x_tensor.shape)) + + # Preserve static shape information + result_tensor = specify_shape(result_tensor, out.type.shape) + + # Convert result back to xtensor + result = xtensor_from_tensor(result_tensor, dims=out.type.dims) + return [result] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 737f125878..063077f653 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -6,6 +6,7 @@ import numpy as np from pytensor.graph import Apply +from pytensor.graph.basic import Constant from pytensor.scalar import discrete_dtypes, upcast from pytensor.tensor import as_tensor, get_scalar_constant_value from pytensor.tensor.exceptions import NotScalarConstantError @@ -391,43 +392,47 @@ class ExpandDims(XOp): __props__ = ("dim",) def __init__(self, dim): + if not isinstance(dim, str): + raise TypeError(f"`dim` must be a string, got: {type(self.dim)}") + self.dim = dim def make_node(self, x, size): x = as_xtensor(x) - if not isinstance(self.dim, str): - raise TypeError(f"`dim` must be a string or None, got: {type(self.dim)}") - if self.dim in x.type.dims: raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}") - if isinstance(size, int | np.integer): - if size <= 0: - raise ValueError(f"size must be positive, got: {size}") - elif not ( - hasattr(size, "ndim") - and getattr(size, "ndim", None) == 0 # symbolic scalar + + # Check if size is a valid type before converting + if not ( + isinstance(size, int | np.integer) + or (hasattr(size, "ndim") and getattr(size, "ndim", None) == 0) ): raise TypeError( f"size must be an int or scalar variable, got: {type(size)}" ) - # Convert size to tensor - size = as_tensor(size, ndim=0) - - # Insert new dim at front - new_dims = (self.dim, *x.type.dims) - # Determine shape try: static_size = get_scalar_constant_value(size) except NotScalarConstantError: static_size = None + if static_size is not None: new_shape = (int(static_size), *x.type.shape) else: new_shape = (None, *x.type.shape) # symbolic size + # Convert size to tensor + size = as_xtensor(size, dims=()) + + # Check if size is a constant and validate it + if isinstance(size, Constant) and size.data < 0: + raise ValueError(f"size must be 0 or positive, got: {size.data}") + + # Insert new dim at front + new_dims = (self.dim, *x.type.dims) + out = xtensor( dtype=x.type.dtype, shape=new_shape, diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index d337b51096..64fb95ddfe 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -540,10 +540,6 @@ def test_expand_dims_errors(): with pytest.raises(ValueError, match="already exists"): expand_dims(y, "city") - # Size = 0 is invalid - with pytest.raises(ValueError, match="size must be.*positive"): - expand_dims(x, "batch", size=0) - # Invalid dim type with pytest.raises(TypeError, match="Invalid type for `dim`"): expand_dims(x, 123) @@ -557,13 +553,6 @@ def test_expand_dims_errors(): with pytest.raises(ValueError, match="already exists"): expand_dims(y, "new") - # Symbolic size with invalid runtime value - size_sym = scalar("size_sym", dtype="int64") - y = expand_dims(x, "batch", size=size_sym) - fn = xr_function([x, size_sym], y, on_unused_input="ignore") - with pytest.raises(Exception): - fn(xr_arange_like(x), 0) - def test_expand_dims_multiple(): """Test expanding multiple dimensions at once using a list of strings.""" From 38d87c68f6003d4920b4e08b653b176b3706d671 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 11 Jun 2025 14:44:42 -0400 Subject: [PATCH 05/11] Cleanup --- pytensor/xtensor/rewriting/shape.py | 10 +- pytensor/xtensor/shape.py | 26 +--- pytensor/xtensor/type.py | 26 ++-- tests/xtensor/test_shape.py | 218 +++++++--------------------- 4 files changed, 77 insertions(+), 203 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index a0c04c1906..c0b1a5fe88 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -123,7 +123,7 @@ def lower_transpose(fgraph, node): @register_lower_xtensor @node_rewriter([Squeeze]) -def local_squeeze_reshape(fgraph, node): +def lower_squeeze(fgraph, node): """Rewrite Squeeze to tensor.squeeze.""" [x] = node.inputs x_tensor = tensor_from_xtensor(x) @@ -138,7 +138,7 @@ def local_squeeze_reshape(fgraph, node): @register_lower_xtensor @node_rewriter([ExpandDims]) -def local_expand_dims_reshape(fgraph, node): +def lower_expand_dims(fgraph, node): """Rewrite ExpandDims using tensor operations.""" x, size = node.inputs out = node.outputs[0] @@ -155,10 +155,8 @@ def local_expand_dims_reshape(fgraph, node): # Simple case: just expand with size 1 result_tensor = expand_dims(x_tensor, new_axis) else: - # First expand with size 1 - expanded = expand_dims(x_tensor, new_axis) - # Then broadcast to the requested size - result_tensor = broadcast_to(expanded, (size_tensor, *x_tensor.shape)) + # Otherwise broadcast to the requested size + result_tensor = broadcast_to(x_tensor, (size_tensor, *x_tensor.shape)) # Preserve static shape information result_tensor = specify_shape(result_tensor, out.type.shape) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 063077f653..8a1e94daaf 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -442,32 +442,16 @@ def make_node(self, x, size): def expand_dims(x, dim=None, create_index_for_new_dim=True, **dim_kwargs): - """Add one or more new dimensions to an XTensorVariable. - - Parameters - ---------- - x : XTensorVariable - Input tensor. - dim : str | Sequence[str] | dict[str, int | Sequence] | None - If str or sequence of str, new dimensions with size 1. - If dict, keys are dimension names and values are either: - - int: the new size - - sequence: coordinates (length determines size) - create_index_for_new_dim : bool, default: True - (Ignored for now) Matches xarray API, reserved for future use. - **dim_kwargs : int | Sequence - Alternative to `dim` dict. Only used if `dim` is None. - - Returns - ------- - XTensorVariable - A tensor with additional dimensions inserted at the front. - """ + """Add one or more new dimensions to an XTensorVariable.""" x = as_xtensor(x) # Extract size from dim_kwargs if present size = dim_kwargs.pop("size", 1) if dim_kwargs else 1 + # xarray compatibility: error if a sequence (list/tuple) of dims and size are given + if (isinstance(dim, list | tuple)) and ("size" in locals() and size != 1): + raise ValueError("cannot specify both keyword and positional arguments") + if dim is None: dim = dim_kwargs elif dim_kwargs: diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 6d0ed0ab41..f9f12eb385 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -483,24 +483,32 @@ def squeeze( def expand_dims( self, - dim: str | None = None, - size: int | Variable = 1, + dim: str | Sequence[str] | dict[str, int | Sequence] | None = None, + create_index_for_new_dim: bool = True, + **dim_kwargs, ): - """Add a new dimension to the tensor. + """Add one or more new dimensions to the tensor. Parameters ---------- - dim : str or None - Name of new dimension. If None, returns self unchanged. - size : int or symbolic, optional - Size of the new dimension (default 1) + dim : str | Sequence[str] | dict[str, int | Sequence] | None + If str or sequence of str, new dimensions with size 1. + If dict, keys are dimension names and values are either: + - int: the new size + - sequence: coordinates (length determines size) + create_index_for_new_dim : bool, default: True + (Ignored for now) Matches xarray API, reserved for future use. + **dim_kwargs : int | Sequence + Alternative to `dim` dict. Only used if `dim` is None. Returns ------- XTensorVariable - Tensor with the new dimension inserted + A tensor with additional dimensions inserted at the front. """ - return px.shape.expand_dims(self, dim, size=size) + return px.shape.expand_dims( + self, dim, create_index_for_new_dim=create_index_for_new_dim, **dim_kwargs + ) # ndarray methods # https://docs.xarray.dev/en/latest/api.html#id7 diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 64fb95ddfe..8123de7003 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -8,14 +8,12 @@ from itertools import chain, combinations import numpy as np -import xarray as xr from xarray import DataArray from xarray import concat as xr_concat from pytensor.tensor import scalar from pytensor.xtensor.shape import ( concat, - expand_dims, squeeze, stack, transpose, @@ -373,162 +371,71 @@ def test_squeeze_errors(): fn2(x2_test) -def test_expand_dims_explicit(): - """Test expand_dims with explicitly named dimensions and sizes.""" - - # 1D case - x = xtensor("x", dims=("city",), shape=(3,)) - y = expand_dims(x, "country") - fn = xr_function([x], y) - x_xr = xr_arange_like(x) - xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country")) - - # 2D case +def test_expand_dims(): + """Test expand_dims.""" x = xtensor("x", dims=("city", "year"), shape=(2, 2)) - y = expand_dims(x, "country") - fn = xr_function([x], y) - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) - - # 3D case - x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2)) - y = expand_dims(x, "country") - fn = xr_function([x], y) - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) - - # Prepending various dims - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - for new_dim in ("x", "y", "z"): - y = expand_dims(x, new_dim) - assert y.type.dims == (new_dim, "a", "b") - assert y.type.shape == (1, 2, 3) - - # Explicit size=1 behaves like default - y1 = expand_dims(x, "batch", size=1) - y2 = expand_dims(x, "batch") - fn1 = xr_function([x], y1) - fn2 = xr_function([x], y2) x_test = xr_arange_like(x) - xr_assert_allclose(fn1(x_test), fn2(x_test)) - # Scalar expansion - x = xtensor("x", dims=(), shape=()) - y = expand_dims(x, "batch") - assert y.type.dims == ("batch",) - assert y.type.shape == (1,) + # Implicit size=1 + y = x.expand_dims("country") fn = xr_function([x], y) - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch")) + xr_assert_allclose(fn(x_test), x_test.expand_dims("country")) - # Static size > 1: broadcast - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch", size=4) - fn = xr_function([x], y) - expected = xr.DataArray( - np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)), - dims=("batch", "a", "b"), - coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]}, - ) - xr_assert_allclose(fn(xr_arange_like(x)), expected) + # Explicit size=1 + y = x.expand_dims("country", size=1) + xr_assert_allclose(fn(x_test), x_test.expand_dims("country")) - # Insert new dim between existing dims - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "new") - # Insert new dim between a and b: ("a", "new", "b") - y = transpose(y, "a", "new", "b") + # Explicit size > 1 + y = x.expand_dims("country", size=4) fn = xr_function([x], y) - x_test = xr_arange_like(x) - expected = x_test.expand_dims("new").transpose("a", "new", "b") - xr_assert_allclose(fn(x_test), expected) + xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 4})) - # Expand with multiple dims - x = xtensor("x", dims=(), shape=()) - y = expand_dims(expand_dims(x, "a"), "b") + # Test with multiple dimensions + y = x.expand_dims(["country", "state"]) fn = xr_function([x], y) - expected = xr_arange_like(x).expand_dims("a").expand_dims("b") - xr_assert_allclose(fn(xr_arange_like(x)), expected) + xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"])) + # Test with a dict of sizes + y = x.expand_dims({"country": 2, "state": 3}) + fn = xr_function([x], y) + xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3})) -def test_expand_dims_symbolic_size(): - """Test expand_dims with symbolic sizes.""" - - # Symbolic size=1: same as default - size_sym_1 = scalar("size_sym_1", dtype="int64") - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch", size=size_sym_1) - fn = xr_function([x, size_sym_1], y, on_unused_input="ignore") - x_test = xr_arange_like(x) - xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch")) - - # Test using symbolic size from an existing dimension of the same tensor - # This verifies that expand_dims can use the size of one dimension to create another - x = xtensor(dims=("a", "b", "c")) - y = expand_dims(x, "d", size=x.sizes["b"]) + # Test with kwargs (equivalent to dict) + y = x.expand_dims(country=2, state=3) fn = xr_function([x], y) - x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 3, 5))) - res = fn(x_test) - expected = x_test.expand_dims({"d": 3}) # 3 is the size of dimension "b" - xr_assert_allclose(res, expected) + xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3)) - # Test broadcasting with symbolic size from a different tensor - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - other = xtensor("other", dims=("c",), shape=(4,)) - y = expand_dims(x, "batch", size=other.sizes["c"]) - fn = xr_function([x, other], y) - x_test = xr_arange_like(x) - other_test = xr_arange_like(other) - res = fn(x_test, other_test) - expected = x_test.expand_dims( - {"batch": 4} - ) # 4 is the size of dimension "c" in other - xr_assert_allclose(res, expected) + # Symbolic size=1 + size_sym_1 = scalar("size_sym_1", dtype="int64") + y = x.expand_dims("country", size=size_sym_1) + fn = xr_function([x, size_sym_1], y) + xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("country")) # Test behavior with symbolic size > 1 # NOTE: This test documents our current behavior where expand_dims broadcasts to the requested size. # This differs from xarray's behavior where expand_dims always adds a size-1 dimension. size_sym_4 = scalar("size_sym_4", dtype="int64") - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - y = expand_dims(x, "batch", size=size_sym_4) - fn = xr_function([x, size_sym_4], y, on_unused_input="ignore") - x_test = xr_arange_like(x) + y = x.expand_dims("country", size=size_sym_4) + fn = xr_function([x, size_sym_4], y) res = fn(x_test, 4) # Our current behavior: broadcasts to size 4 - expected = x_test.expand_dims({"batch": 4}) + expected = x_test.expand_dims({"country": 4}) xr_assert_allclose(res, expected) # xarray's behavior would be: - # expected = x_test.expand_dims("batch") # always size 1 + # expected = x_test.expand_dims("country") # always size 1 # xr_assert_allclose(res, expected) - # Test using symbolic size from a reduction operation - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) - reduced = x.sum("a") # shape: (b: 3) - y = expand_dims(x, "batch", size=reduced.sizes["b"]) - fn = xr_function([x], y) - x_test = xr_arange_like(x) - res = fn(x_test) - expected = x_test.expand_dims({"batch": 3}) # 3 is the size of dimension "b" - xr_assert_allclose(res, expected) - - # Test chaining expand_dims with symbolic sizes - x = xtensor("x", dims=("a",), shape=(2,)) - y = expand_dims(x, "b", size=x.sizes["a"]) # shape: (a: 2, b: 2) - z = expand_dims(y, "c", size=y.sizes["b"]) # shape: (a: 2, b: 2, c: 2) - fn = xr_function([x], z) - x_test = xr_arange_like(x) - res = fn(x_test) - expected = x_test.expand_dims({"b": 2}).expand_dims({"c": 2}) - xr_assert_allclose(res, expected) + # Test with symbolic sizes in dict + size_sym_1 = scalar("size_sym_1", dtype="int64") + size_sym_2 = scalar("size_sym_2", dtype="int64") + y = x.expand_dims({"country": size_sym_1, "state": size_sym_2}) + fn = xr_function([x, size_sym_1, size_sym_2], y) + xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3})) - # Test bidirectional broadcasting with symbolic sizes - x = xtensor("x", dims=("a",), shape=(2,)) - y = xtensor("y", dims=("b",), shape=(3,)) - # Expand x with size from y, then add y - expanded = expand_dims(x, "b", size=y.sizes["b"]) - z = expanded + y # Should broadcast x to match y's size - fn = xr_function([x, y], z) - x_test = xr_arange_like(x) - y_test = xr_arange_like(y) - res = fn(x_test, y_test) - expected = x_test.expand_dims({"b": 3}) + y_test - xr_assert_allclose(res, expected) + # Test with symbolic sizes in kwargs + y = x.expand_dims(country=size_sym_1, state=size_sym_2) + fn = xr_function([x, size_sym_1, size_sym_2], y) + xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3})) def test_expand_dims_errors(): @@ -536,48 +443,25 @@ def test_expand_dims_errors(): # Expanding existing dim x = xtensor("x", dims=("city",), shape=(3,)) - y = expand_dims(x, "country") + y = x.expand_dims("country") with pytest.raises(ValueError, match="already exists"): - expand_dims(y, "city") + y.expand_dims("city") # Invalid dim type with pytest.raises(TypeError, match="Invalid type for `dim`"): - expand_dims(x, 123) + x.expand_dims(123) # Invalid size type with pytest.raises(TypeError, match="size must be an int or scalar variable"): - expand_dims(x, "new", size=[1]) + x.expand_dims("new", size=[1]) # Duplicate dimension creation - y = expand_dims(x, "new") + y = x.expand_dims("new") with pytest.raises(ValueError, match="already exists"): - expand_dims(y, "new") - + y.expand_dims("new") -def test_expand_dims_multiple(): - """Test expanding multiple dimensions at once using a list of strings.""" - x = xtensor("x", dims=("city",), shape=(3,)) - y = expand_dims(x, ["country", "state"]) - fn = xr_function([x], y) - x_xr = xr_arange_like(x) - xr_assert_allclose(fn(x_xr), x_xr.expand_dims(["country", "state"])) - - # Test with a dict of sizes - y = expand_dims(x, {"country": 2, "state": 3}) - fn = xr_function([x], y) - x_xr = xr_arange_like(x) - xr_assert_allclose(fn(x_xr), x_xr.expand_dims({"country": 2, "state": 3})) - - # Test with a mix of strings and dicts - y = expand_dims(x, ["country", "state"], size=3) - fn = xr_function([x], y) - x_xr = xr_arange_like(x) - xr_assert_allclose(fn(x_xr), x_xr.expand_dims(["country", "state"])) - - # Test with symbolic sizes in dict - size_sym_1 = scalar("size_sym_1", dtype="int64") - size_sym_2 = scalar("size_sym_2", dtype="int64") - y = expand_dims(x, {"country": size_sym_1, "state": size_sym_2}) - fn = xr_function([x, size_sym_1, size_sym_2], y, on_unused_input="ignore") - x_xr = xr_arange_like(x) - xr_assert_allclose(fn(x_xr, 2, 3), x_xr.expand_dims({"country": 2, "state": 3})) + # Test for error when both positional and size are given + with pytest.raises( + ValueError, match="cannot specify both keyword and positional arguments" + ): + x.expand_dims(["country", "state"], size=3) From 30354cd16ed73ba671af2012321600a0794f2425 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 11 Jun 2025 15:07:34 -0400 Subject: [PATCH 06/11] Adding axis parameter --- pytensor/xtensor/shape.py | 26 +++++++++++++++++++++++++- pytensor/xtensor/type.py | 16 ++++++++++++++-- tests/xtensor/test_shape.py | 12 ++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 8a1e94daaf..b1eb15aefa 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -441,7 +441,7 @@ def make_node(self, x, size): return Apply(self, [x, size], [out]) -def expand_dims(x, dim=None, create_index_for_new_dim=True, **dim_kwargs): +def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwargs): """Add one or more new dimensions to an XTensorVariable.""" x = as_xtensor(x) @@ -479,8 +479,32 @@ def expand_dims(x, dim=None, create_index_for_new_dim=True, **dim_kwargs): for name, size in dims_dict.items(): canonical_dims.append((name, size)) + # Store original dimensions for later use with axis + original_dims = list(x.type.dims) + # Insert each new dim at the front (reverse order preserves user intent) for name, size in reversed(canonical_dims): x = ExpandDims(dim=name)(x, size) + # If axis is specified, transpose to put new dimensions in the right place + if axis is not None: + new_dim_names = [name for name, _ in canonical_dims] + # Wrap non-sequence axis in a list + if not isinstance(axis, Sequence): + axis = [axis] + + # xarray requires len(axis) == len(new_dim_names) + if len(axis) != len(new_dim_names): + raise ValueError("lengths of dim and axis should be identical.") + + # Insert each new dim at the specified axis position + # Start with original dims, then insert each new dim at its axis + target_dims = list(original_dims) + # axis values are relative to the result after each insertion + for insert_dim, insert_axis in sorted( + zip(new_dim_names, axis), key=lambda x: x[1] + ): + target_dims.insert(insert_axis, insert_dim) + x = transpose(x, *target_dims) + return x diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index f9f12eb385..dbddda79af 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -485,6 +485,7 @@ def expand_dims( self, dim: str | Sequence[str] | dict[str, int | Sequence] | None = None, create_index_for_new_dim: bool = True, + axis: int | None = None, **dim_kwargs, ): """Add one or more new dimensions to the tensor. @@ -497,7 +498,14 @@ def expand_dims( - int: the new size - sequence: coordinates (length determines size) create_index_for_new_dim : bool, default: True - (Ignored for now) Matches xarray API, reserved for future use. + Currently ignored. Reserved for future coordinate support. + In xarray, when True (default), creates a coordinate index for the new dimension + with values from 0 to size-1. When False, no coordinate index is created. + axis : int | None, default: None + Not implemented yet. In xarray, specifies where to insert the new dimension(s). + By default (None), new dimensions are inserted at the beginning (axis=0). + Symbolic axis is not supported yet. + Negative values count from the end. **dim_kwargs : int | Sequence Alternative to `dim` dict. Only used if `dim` is None. @@ -507,7 +515,11 @@ def expand_dims( A tensor with additional dimensions inserted at the front. """ return px.shape.expand_dims( - self, dim, create_index_for_new_dim=create_index_for_new_dim, **dim_kwargs + self, + dim, + create_index_for_new_dim=create_index_for_new_dim, + axis=axis, + **dim_kwargs, ) # ndarray methods diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 8123de7003..b07a541341 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -437,6 +437,18 @@ def test_expand_dims(): fn = xr_function([x, size_sym_1, size_sym_2], y) xr_assert_allclose(fn(x_test, 2, 3), x_test.expand_dims({"country": 2, "state": 3})) + # Test with axis parameter + y = x.expand_dims("country", axis=1) + fn = xr_function([x], y) + xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1)) + + # Add two new dims at axis=[1, 2] + y = x.expand_dims(["country", "state"], axis=[1, 2]) + fn = xr_function([x], y) + xr_assert_allclose( + fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2]) + ) + def test_expand_dims_errors(): """Test error handling in expand_dims.""" From efc0fdb8f12eade28fafa333cd4fea43ad49d682 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 11 Jun 2025 15:20:46 -0400 Subject: [PATCH 07/11] Fixing axis parameter --- pytensor/xtensor/shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index b1eb15aefa..456fe97fca 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -505,6 +505,6 @@ def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwa zip(new_dim_names, axis), key=lambda x: x[1] ): target_dims.insert(insert_axis, insert_dim) - x = transpose(x, *target_dims) + x = Transpose(dims=tuple(target_dims))(x) return x From 3d15197d445f8b5afe8fa5789f11c5169278eabb Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 11 Jun 2025 15:25:27 -0400 Subject: [PATCH 08/11] Clean up make_node --- pytensor/xtensor/shape.py | 33 +++++++++------------------------ tests/xtensor/test_shape.py | 4 ---- 2 files changed, 9 insertions(+), 28 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 456fe97fca..6544a677ac 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -6,10 +6,10 @@ import numpy as np from pytensor.graph import Apply -from pytensor.graph.basic import Constant from pytensor.scalar import discrete_dtypes, upcast from pytensor.tensor import as_tensor, get_scalar_constant_value from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.type import integer_dtypes from pytensor.tensor.variable import TensorVariable from pytensor.xtensor.basic import XOp from pytensor.xtensor.type import as_xtensor, xtensor @@ -403,32 +403,17 @@ def make_node(self, x, size): if self.dim in x.type.dims: raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}") - # Check if size is a valid type before converting - if not ( - isinstance(size, int | np.integer) - or (hasattr(size, "ndim") and getattr(size, "ndim", None) == 0) - ): - raise TypeError( - f"size must be an int or scalar variable, got: {type(size)}" - ) - - # Determine shape + size = as_xtensor(size, dims=()) + if not (size.dtype in integer_dtypes and size.ndim == 0): + raise ValueError(f"size should be an integer scalar, got {size.type}") try: - static_size = get_scalar_constant_value(size) + static_size = int(get_scalar_constant_value(size)) except NotScalarConstantError: static_size = None - - if static_size is not None: - new_shape = (int(static_size), *x.type.shape) - else: - new_shape = (None, *x.type.shape) # symbolic size - - # Convert size to tensor - size = as_xtensor(size, dims=()) - - # Check if size is a constant and validate it - if isinstance(size, Constant) and size.data < 0: - raise ValueError(f"size must be 0 or positive, got: {size.data}") + # If size is a constant, validate it + if static_size is not None and static_size < 0: + raise ValueError(f"size must be 0 or positive, got: {static_size}") + new_shape = (static_size, *x.type.shape) # Insert new dim at front new_dims = (self.dim, *x.type.dims) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index b07a541341..3d5213c3f9 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -463,10 +463,6 @@ def test_expand_dims_errors(): with pytest.raises(TypeError, match="Invalid type for `dim`"): x.expand_dims(123) - # Invalid size type - with pytest.raises(TypeError, match="size must be an int or scalar variable"): - x.expand_dims("new", size=[1]) - # Duplicate dimension creation y = x.expand_dims("new") with pytest.raises(ValueError, match="already exists"): From 6f810d3e89544084a8f089778ff7fdd059e7a510 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 12 Jun 2025 10:45:42 -0400 Subject: [PATCH 09/11] Cleanup --- pytensor/xtensor/shape.py | 42 ++++++++++++++++------------ pytensor/xtensor/type.py | 4 +-- tests/xtensor/test_shape.py | 55 +++++++++++++++---------------------- 3 files changed, 49 insertions(+), 52 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 6544a677ac..51e7c73c11 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -10,7 +10,6 @@ from pytensor.tensor import as_tensor, get_scalar_constant_value from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.type import integer_dtypes -from pytensor.tensor.variable import TensorVariable from pytensor.xtensor.basic import XOp from pytensor.xtensor.type import as_xtensor, xtensor @@ -410,6 +409,7 @@ def make_node(self, x, size): static_size = int(get_scalar_constant_value(size)) except NotScalarConstantError: static_size = None + # If size is a constant, validate it if static_size is not None and static_size < 0: raise ValueError(f"size must be 0 or positive, got: {static_size}") @@ -430,6 +430,14 @@ def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwa """Add one or more new dimensions to an XTensorVariable.""" x = as_xtensor(x) + # Warn if create_index_for_new_dim is used (not supported) + if not create_index_for_new_dim: + warnings.warn( + "create_index_for_new_dim=False has no effect in pytensor.xtensor", + UserWarning, + stacklevel=2, + ) + # Extract size from dim_kwargs if present size = dim_kwargs.pop("size", 1) if dim_kwargs else 1 @@ -451,6 +459,12 @@ def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwa dims_dict = {} for name, val in dim.items(): if isinstance(val, Sequence | np.ndarray) and not isinstance(val, str): + warnings.warn( + "When a sequence is provided as a dimension size, only its length is used. " + "The actual values (which would be coordinates in xarray) are ignored.", + UserWarning, + stacklevel=2, + ) dims_dict[name] = len(val) elif isinstance(val, int): dims_dict[name] = val @@ -459,21 +473,16 @@ def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwa else: raise TypeError(f"Invalid type for `dim`: {type(dim)}") - # Convert to canonical form: list of (dim_name, size) - canonical_dims: list[tuple[str, int | np.integer | TensorVariable]] = [] - for name, size in dims_dict.items(): - canonical_dims.append((name, size)) - - # Store original dimensions for later use with axis + # Store original dimensions for axis handling original_dims = list(x.type.dims) # Insert each new dim at the front (reverse order preserves user intent) - for name, size in reversed(canonical_dims): + for name, size in reversed(dims_dict.items()): x = ExpandDims(dim=name)(x, size) # If axis is specified, transpose to put new dimensions in the right place if axis is not None: - new_dim_names = [name for name, _ in canonical_dims] + new_dim_names = list(dims_dict.keys()) # Wrap non-sequence axis in a list if not isinstance(axis, Sequence): axis = [axis] @@ -482,14 +491,13 @@ def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwa if len(axis) != len(new_dim_names): raise ValueError("lengths of dim and axis should be identical.") - # Insert each new dim at the specified axis position - # Start with original dims, then insert each new dim at its axis - target_dims = list(original_dims) - # axis values are relative to the result after each insertion - for insert_dim, insert_axis in sorted( - zip(new_dim_names, axis), key=lambda x: x[1] - ): - target_dims.insert(insert_axis, insert_dim) + # Insert new dimensions at their specified positions + target_dims = original_dims.copy() + for name, pos in zip(new_dim_names, axis): + # Convert negative axis to positive position relative to current dims + if pos < 0: + pos = len(target_dims) + pos + 1 + target_dims.insert(pos, name) x = Transpose(dims=tuple(target_dims))(x) return x diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index dbddda79af..9fea411129 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -485,7 +485,7 @@ def expand_dims( self, dim: str | Sequence[str] | dict[str, int | Sequence] | None = None, create_index_for_new_dim: bool = True, - axis: int | None = None, + axis: int | Sequence[int] | None = None, **dim_kwargs, ): """Add one or more new dimensions to the tensor. @@ -501,7 +501,7 @@ def expand_dims( Currently ignored. Reserved for future coordinate support. In xarray, when True (default), creates a coordinate index for the new dimension with values from 0 to size-1. When False, no coordinate index is created. - axis : int | None, default: None + axis : int | Sequence[int] | None, default: None Not implemented yet. In xarray, specifies where to insert the new dimension(s). By default (None), new dimensions are inserted at the beginning (axis=0). Symbolic axis is not supported yet. diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 3d5213c3f9..e7a1e7d08d 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -381,15 +381,6 @@ def test_expand_dims(): fn = xr_function([x], y) xr_assert_allclose(fn(x_test), x_test.expand_dims("country")) - # Explicit size=1 - y = x.expand_dims("country", size=1) - xr_assert_allclose(fn(x_test), x_test.expand_dims("country")) - - # Explicit size > 1 - y = x.expand_dims("country", size=4) - fn = xr_function([x], y) - xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 4})) - # Test with multiple dimensions y = x.expand_dims(["country", "state"]) fn = xr_function([x], y) @@ -407,26 +398,11 @@ def test_expand_dims(): # Symbolic size=1 size_sym_1 = scalar("size_sym_1", dtype="int64") - y = x.expand_dims("country", size=size_sym_1) + y = x.expand_dims({"country": size_sym_1}) fn = xr_function([x, size_sym_1], y) - xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("country")) - - # Test behavior with symbolic size > 1 - # NOTE: This test documents our current behavior where expand_dims broadcasts to the requested size. - # This differs from xarray's behavior where expand_dims always adds a size-1 dimension. - size_sym_4 = scalar("size_sym_4", dtype="int64") - y = x.expand_dims("country", size=size_sym_4) - fn = xr_function([x, size_sym_4], y) - res = fn(x_test, 4) - # Our current behavior: broadcasts to size 4 - expected = x_test.expand_dims({"country": 4}) - xr_assert_allclose(res, expected) - # xarray's behavior would be: - # expected = x_test.expand_dims("country") # always size 1 - # xr_assert_allclose(res, expected) + xr_assert_allclose(fn(x_test, 1), x_test.expand_dims({"country": 1})) # Test with symbolic sizes in dict - size_sym_1 = scalar("size_sym_1", dtype="int64") size_sym_2 = scalar("size_sym_2", dtype="int64") y = x.expand_dims({"country": size_sym_1, "state": size_sym_2}) fn = xr_function([x, size_sym_1, size_sym_2], y) @@ -442,13 +418,32 @@ def test_expand_dims(): fn = xr_function([x], y) xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=1)) - # Add two new dims at axis=[1, 2] + # Test with negative axis parameter + y = x.expand_dims("country", axis=-1) + fn = xr_function([x], y) + xr_assert_allclose(fn(x_test), x_test.expand_dims("country", axis=-1)) + + # Add two new dims with axis parameters y = x.expand_dims(["country", "state"], axis=[1, 2]) fn = xr_function([x], y) xr_assert_allclose( fn(x_test), x_test.expand_dims(["country", "state"], axis=[1, 2]) ) + # Add two dims with negative axis parameters + y = x.expand_dims(["country", "state"], axis=[-1, -2]) + fn = xr_function([x], y) + xr_assert_allclose( + fn(x_test), x_test.expand_dims(["country", "state"], axis=[-1, -2]) + ) + + # Add two dims with positive and negative axis parameters + y = x.expand_dims(["country", "state"], axis=[-2, 1]) + fn = xr_function([x], y) + xr_assert_allclose( + fn(x_test), x_test.expand_dims(["country", "state"], axis=[-2, 1]) + ) + def test_expand_dims_errors(): """Test error handling in expand_dims.""" @@ -467,9 +462,3 @@ def test_expand_dims_errors(): y = x.expand_dims("new") with pytest.raises(ValueError, match="already exists"): y.expand_dims("new") - - # Test for error when both positional and size are given - with pytest.raises( - ValueError, match="cannot specify both keyword and positional arguments" - ): - x.expand_dims(["country", "state"], size=3) From 841e52f9af50a37614059c020fc1af2e6483871b Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Thu, 12 Jun 2025 11:46:16 -0400 Subject: [PATCH 10/11] Cleanup --- pytensor/xtensor/shape.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 51e7c73c11..238b109610 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -430,6 +430,9 @@ def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwa """Add one or more new dimensions to an XTensorVariable.""" x = as_xtensor(x) + # Store original dimensions for axis handling + original_dims = x.type.dims + # Warn if create_index_for_new_dim is used (not supported) if not create_index_for_new_dim: warnings.warn( @@ -473,27 +476,23 @@ def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwa else: raise TypeError(f"Invalid type for `dim`: {type(dim)}") - # Store original dimensions for axis handling - original_dims = list(x.type.dims) - # Insert each new dim at the front (reverse order preserves user intent) for name, size in reversed(dims_dict.items()): x = ExpandDims(dim=name)(x, size) # If axis is specified, transpose to put new dimensions in the right place if axis is not None: - new_dim_names = list(dims_dict.keys()) # Wrap non-sequence axis in a list if not isinstance(axis, Sequence): axis = [axis] - # xarray requires len(axis) == len(new_dim_names) - if len(axis) != len(new_dim_names): + # require len(axis) == len(dims_dict) + if len(axis) != len(dims_dict): raise ValueError("lengths of dim and axis should be identical.") # Insert new dimensions at their specified positions - target_dims = original_dims.copy() - for name, pos in zip(new_dim_names, axis): + target_dims = list(original_dims) + for name, pos in zip(dims_dict, axis): # Convert negative axis to positive position relative to current dims if pos < 0: pos = len(target_dims) + pos + 1 From 7bb0d6e2a1c8233cfa80cb592d7424e7f3801335 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Fri, 13 Jun 2025 11:39:22 -0400 Subject: [PATCH 11/11] Improving test coverage, error handling --- pytensor/xtensor/shape.py | 31 ++++++++++++++++--------------- tests/xtensor/test_shape.py | 23 ++++++++++++++++++++--- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 238b109610..f604dc8188 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -1,5 +1,5 @@ import warnings -from collections.abc import Sequence +from collections.abc import Hashable, Sequence from types import EllipsisType from typing import Literal @@ -426,7 +426,7 @@ def make_node(self, x, size): return Apply(self, [x, size], [out]) -def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwargs): +def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwargs): """Add one or more new dimensions to an XTensorVariable.""" x = as_xtensor(x) @@ -434,34 +434,36 @@ def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwa original_dims = x.type.dims # Warn if create_index_for_new_dim is used (not supported) - if not create_index_for_new_dim: + if create_index_for_new_dim is not None: warnings.warn( "create_index_for_new_dim=False has no effect in pytensor.xtensor", UserWarning, stacklevel=2, ) - # Extract size from dim_kwargs if present - size = dim_kwargs.pop("size", 1) if dim_kwargs else 1 - - # xarray compatibility: error if a sequence (list/tuple) of dims and size are given - if (isinstance(dim, list | tuple)) and ("size" in locals() and size != 1): - raise ValueError("cannot specify both keyword and positional arguments") - if dim is None: dim = dim_kwargs elif dim_kwargs: raise ValueError("Cannot specify both `dim` and `**dim_kwargs`") + # Check that dim is Hashable or a sequence of Hashable or dict + if not isinstance(dim, Hashable): + if not isinstance(dim, Sequence | dict): + raise TypeError(f"unhashable type: {type(dim).__name__}") + if not all(isinstance(d, Hashable) for d in dim): + raise TypeError(f"unhashable type in {type(dim).__name__}") + # Normalize to a dimension-size mapping if isinstance(dim, str): - dims_dict = {dim: size} + dims_dict = {dim: 1} elif isinstance(dim, Sequence) and not isinstance(dim, dict): dims_dict = {d: 1 for d in dim} elif isinstance(dim, dict): dims_dict = {} for name, val in dim.items(): - if isinstance(val, Sequence | np.ndarray) and not isinstance(val, str): + if isinstance(val, str): + raise TypeError(f"Dimension size cannot be a string: {val}") + if isinstance(val, Sequence | np.ndarray): warnings.warn( "When a sequence is provided as a dimension size, only its length is used. " "The actual values (which would be coordinates in xarray) are ignored.", @@ -469,10 +471,9 @@ def expand_dims(x, dim=None, create_index_for_new_dim=True, axis=None, **dim_kwa stacklevel=2, ) dims_dict[name] = len(val) - elif isinstance(val, int): - dims_dict[name] = val else: - dims_dict[name] = val # symbolic/int scalar allowed + # should be int or symbolic scalar + dims_dict[name] = val else: raise TypeError(f"Invalid type for `dim`: {type(dim)}") diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index e7a1e7d08d..69802dcec0 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -376,7 +376,7 @@ def test_expand_dims(): x = xtensor("x", dims=("city", "year"), shape=(2, 2)) x_test = xr_arange_like(x) - # Implicit size=1 + # Implicit size 1 y = x.expand_dims("country") fn = xr_function([x], y) xr_assert_allclose(fn(x_test), x_test.expand_dims("country")) @@ -386,7 +386,7 @@ def test_expand_dims(): fn = xr_function([x], y) xr_assert_allclose(fn(x_test), x_test.expand_dims(["country", "state"])) - # Test with a dict of sizes + # Test with a dict of name-size pairs y = x.expand_dims({"country": 2, "state": 3}) fn = xr_function([x], y) xr_assert_allclose(fn(x_test), x_test.expand_dims({"country": 2, "state": 3})) @@ -396,7 +396,15 @@ def test_expand_dims(): fn = xr_function([x], y) xr_assert_allclose(fn(x_test), x_test.expand_dims(country=2, state=3)) - # Symbolic size=1 + # Test with a dict of name-coord array pairs + y = x.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])}) + fn = xr_function([x], y) + xr_assert_allclose( + fn(x_test), + x_test.expand_dims({"country": np.array([1, 2]), "state": np.array([3, 4, 5])}), + ) + + # Symbolic size 1 size_sym_1 = scalar("size_sym_1", dtype="int64") y = x.expand_dims({"country": size_sym_1}) fn = xr_function([x, size_sym_1], y) @@ -462,3 +470,12 @@ def test_expand_dims_errors(): y = x.expand_dims("new") with pytest.raises(ValueError, match="already exists"): y.expand_dims("new") + + # Find out what xarray does with a numpy array as dim + # x_test = xr_arange_like(x) + # x_test.expand_dims(np.array([1, 2])) + # TypeError: unhashable type: 'numpy.ndarray' + + # Test with a numpy array as dim (not supported) + with pytest.raises(TypeError, match="unhashable type"): + y.expand_dims(np.array([1, 2]))