Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
25397c8
WIP Implement index operations for XTensorVariables
ricardoV94 May 21, 2025
e32d865
Add diff method to XTensorVariable
ricardoV94 May 26, 2025
5988cec
Add transpose operation for labeled tensors with ellipsis support
AllenDowney May 27, 2025
5936ab2
Refactor: Extract ellipsis expansion logic into helper function
AllenDowney May 27, 2025
6fc7b89
Fix lint errors: remove trailing whitespace from docstrings
AllenDowney May 27, 2025
0778cf7
Format files with ruff
AllenDowney May 27, 2025
c7ce0c9
Remove commented out line
AllenDowney May 27, 2025
bc2cbc0
Add missing_dims parameter to transpose for XTensorVariable and core,…
AllenDowney May 28, 2025
7bfa2b2
Add missing_dims parameter to transpose for XTensorVariable and core,…
AllenDowney May 28, 2025
d4f5512
Fix linting issues: remove unused Union import and use dict.fromkeys()
AllenDowney May 28, 2025
1ed01c4
Improve expand_ellipsis with validate parameter and update tests
AllenDowney May 28, 2025
4f010e0
Apply ruff-format to shape.py, type.py, and test_shape.py for consist…
AllenDowney May 28, 2025
f0ea583
Simplify make_node in Transpose class by combining ignore/warn cases
AllenDowney May 28, 2025
0125bd2
Format expand_ellipsis call for better readability
AllenDowney May 28, 2025
30e1a42
WIP Implement index operations for XTensorVariables
ricardoV94 May 21, 2025
29b954a
Add diff method to XTensorVariable
ricardoV94 May 26, 2025
a76b15e
Format and simplify expand_ellipsis; auto-fix with pre-commit; update…
AllenDowney May 28, 2025
af14c90
Improve expand_dims: add tests, fix reshape usage, and ensure code st…
AllenDowney May 28, 2025
6208092
Merge WIP changes from origin/labeled_tensors
AllenDowney May 28, 2025
15f4c48
Implement squeeze
AllenDowney May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
)
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
XTensorType,
as_xtensor,
xtensor,
xtensor_constant,
Expand Down
142 changes: 142 additions & 0 deletions pytensor/xtensor/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# HERE LIE DRAGONS
# Uselful links to make sense of all the numpy/xarray complexity
# https://numpy.org/devdocs//user/basics.indexing.html
# https://numpy.org/neps/nep-0021-advanced-indexing.html
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html

from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.scalar.basic import discrete_dtypes
from pytensor.tensor.basic import as_tensor
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor


def as_idx_variable(idx):
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
raise TypeError(
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
)
if isinstance(idx, slice):
idx = make_slice(idx)
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
pass
else:
# Must be integer indices, we already counted for None and slices
try:
idx = as_tensor(idx)
except TypeError:
idx = as_xtensor(idx)
if idx.type.dtype == "bool":
raise NotImplementedError("Boolean indexing not yet supported")
if idx.type.dtype not in discrete_dtypes:
raise TypeError("Numerical indices must be integers or boolean")
if idx.type.dtype == "bool" and idx.type.ndim == 0:
# This can't be triggered right now, but will once we lift the boolean restriction
raise NotImplementedError("Scalar boolean indices not supported")
return idx


def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
if dim_length is None:
return None
if isinstance(slc, Constant):
d = slc.data
start, stop, step = d.start, d.stop, d.step
elif slc.owner is None:
# It's a root variable no way of knowing what we're getting
return None
else:
# It's a MakeSliceOp
start, stop, step = slc.owner.inputs
if isinstance(start, Constant):
start = start.data
else:
return None
if isinstance(stop, Constant):
stop = stop.data
else:
return None
if isinstance(step, Constant):
step = step.data
else:
return None
return len(range(*slice(start, stop, step).indices(dim_length)))


class Index(XOp):
__props__ = ()

def make_node(self, x, *idxs):
x = as_xtensor(x)
idxs = [as_idx_variable(idx) for idx in idxs]

x_ndim = x.type.ndim
x_dims = x.type.dims
x_shape = x.type.shape
out_dims = []
out_shape = []
has_unlabeled_vector_idx = False
has_labeled_vector_idx = False
for i, idx in enumerate(idxs):
if i == x_ndim:
raise IndexError("Too many indices")
if isinstance(idx.type, SliceType):
out_dims.append(x_dims[i])
out_shape.append(get_static_slice_length(idx, x_shape[i]))
elif isinstance(idx.type, XTensorType):
if has_unlabeled_vector_idx:
raise NotImplementedError(
"Mixing of labeled and unlabeled vector indexing not implemented"
)
has_labeled_vector_idx = True
idx_dims = idx.type.dims
for dim in idx_dims:
idx_dim_shape = idx.type.shape[idx_dims.index(dim)]
if dim in out_dims:
# Dim already introduced in output by a previous index
# Update static shape or raise if incompatible
out_dim_pos = out_dims.index(dim)
out_dim_shape = out_shape[out_dim_pos]
if out_dim_shape is None:
# We don't know the size of the dimension yet
out_shape[out_dim_pos] = idx_dim_shape
elif (
idx_dim_shape is not None and idx_dim_shape != out_dim_shape
):
raise IndexError(
f"Dimension of indexers mismatch for dim {dim}"
)
else:
# New dimension
out_dims.append(dim)
out_shape.append(idx_dim_shape)

else: # TensorType
if idx.type.ndim == 0:
# Scalar, dimension is dropped
pass
elif idx.type.ndim == 1:
if has_labeled_vector_idx:
raise NotImplementedError(
"Mixing of labeled and unlabeled vector indexing not implemented"
)
has_unlabeled_vector_idx = True
out_dims.append(x_dims[i])
out_shape.append(idx.type.shape[0])
else:
# Same error that xarray raises
raise IndexError(
"Unlabeled multi-dimensional array cannot be used for indexing"
)
for j in range(i + 1, x_ndim):
# Add any unindexed dimensions
out_dims.append(x_dims[j])
out_shape.append(x_shape[j])

output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, *idxs], [output])


index = Index()
1 change: 1 addition & 0 deletions pytensor/xtensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytensor.xtensor.rewriting.basic
import pytensor.xtensor.rewriting.indexing
import pytensor.xtensor.rewriting.reduction
import pytensor.xtensor.rewriting.shape
import pytensor.xtensor.rewriting.vectorization
67 changes: 67 additions & 0 deletions pytensor/xtensor/rewriting/indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from pytensor.graph import Constant, node_rewriter
from pytensor.tensor import TensorType, specify_shape
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.indexing import Index
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
from pytensor.xtensor.type import XTensorType


def to_basic_idx(idx):
if isinstance(idx.type, SliceType):
if isinstance(idx, Constant):
return idx.data
elif idx.owner:
# MakeSlice Op
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible
return slice(
*[
None if isinstance(i.type, NoneTypeT) else i
for i in idx.owner.inputs
]
)
else:
return idx
if (
isinstance(idx.type, XTensorType | TensorType)
and idx.type.ndim == 0
and idx.type.dtype != bool
):
return idx
raise TypeError("Cannot convert idx to basic idx")


def _count_idx_types(idxs):
basic, vector, xvector = 0, 0, 0
for idx in idxs:
if isinstance(idx.type, SliceType):
basic += 1
elif idx.type.ndim == 0:
basic += 1
elif isinstance(idx.type, TensorType):
vector += 1
else:
xvector += 1
return basic, vector, xvector


@register_xcanonicalize
@node_rewriter(tracks=[Index])
def lower_index(fgraph, node):
x, *idxs = node.inputs
[out] = node.outputs
x_tensor = tensor_from_xtensor(x)
n_basic, n_vector, n_xvector = _count_idx_types(idxs)
if n_xvector == 0 and n_vector == 0:
x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)]
elif n_vector == 1 and n_xvector == 0:
# Special case for single vector index, no orthogonal indexing
x_tensor_indexed = x_tensor[tuple(idxs)]
else:
# Not yet implemented
return None

# Add lost shape if any
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims)
return [new_out]
17 changes: 16 additions & 1 deletion pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytensor.tensor import broadcast_to, join, moveaxis
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
from pytensor.xtensor.shape import Concat, Stack
from pytensor.xtensor.shape import Concat, Stack, Transpose, expand_ellipsis


@register_xcanonicalize
Expand Down Expand Up @@ -70,3 +70,18 @@ def lower_concat(fgraph, node):
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out]


@register_xcanonicalize
@node_rewriter(tracks=[Transpose])
def lower_transpose(fgraph, node):
[x] = node.inputs
# Determine the permutation of axes
out_dims = node.op.dims
in_dims = x.type.dims
expanded_dims = expand_ellipsis(out_dims, in_dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to expand_ellipsis again, you have the ground truth in node.outputs[0].type.dims that you already computed in make_node

perm = tuple(in_dims.index(d) for d in expanded_dims)
x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = x_tensor.transpose(perm)
new_out = xtensor_from_tensor(x_tensor_transposed, dims=expanded_dims)
return [new_out]
61 changes: 61 additions & 0 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,67 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
return y


def expand_ellipsis(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this check that there's at most one dims and raise otherwise?

dims: tuple[str, ...], all_dims: tuple[str, ...]
) -> tuple[str, ...]:
"""Expand ellipsis in dimension permutation.

Parameters
----------
dims : tuple[str, ...]
The dimension permutation, which may contain ellipsis
all_dims : tuple[str, ...]
All available dimensions

Returns
-------
tuple[str, ...]
The expanded dimension permutation
"""
if dims == () or dims == (...,):
return tuple(reversed(all_dims))

if ... not in dims:
return dims

pre = []
post = []
found = False
for d in dims:
if d is ...:
found = True
elif not found:
pre.append(d)
else:
post.append(d)
middle = [d for d in all_dims if d not in pre + post]
return tuple(pre + middle + post)


class Transpose(XOp):
__props__ = ("dims",)

def __init__(self, dims: tuple[str, ...]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the type hint is wrong, because dims can include ellipsis

Suggested change
def __init__(self, dims: tuple[str, ...]):
def __init__(self, dims: tuple[str | Ellipsis, ...]):

super().__init__()
self.dims = dims

def make_node(self, x):
x = as_xtensor(x)
dims = expand_ellipsis(self.dims, x.type.dims)
if set(dims) != set(x.type.dims):
raise ValueError(f"Transpose dims {dims} must match {x.type.dims}")
output = xtensor(
dtype=x.type.dtype,
shape=tuple(x.type.shape[x.type.dims.index(d)] for d in dims),
dims=dims,
)
return Apply(self, [x], [output])


def transpose(x, *dims):
return Transpose(dims)(x)
Copy link
Member

@ricardoV94 ricardoV94 May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xarray has a missing_dims, that we could provide here as well: https://docs.xarray.dev/en/latest/generated/xarray.DataArray.transpose.html

We already do that for isel:

def isel(
self,
indexers: dict[str, Any] | None = None,
drop: bool = False, # Unused by PyTensor
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
**indexers_kwargs,
):



class Concat(XOp):
__props__ = ("dim",)

Expand Down
Loading