Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pytensor.graph import node_rewriter
from pytensor.tensor import (
broadcast_to,
expand_dims,
join,
moveaxis,
specify_shape,
Expand All @@ -10,6 +11,7 @@
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import (
Concat,
ExpandDims,
Squeeze,
Stack,
Transpose,
Expand Down Expand Up @@ -132,3 +134,35 @@ def local_squeeze_reshape(fgraph, node):

new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
return [new_out]


@register_lower_xtensor
@node_rewriter([ExpandDims])
def local_expand_dims_reshape(fgraph, node):
"""Rewrite ExpandDims using tensor operations."""
x, size = node.inputs
out = node.outputs[0]

# Convert inputs to tensors
x_tensor = tensor_from_xtensor(x)
size_tensor = tensor_from_xtensor(size)

# Get the new dimension name and position
new_axis = 0 # Always insert at front

# Use tensor operations
if out.type.shape[0] == 1:
# Simple case: just expand with size 1
result_tensor = expand_dims(x_tensor, new_axis)
else:
# First expand with size 1
expanded = expand_dims(x_tensor, new_axis)
# Then broadcast to the requested size
result_tensor = broadcast_to(expanded, (size_tensor, *x_tensor.shape))

# Preserve static shape information
result_tensor = specify_shape(result_tensor, out.type.shape)

# Convert result back to xtensor
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
return [result]
120 changes: 120 additions & 0 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from types import EllipsisType
from typing import Literal

import numpy as np

from pytensor.graph import Apply
from pytensor.graph.basic import Constant
from pytensor.scalar import discrete_dtypes, upcast
from pytensor.tensor import as_tensor, get_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.variable import TensorVariable
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor

Expand Down Expand Up @@ -380,3 +384,119 @@ def squeeze(x, dim=None):
return x # no-op if nothing to squeeze

return Squeeze(dims=dims)(x)


class ExpandDims(XOp):
"""Add a new dimension to an XTensorVariable."""

__props__ = ("dim",)

def __init__(self, dim):
if not isinstance(dim, str):
raise TypeError(f"`dim` must be a string, got: {type(self.dim)}")

self.dim = dim

def make_node(self, x, size):
x = as_xtensor(x)

if self.dim in x.type.dims:
raise ValueError(f"Dimension {self.dim} already exists in {x.type.dims}")

# Check if size is a valid type before converting
if not (
isinstance(size, int | np.integer)
or (hasattr(size, "ndim") and getattr(size, "ndim", None) == 0)
):
raise TypeError(
f"size must be an int or scalar variable, got: {type(size)}"
)

# Determine shape
try:
static_size = get_scalar_constant_value(size)
except NotScalarConstantError:
static_size = None

if static_size is not None:
new_shape = (int(static_size), *x.type.shape)
else:
new_shape = (None, *x.type.shape) # symbolic size

# Convert size to tensor
size = as_xtensor(size, dims=())

# Check if size is a constant and validate it
if isinstance(size, Constant) and size.data < 0:
raise ValueError(f"size must be 0 or positive, got: {size.data}")

# Insert new dim at front
new_dims = (self.dim, *x.type.dims)

out = xtensor(
dtype=x.type.dtype,
shape=new_shape,
dims=new_dims,
)
return Apply(self, [x, size], [out])


def expand_dims(x, dim=None, create_index_for_new_dim=True, **dim_kwargs):
"""Add one or more new dimensions to an XTensorVariable.

Parameters
----------
x : XTensorVariable
Input tensor.
dim : str | Sequence[str] | dict[str, int | Sequence] | None
If str or sequence of str, new dimensions with size 1.
If dict, keys are dimension names and values are either:
- int: the new size
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool, default: True
(Ignored for now) Matches xarray API, reserved for future use.
**dim_kwargs : int | Sequence
Alternative to `dim` dict. Only used if `dim` is None.

Returns
-------
XTensorVariable
A tensor with additional dimensions inserted at the front.
"""
x = as_xtensor(x)

# Extract size from dim_kwargs if present
size = dim_kwargs.pop("size", 1) if dim_kwargs else 1

if dim is None:
dim = dim_kwargs
elif dim_kwargs:
raise ValueError("Cannot specify both `dim` and `**dim_kwargs`")

# Normalize to a dimension-size mapping
if isinstance(dim, str):
dims_dict = {dim: size}
elif isinstance(dim, Sequence) and not isinstance(dim, dict):
dims_dict = {d: 1 for d in dim}
elif isinstance(dim, dict):
dims_dict = {}
for name, val in dim.items():
if isinstance(val, Sequence | np.ndarray) and not isinstance(val, str):
dims_dict[name] = len(val)
elif isinstance(val, int):
dims_dict[name] = val
else:
dims_dict[name] = val # symbolic/int scalar allowed
else:
raise TypeError(f"Invalid type for `dim`: {type(dim)}")

# Convert to canonical form: list of (dim_name, size)
canonical_dims: list[tuple[str, int | np.integer | TensorVariable]] = []
for name, size in dims_dict.items():
canonical_dims.append((name, size))

# Insert each new dim at the front (reverse order preserves user intent)
for name, size in reversed(canonical_dims):
x = ExpandDims(dim=name)(x, size)

return x
21 changes: 21 additions & 0 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,27 @@ def squeeze(
raise NotImplementedError("Squeeze with axis not Implemented")
return px.shape.squeeze(self, dim)

def expand_dims(
self,
dim: str | None = None,
size: int | Variable = 1,
):
"""Add a new dimension to the tensor.

Parameters
----------
dim : str or None
Name of new dimension. If None, returns self unchanged.
size : int or symbolic, optional
Size of the new dimension (default 1)

Returns
-------
XTensorVariable
Tensor with the new dimension inserted
"""
return px.shape.expand_dims(self, dim, size=size)

# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def clip(self, min, max):
Expand Down
Loading