Skip to content
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
25397c8
WIP Implement index operations for XTensorVariables
ricardoV94 May 21, 2025
e32d865
Add diff method to XTensorVariable
ricardoV94 May 26, 2025
5988cec
Add transpose operation for labeled tensors with ellipsis support
AllenDowney May 27, 2025
5936ab2
Refactor: Extract ellipsis expansion logic into helper function
AllenDowney May 27, 2025
6fc7b89
Fix lint errors: remove trailing whitespace from docstrings
AllenDowney May 27, 2025
0778cf7
Format files with ruff
AllenDowney May 27, 2025
c7ce0c9
Remove commented out line
AllenDowney May 27, 2025
bc2cbc0
Add missing_dims parameter to transpose for XTensorVariable and core,…
AllenDowney May 28, 2025
7bfa2b2
Add missing_dims parameter to transpose for XTensorVariable and core,…
AllenDowney May 28, 2025
d4f5512
Fix linting issues: remove unused Union import and use dict.fromkeys()
AllenDowney May 28, 2025
1ed01c4
Improve expand_ellipsis with validate parameter and update tests
AllenDowney May 28, 2025
4f010e0
Apply ruff-format to shape.py, type.py, and test_shape.py for consist…
AllenDowney May 28, 2025
f0ea583
Simplify make_node in Transpose class by combining ignore/warn cases
AllenDowney May 28, 2025
0125bd2
Format expand_ellipsis call for better readability
AllenDowney May 28, 2025
30e1a42
WIP Implement index operations for XTensorVariables
ricardoV94 May 21, 2025
29b954a
Add diff method to XTensorVariable
ricardoV94 May 26, 2025
a76b15e
Format and simplify expand_ellipsis; auto-fix with pre-commit; update…
AllenDowney May 28, 2025
af14c90
Improve expand_dims: add tests, fix reshape usage, and ensure code st…
AllenDowney May 28, 2025
6208092
Merge WIP changes from origin/labeled_tensors
AllenDowney May 28, 2025
15f4c48
Implement squeeze
AllenDowney May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
)
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
XTensorType,
as_xtensor,
xtensor,
xtensor_constant,
Expand Down
142 changes: 142 additions & 0 deletions pytensor/xtensor/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# HERE LIE DRAGONS
# Uselful links to make sense of all the numpy/xarray complexity
# https://numpy.org/devdocs//user/basics.indexing.html
# https://numpy.org/neps/nep-0021-advanced-indexing.html
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html

from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.scalar.basic import discrete_dtypes
from pytensor.tensor.basic import as_tensor
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor


def as_idx_variable(idx):
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
raise TypeError(
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
)
if isinstance(idx, slice):
idx = make_slice(idx)
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
pass
else:
# Must be integer indices, we already counted for None and slices
try:
idx = as_tensor(idx)
except TypeError:
idx = as_xtensor(idx)
if idx.type.dtype == "bool":
raise NotImplementedError("Boolean indexing not yet supported")
if idx.type.dtype not in discrete_dtypes:
raise TypeError("Numerical indices must be integers or boolean")
if idx.type.dtype == "bool" and idx.type.ndim == 0:
# This can't be triggered right now, but will once we lift the boolean restriction
raise NotImplementedError("Scalar boolean indices not supported")
return idx


def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
if dim_length is None:
return None
if isinstance(slc, Constant):
d = slc.data
start, stop, step = d.start, d.stop, d.step
elif slc.owner is None:
# It's a root variable no way of knowing what we're getting
return None
else:
# It's a MakeSliceOp
start, stop, step = slc.owner.inputs
if isinstance(start, Constant):
start = start.data
else:
return None
if isinstance(stop, Constant):
stop = stop.data
else:
return None
if isinstance(step, Constant):
step = step.data
else:
return None
return len(range(*slice(start, stop, step).indices(dim_length)))


class Index(XOp):
__props__ = ()

def make_node(self, x, *idxs):
x = as_xtensor(x)
idxs = [as_idx_variable(idx) for idx in idxs]

x_ndim = x.type.ndim
x_dims = x.type.dims
x_shape = x.type.shape
out_dims = []
out_shape = []
has_unlabeled_vector_idx = False
has_labeled_vector_idx = False
for i, idx in enumerate(idxs):
if i == x_ndim:
raise IndexError("Too many indices")
if isinstance(idx.type, SliceType):
out_dims.append(x_dims[i])
out_shape.append(get_static_slice_length(idx, x_shape[i]))
elif isinstance(idx.type, XTensorType):
if has_unlabeled_vector_idx:
raise NotImplementedError(
"Mixing of labeled and unlabeled vector indexing not implemented"
)
has_labeled_vector_idx = True
idx_dims = idx.type.dims
for dim in idx_dims:
idx_dim_shape = idx.type.shape[idx_dims.index(dim)]
if dim in out_dims:
# Dim already introduced in output by a previous index
# Update static shape or raise if incompatible
out_dim_pos = out_dims.index(dim)
out_dim_shape = out_shape[out_dim_pos]
if out_dim_shape is None:
# We don't know the size of the dimension yet
out_shape[out_dim_pos] = idx_dim_shape
elif (
idx_dim_shape is not None and idx_dim_shape != out_dim_shape
):
raise IndexError(
f"Dimension of indexers mismatch for dim {dim}"
)
else:
# New dimension
out_dims.append(dim)
out_shape.append(idx_dim_shape)

else: # TensorType
if idx.type.ndim == 0:
# Scalar, dimension is dropped
pass
elif idx.type.ndim == 1:
if has_labeled_vector_idx:
raise NotImplementedError(
"Mixing of labeled and unlabeled vector indexing not implemented"
)
has_unlabeled_vector_idx = True
out_dims.append(x_dims[i])
out_shape.append(idx.type.shape[0])
else:
# Same error that xarray raises
raise IndexError(
"Unlabeled multi-dimensional array cannot be used for indexing"
)
for j in range(i + 1, x_ndim):
# Add any unindexed dimensions
out_dims.append(x_dims[j])
out_shape.append(x_shape[j])

output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, *idxs], [output])


index = Index()
1 change: 1 addition & 0 deletions pytensor/xtensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytensor.xtensor.rewriting.basic
import pytensor.xtensor.rewriting.indexing
import pytensor.xtensor.rewriting.reduction
import pytensor.xtensor.rewriting.shape
import pytensor.xtensor.rewriting.vectorization
67 changes: 67 additions & 0 deletions pytensor/xtensor/rewriting/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from pytensor.graph import Constant, node_rewriter
from pytensor.tensor import TensorType, specify_shape
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.indexing import Index
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
from pytensor.xtensor.type import XTensorType


def to_basic_idx(idx):
if isinstance(idx.type, SliceType):
if isinstance(idx, Constant):
return idx.data
elif idx.owner:
# MakeSlice Op
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible
return slice(
*[
None if isinstance(i.type, NoneTypeT) else i
for i in idx.owner.inputs
]
)
else:
return idx
if (
isinstance(idx.type, XTensorType | TensorType)
and idx.type.ndim == 0
and idx.type.dtype != bool
):
return idx
raise TypeError("Cannot convert idx to basic idx")


def _count_idx_types(idxs):
basic, vector, xvector = 0, 0, 0
for idx in idxs:
if isinstance(idx.type, SliceType):
basic += 1
elif idx.type.ndim == 0:
basic += 1
elif isinstance(idx.type, TensorType):
vector += 1
else:
xvector += 1
return basic, vector, xvector


@register_xcanonicalize
@node_rewriter(tracks=[Index])
def lower_index(fgraph, node):
x, *idxs = node.inputs
[out] = node.outputs
x_tensor = tensor_from_xtensor(x)
n_basic, n_vector, n_xvector = _count_idx_types(idxs)
if n_xvector == 0 and n_vector == 0:
x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)]
elif n_vector == 1 and n_xvector == 0:
# Special case for single vector index, no orthogonal indexing
x_tensor_indexed = x_tensor[tuple(idxs)]
else:
# Not yet implemented
return None

# Add lost shape if any
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims)
return [new_out]
18 changes: 17 additions & 1 deletion pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytensor.tensor import broadcast_to, join, moveaxis
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
from pytensor.xtensor.shape import Concat, Stack
from pytensor.xtensor.shape import Concat, Stack, Transpose


@register_xcanonicalize
Expand Down Expand Up @@ -70,3 +70,19 @@ def lower_concat(fgraph, node):
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out]


@register_xcanonicalize
@node_rewriter(tracks=[Transpose])
def lower_transpose(fgraph, node):
[x] = node.inputs
# Use the final dimensions that were already computed in make_node
out_dims = node.outputs[0].type.dims
in_dims = x.type.dims

# Compute the permutation based on the final dimensions
perm = tuple(in_dims.index(d) for d in out_dims)
x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = x_tensor.transpose(perm)
new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims)
return [new_out]
131 changes: 131 additions & 0 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import warnings
from collections.abc import Sequence
from typing import Literal

from pytensor import Variable
from pytensor.graph import Apply
Expand Down Expand Up @@ -73,6 +75,135 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
return y


def expand_ellipsis(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this check that there's at most one dims and raise otherwise?

dims: tuple[str, ...], all_dims: tuple[str, ...], validate: bool = True
) -> tuple[str, ...]:
"""Expand ellipsis in dimension permutation.

Parameters
----------
dims : tuple[str, ...]
The dimension permutation, which may contain ellipsis
all_dims : tuple[str, ...]
All available dimensions
validate : bool, default True
Whether to check that all non-ellipsis elements in dims are valid dimension names.

Returns
-------
tuple[str, ...]
The expanded dimension permutation

Raises
------
ValueError
If more than one ellipsis is present in dims.
If any non-ellipsis element in dims is not a valid dimension name and validate is True.
"""
if dims == () or dims == (...,):
return tuple(reversed(all_dims))

if ... not in dims:
if validate:
invalid_dims = set(dims) - set(all_dims)
if invalid_dims:
raise ValueError(
f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}"
)
return dims

if sum(d is ... for d in dims) > 1:
raise ValueError("an index can only have a single ellipsis ('...')")

pre = []
post = []
found = False
for d in dims:
if d is ...:
found = True
elif not found:
pre.append(d)
else:
post.append(d)
if validate:
invalid_dims = set(pre + post) - set(all_dims)
if invalid_dims:
raise ValueError(
f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}"
)
middle = [d for d in all_dims if d not in pre + post]
return tuple(pre + middle + post)


class Transpose(XOp):
__props__ = ("dims", "missing_dims")

def __init__(
self,
dims: tuple[str, ...],
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
):
super().__init__()
self.dims = dims
self.missing_dims = missing_dims

def make_node(self, x):
x = as_xtensor(x)
dims = expand_ellipsis(
self.dims, x.type.dims, validate=(self.missing_dims == "raise")
)

# Handle missing dimensions based on missing_dims setting
if self.missing_dims in ("ignore", "warn"):
if self.missing_dims == "warn":
missing = set(dims) - set(x.type.dims)
if missing:
warnings.warn(f"Dimensions {missing} do not exist in {x.type.dims}")
# Filter out dimensions that don't exist and add remaining ones
dims = tuple(d for d in dims if d in x.type.dims)
remaining_dims = tuple(d for d in x.type.dims if d not in dims)
dims = dims + remaining_dims
else: # "raise"
if set(dims) != set(x.type.dims):
raise ValueError(f"Transpose dims {dims} must match {x.type.dims}")

output = xtensor(
dtype=x.type.dtype,
shape=tuple(x.type.shape[x.type.dims.index(d)] for d in dims),
dims=dims,
)
return Apply(self, [x], [output])


def transpose(x, *dims, missing_dims: Literal["raise", "warn", "ignore"] = "raise"):
"""Transpose dimensions of the tensor.

Parameters
----------
x : XTensorVariable
Input tensor to transpose.
*dims : str
Dimensions to transpose to. Can include ellipsis (...) to represent
remaining dimensions in their original order.
missing_dims : {"raise", "warn", "ignore"}, optional
How to handle dimensions that don't exist in the input tensor:
- "raise": Raise an error if any dimensions don't exist (default)
- "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist

Returns
-------
XTensorVariable
Transposed tensor with reordered dimensions.

Raises
------
ValueError
If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise".
"""
return Transpose(dims, missing_dims=missing_dims)(x)


class Concat(XOp):
__props__ = ("dim",)

Expand Down
Loading