From 87ecb5fb223819d98643df8342a7c038a1621fea Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 6 Jun 2025 13:21:12 +0200 Subject: [PATCH 01/13] Allow lazy import of linker dispatchers --- pytensor/link/jax/dispatch/__init__.py | 2 +- pytensor/link/jax/dispatch/basic.py | 16 ++++------------ pytensor/link/jax/linker.py | 21 ++++++++++++++++++--- pytensor/link/numba/dispatch/__init__.py | 2 +- pytensor/link/numba/dispatch/basic.py | 18 ++++-------------- pytensor/link/numba/linker.py | 16 ++++++++++++++-- pytensor/link/pytorch/dispatch/__init__.py | 2 +- pytensor/link/pytorch/dispatch/basic.py | 12 ++++-------- pytensor/link/pytorch/linker.py | 16 ++++++++++++---- 9 files changed, 59 insertions(+), 46 deletions(-) diff --git a/pytensor/link/jax/dispatch/__init__.py b/pytensor/link/jax/dispatch/__init__.py index 5da81bf80c..d1cd3c172e 100644 --- a/pytensor/link/jax/dispatch/__init__.py +++ b/pytensor/link/jax/dispatch/__init__.py @@ -1,5 +1,5 @@ # isort: off -from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify +from pytensor.link.jax.linker import jax_funcify, jax_typify # Load dispatch specializations import pytensor.link.jax.dispatch.blas diff --git a/pytensor/link/jax/dispatch/basic.py b/pytensor/link/jax/dispatch/basic.py index bd559ee716..cbff77261b 100644 --- a/pytensor/link/jax/dispatch/basic.py +++ b/pytensor/link/jax/dispatch/basic.py @@ -1,6 +1,5 @@ import warnings from collections.abc import Callable -from functools import singledispatch import jax import jax.numpy as jnp @@ -10,8 +9,10 @@ from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.configdefaults import config +from pytensor.graph import Op from pytensor.graph.fg import FunctionGraph from pytensor.ifelse import IfElse +from pytensor.link.jax.linker import jax_funcify, jax_typify from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import Assert, CheckAndRaise @@ -22,15 +23,6 @@ jax.config.update("jax_enable_x64", False) -@singledispatch -def jax_typify(data, dtype=None, **kwargs): - r"""Convert instances of PyTensor `Type`\s to JAX types.""" - if dtype is None: - return data - else: - return jnp.array(data, dtype=dtype) - - @jax_typify.register(np.ndarray) def jax_typify_ndarray(data, dtype=None, **kwargs): if len(data.shape) == 0: @@ -38,8 +30,8 @@ def jax_typify_ndarray(data, dtype=None, **kwargs): return jnp.array(data, dtype=dtype) -@singledispatch -def jax_funcify(op, node=None, storage_map=None, **kwargs): +@jax_funcify.register(Op) +def jax_funcify_op(op, node=None, storage_map=None, **kwargs): """Create a JAX compatible function from an PyTensor `Op`.""" raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}") diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index eb2f4fb267..8e88645c88 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -1,4 +1,5 @@ import warnings +from functools import singledispatch from numpy.random import Generator @@ -6,6 +7,23 @@ from pytensor.link.basic import JITLinker +@singledispatch +def jax_typify(data, dtype=None, **kwargs): + r"""Convert instances of PyTensor `Type`\s to JAX types.""" + import jax.numpy as jnp + + if dtype is None: + return data + else: + return jnp.array(data, dtype=dtype) + + +@singledispatch +def jax_funcify(obj, *args, **kwargs): + """Create a JAX compatible function from an PyTensor `Op`.""" + raise NotImplementedError(f"No JAX conversion for the given type: {type(obj)}") + + class JAXLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using JAX.""" @@ -14,7 +32,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): - from pytensor.link.jax.dispatch import jax_funcify from pytensor.link.jax.dispatch.shape import JAXShapeTuple from pytensor.tensor.random.type import RandomType @@ -111,8 +128,6 @@ def convert_scalar_shape_inputs( return convert_scalar_shape_inputs def create_thunk_inputs(self, storage_map): - from pytensor.link.jax.dispatch import jax_typify - thunk_inputs = [] for n in self.fgraph.inputs: sinput = storage_map[n] diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 1fefb1d06d..4aa147341e 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -1,5 +1,5 @@ # isort: off -from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify +from pytensor.link.numba.linker import numba_funcify, numba_typify # Load dispatch specializations import pytensor.link.numba.dispatch.blockwise diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index f6e62ae2f8..8d4404dc0e 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -2,7 +2,6 @@ import sys import warnings from copy import copy -from functools import singledispatch from textwrap import dedent import numba @@ -21,11 +20,13 @@ from pytensor.compile.builders import OpFromGraph from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.ops import DeepCopyOp +from pytensor.graph import Op from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.type import Type from pytensor.ifelse import IfElse from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType +from pytensor.link.numba.linker import numba_funcify, numba_typify from pytensor.link.utils import ( compile_function_src, fgraph_to_python, @@ -276,11 +277,6 @@ def create_arg_string(x): return args -@singledispatch -def numba_typify(data, dtype=None, **kwargs): - return data - - def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): """Create a Numba compatible function from a Pytensor `Op`.""" @@ -326,14 +322,8 @@ def perform(*inputs): return perform -@singledispatch -def numba_funcify(op, node=None, storage_map=None, **kwargs): - """Generate a numba function for a given op and apply node. - - The resulting function will usually use the `no_cpython_wrapper` - argument in numba, so it can not be called directly from python, - but only from other jit functions. - """ +@numba_funcify.register(Op) +def numba_funcify_op(op, node=None, storage_map=None, **kwargs): return generate_fallback_impl(op, node, storage_map, **kwargs) diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 59dc81e1b0..1380b3d745 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -1,12 +1,24 @@ +from functools import singledispatch + from pytensor.link.basic import JITLinker +@singledispatch +def numba_typify(data, dtype=None, **kwargs): + raise NotImplementedError( + f"Numba funcify not implemented for data type {type(data)}" + ) + + +@singledispatch +def numba_funcify(obj, *args, **kwargs): + raise NotImplementedError(f"Numba funcify not implemented for type {type(obj)}") + + class NumbaLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Numba.""" def fgraph_convert(self, fgraph, **kwargs): - from pytensor.link.numba.dispatch import numba_funcify - return numba_funcify(fgraph, **kwargs) def jit_compile(self, fn): diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index 4caabf3e03..8b48eca89a 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -1,5 +1,5 @@ # isort: off -from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify +from pytensor.link.pytorch.linker import pytorch_funcify, pytorch_typify # # Load dispatch specializations import pytensor.link.pytorch.dispatch.blas diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index c33b2e6227..48b766cc44 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -1,4 +1,3 @@ -from functools import singledispatch from types import NoneType import numpy as np @@ -10,9 +9,11 @@ from pytensor.compile.builders import OpFromGraph from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.ops import DeepCopyOp +from pytensor.graph import Op from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.ifelse import IfElse +from pytensor.link.pytorch.linker import pytorch_funcify, pytorch_typify from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise from pytensor.tensor.basic import ( @@ -27,11 +28,6 @@ ) -@singledispatch -def pytorch_typify(data, **kwargs): - raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}") - - @pytorch_typify.register(np.ndarray) @pytorch_typify.register(torch.Tensor) def pytorch_typify_tensor(data, dtype=None, **kwargs): @@ -45,8 +41,8 @@ def pytorch_typify_no_conversion_needed(data, **kwargs): return data -@singledispatch -def pytorch_funcify(op, node=None, storage_map=None, **kwargs): +@pytorch_funcify.register(Op) +def pytorch_funcify_op(op, node=None, storage_map=None, **kwargs): """Create a PyTorch compatible function from an PyTensor `Op`.""" raise NotImplementedError( f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation" diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index b8475e3157..3a19b34fcc 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -1,7 +1,19 @@ +from functools import singledispatch + from pytensor.link.basic import JITLinker from pytensor.link.utils import unique_name_generator +@singledispatch +def pytorch_typify(data, **kwargs): + raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}") + + +@singledispatch +def pytorch_funcify(obj, *args, **kwargs): + raise NotImplementedError(f"pytorch_funcify is not implemented for {type(obj)}") + + class PytorchLinker(JITLinker): """A `Linker` that compiles NumPy-based operations using torch.compile.""" @@ -10,8 +22,6 @@ def __init__(self, *args, **kwargs): self.gen_functors = [] def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): - from pytensor.link.pytorch.dispatch import pytorch_funcify - # We want to have globally unique names # across the entire pytensor graph, not # just the subgraph @@ -40,8 +50,6 @@ def jit_compile(self, fn): # flag that tend to help our graphs torch._dynamo.config.capture_dynamic_output_shape_ops = True - from pytensor.link.pytorch.dispatch import pytorch_typify - class wrapper: """ Pytorch would fail compiling our method when trying From 74ae7ecb7ca1d6c3032c6a37651b0e22d3fe6d2a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 2 Aug 2023 14:59:22 +0200 Subject: [PATCH 02/13] Basic labeled tensor functionality --- .github/workflows/test.yml | 11 + pytensor/xtensor/__init__.py | 12 + pytensor/xtensor/basic.py | 127 +++++++++ pytensor/xtensor/readme.md | 69 +++++ pytensor/xtensor/rewriting/__init__.py | 1 + pytensor/xtensor/rewriting/basic.py | 62 +++++ pytensor/xtensor/rewriting/utils.py | 33 +++ pytensor/xtensor/type.py | 351 +++++++++++++++++++++++++ tests/xtensor/__init__.py | 0 tests/xtensor/util.py | 60 +++++ 10 files changed, 726 insertions(+) create mode 100644 pytensor/xtensor/__init__.py create mode 100644 pytensor/xtensor/basic.py create mode 100644 pytensor/xtensor/readme.md create mode 100644 pytensor/xtensor/rewriting/__init__.py create mode 100644 pytensor/xtensor/rewriting/basic.py create mode 100644 pytensor/xtensor/rewriting/utils.py create mode 100644 pytensor/xtensor/type.py create mode 100644 tests/xtensor/__init__.py create mode 100644 tests/xtensor/util.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9677615206..1f6cbc798f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -82,6 +82,7 @@ jobs: install-numba: [0] install-jax: [0] install-torch: [0] + install-xarray: [0] part: - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests/scan" @@ -115,6 +116,7 @@ jobs: install-numba: 0 install-jax: 0 install-torch: 0 + install-xarray: 0 - install-numba: 1 os: "ubuntu-latest" python-version: "3.10" @@ -150,6 +152,13 @@ jobs: fast-compile: 0 float32: 0 part: "tests/link/pytorch" + - install-xarray: 1 + os: "ubuntu-latest" + python-version: "3.13" + numpy-version: ">=2.0" + fast-compile: 0 + float32: 0 + part: "tests/xtensor" - os: macos-15 python-version: "3.13" numpy-version: ">=2.0" @@ -196,6 +205,7 @@ jobs: if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi + if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi pip install pytest-sphinx pip install -e ./ @@ -212,6 +222,7 @@ jobs: INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_TORCH: ${{ matrix.install-torch}} + INSTALL_XARRAY: ${{ matrix.install-xarray }} OS: ${{ matrix.os}} - name: Run tests diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py new file mode 100644 index 0000000000..6c25adc05f --- /dev/null +++ b/pytensor/xtensor/__init__.py @@ -0,0 +1,12 @@ +import warnings + +import pytensor.xtensor.rewriting +from pytensor.xtensor.type import ( + XTensorType, + as_xtensor, + xtensor, + xtensor_constant, +) + + +warnings.warn("xtensor module is experimental and full of bugs") diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py new file mode 100644 index 0000000000..925fb01fb3 --- /dev/null +++ b/pytensor/xtensor/basic.py @@ -0,0 +1,127 @@ +from collections.abc import Sequence + +from pytensor.compile import ViewOp +from pytensor.graph import Apply, Op +from pytensor.link.c.op import COp +from pytensor.link.jax.linker import jax_funcify +from pytensor.link.numba.linker import numba_funcify +from pytensor.link.pytorch.linker import pytorch_funcify +from pytensor.tensor.type import TensorType +from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor + + +class XOp(Op): + """A base class for XOps that shouldn't be materialized""" + + def perform(self, node, inputs, outputs): + raise NotImplementedError( + f"xtensor operation {self} must be lowered to equivalent tensor operations" + ) + + +class XTypeCastOp(COp): + """Base class for Ops that type cast between TensorType and XTensorType. + + This is like a `ViewOp` but without the expectation the input and output have identical types. + """ + + view_map = {0: [0]} + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] + + def c_code(self, node, nodename, inp, out, sub): + (iname,) = inp + (oname,) = out + fail = sub["fail"] + + code, _ = ViewOp.c_code_and_version[TensorType] + return code % locals() + + def c_code_cache_version(self): + _, version = ViewOp.c_code_and_version[TensorType] + return (version,) + + +@numba_funcify.register(XTypeCastOp) +def numba_funcify_XCast(op, *args, **kwargs): + from pytensor.link.numba.dispatch.basic import numba_njit + + @numba_njit + def xcast(x): + return x + + return xcast + + +@jax_funcify.register(XTypeCastOp) +@pytorch_funcify.register(XTypeCastOp) +def funcify_XCast(op, *args, **kwargs): + def xcast(x): + return x + + return xcast + + +class TensorFromXTensor(XTypeCastOp): + __props__ = () + + def make_node(self, x) -> Apply: + if not isinstance(x.type, XTensorType): + raise TypeError(f"x must be have an XTensorType, got {type(x.type)}") + output = TensorType(x.type.dtype, shape=x.type.shape)() + return Apply(self, [x], [output]) + + +tensor_from_xtensor = TensorFromXTensor() + + +class XTensorFromTensor(XTypeCastOp): + __props__ = ("dims",) + + def __init__(self, dims: Sequence[str]): + super().__init__() + self.dims = tuple(dims) + + def make_node(self, x) -> Apply: + if not isinstance(x.type, TensorType): + raise TypeError(f"x must be an TensorType type, got {type(x.type)}") + output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape) + return Apply(self, [x], [output]) + + +def xtensor_from_tensor(x, dims): + return XTensorFromTensor(dims=dims)(x) + + +class Rename(XTypeCastOp): + __props__ = ("new_dims",) + + def __init__(self, new_dims: tuple[str, ...]): + super().__init__() + self.new_dims = new_dims + + def make_node(self, x): + x = as_xtensor(x) + output = x.type.clone(dims=self.new_dims)() + return Apply(self, [x], [output]) + + +def rename(x, name_dict: dict[str, str] | None = None, **names: str): + if name_dict is not None: + if names: + raise ValueError("Cannot use both positional and keyword names in rename") + names = name_dict + + x = as_xtensor(x) + old_names = x.type.dims + new_names = list(old_names) + for old_name, new_name in names.items(): + try: + new_names[old_names.index(old_name)] = new_name + except IndexError: + raise ValueError( + f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}" + ) + + return Rename(tuple(new_names))(x) diff --git a/pytensor/xtensor/readme.md b/pytensor/xtensor/readme.md new file mode 100644 index 0000000000..b3511f56ad --- /dev/null +++ b/pytensor/xtensor/readme.md @@ -0,0 +1,69 @@ +# XTensor Module + +This module implements as abstraction layer on regular tensor operations, that behaves like Xarray. + +A new type `XTensorType`, generalizes the `TensorType` with the addition of a `dims` attribute, +that labels the dimensions of the tensor. + +Variables of `XTensorType` (i.e., `XTensorVariable`s) are the symbolic counterpart to xarray DataArray objects. + +The module implements several PyTensor operations `XOp`s, whose signature mimics that of xarray (and xarray_einstants) DataArray operations. +These operations, unlike most regular PyTensor operations, cannot be directly evaluated, but require a rewrite (lowering) into +a regular tensor graph that can itself be evaluated as usual. + +Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray. +If the existing XOps can be composed to produce the desired result, then we can use them directly. + +## Coordinates +For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`. +The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor. +Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor. + +## Example + +```python +import pytensor.tensor as pt +import pytensor.xtensor as px + +a = pt.tensor("a", shape=(3,)) +b = pt.tensor("b", shape=(4,)) + +ax = px.as_xtensor(a, dims=["x"]) +bx = px.as_xtensor(b, dims=["y"]) + +zx = ax + bx +assert zx.type == px.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4)) + +z = zx.values +z.dprint() +# TensorFromXTensor [id A] +# └─ XElemwise{scalar_op=Add()} [id B] +# ├─ XTensorFromTensor{dims=('x',)} [id C] +# │ └─ a [id D] +# └─ XTensorFromTensor{dims=('y',)} [id E] +# └─ b [id F] +``` + +Once we compile the graph, no `XOp`s are left. + +```python +import pytensor + +with pytensor.config.change_flags(optimizer_verbose=True): + fn = pytensor.function([a, b], z) + +# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0) +# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None +# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None +# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0) + +fn.dprint() +# Add [id A] 2 +# ├─ ExpandDims{axis=1} [id B] 1 +# │ └─ a [id C] +# └─ ExpandDims{axis=0} [id D] 0 +# └─ b [id E] +``` + + + diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py new file mode 100644 index 0000000000..6ff8b80822 --- /dev/null +++ b/pytensor/xtensor/rewriting/__init__.py @@ -0,0 +1 @@ +import pytensor.xtensor.rewriting.basic diff --git a/pytensor/xtensor/rewriting/basic.py b/pytensor/xtensor/rewriting/basic.py new file mode 100644 index 0000000000..b9043c7126 --- /dev/null +++ b/pytensor/xtensor/rewriting/basic.py @@ -0,0 +1,62 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor.basic import register_infer_shape +from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless +from pytensor.xtensor.basic import ( + Rename, + TensorFromXTensor, + XTensorFromTensor, + xtensor_from_tensor, +) +from pytensor.xtensor.rewriting.utils import register_xcanonicalize + + +@register_infer_shape +@register_useless +@register_canonicalize +@register_xcanonicalize +@node_rewriter(tracks=[TensorFromXTensor]) +def useless_tensor_from_xtensor(fgraph, node): + """TensorFromXTensor(XTensorFromTensor(x)) -> x""" + [x] = node.inputs + if x.owner and isinstance(x.owner.op, XTensorFromTensor): + return [x.owner.inputs[0]] + + +@register_infer_shape +@register_useless +@register_canonicalize +@register_xcanonicalize +@node_rewriter(tracks=[XTensorFromTensor]) +def useless_xtensor_from_tensor(fgraph, node): + """XTensorFromTensor(TensorFromXTensor(x)) -> x""" + [x] = node.inputs + if x.owner and isinstance(x.owner.op, TensorFromXTensor): + return [x.owner.inputs[0]] + + +@register_xcanonicalize +@node_rewriter(tracks=[TensorFromXTensor]) +def useless_tensor_from_xtensor_of_rename(fgraph, node): + """TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)""" + [renamed_x] = node.inputs + if renamed_x.owner and isinstance(renamed_x.owner.op, Rename): + [x] = renamed_x.owner.inputs + return node.op(x, return_list=True) + + +@register_xcanonicalize +@node_rewriter(tracks=[Rename]) +def useless_rename(fgraph, node): + """ + + Rename(Rename(x, inner_dims), outer_dims) -> Rename(x, outer_dims) + Rename(X, XTensorFromTensor(x, inner_dims), outer_dims) -> XTensorFrom_tensor(x, outer_dims) + """ + [renamed_x] = node.inputs + if renamed_x.owner: + if isinstance(renamed_x.owner.op, Rename): + [x] = renamed_x.owner.inputs + return [node.op(x)] + elif isinstance(renamed_x.owner.op, TensorFromXTensor): + [x] = renamed_x.owner.inputs + return [xtensor_from_tensor(x, dims=node.op.new_dims)] diff --git a/pytensor/xtensor/rewriting/utils.py b/pytensor/xtensor/rewriting/utils.py new file mode 100644 index 0000000000..03de2c67a9 --- /dev/null +++ b/pytensor/xtensor/rewriting/utils.py @@ -0,0 +1,33 @@ +from pytensor.compile import optdb +from pytensor.graph.rewriting.basic import NodeRewriter +from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase + + +optdb.register( + "xcanonicalize", + EquilibriumDB(ignore_newtrees=False), + "fast_run", + "fast_compile", + "xtensor", + position=0, +) + + +def register_xcanonicalize( + node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: RewriteDatabase | NodeRewriter): + return register_xcanonicalize( + inner_rewriter, node_rewriter, *tags, **kwargs + ) + + return register + + else: + name = kwargs.pop("name", None) or node_rewriter.__name__ + optdb["xtensor"].register( + name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs + ) + return node_rewriter diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py new file mode 100644 index 0000000000..77f6c2a715 --- /dev/null +++ b/pytensor/xtensor/type.py @@ -0,0 +1,351 @@ +from pytensor.compile import ( + DeepCopyOp, + ViewOp, + register_deep_copy_op_c_code, + register_view_op_c_code, +) +from pytensor.tensor import TensorType +from pytensor.tensor.math import variadic_mul + + +try: + import xarray as xr + + XARRAY_AVAILABLE = True +except ModuleNotFoundError: + XARRAY_AVAILABLE = False + +from collections.abc import Sequence +from typing import TypeVar + +import numpy as np + +import pytensor.xtensor as px +from pytensor import _as_symbolic, config +from pytensor.graph import Apply, Constant +from pytensor.graph.basic import OptionalApplyType, Variable +from pytensor.graph.type import HasDataType, HasShape, Type +from pytensor.tensor.basic import constant as tensor_constant +from pytensor.tensor.utils import hash_from_ndarray +from pytensor.tensor.variable import TensorVariable + + +class XTensorType(Type, HasDataType, HasShape): + """A `Type` for Xtensors (Xarray-like tensors with dims).""" + + __props__ = ("dtype", "shape", "dims") + + def __init__( + self, + dtype: str | np.dtype, + *, + dims: Sequence[str], + shape: Sequence[int | None] | None = None, + name: str | None = None, + ): + if dtype == "floatX": + self.dtype = config.floatX + else: + self.dtype = np.dtype(dtype).name + + self.dims = tuple(dims) + if len(set(dims)) < len(dims): + raise ValueError(f"Dimensions must be unique. Found duplicates in {dims}: ") + if shape is None: + self.shape = (None,) * len(self.dims) + else: + self.shape = tuple(shape) + if len(self.shape) != len(self.dims): + raise ValueError( + f"Shape {self.shape} must have the same length as dims {self.dims}" + ) + self.ndim = len(self.dims) + self.name = name + + def clone( + self, + dtype=None, + dims=None, + shape=None, + **kwargs, + ): + if dtype is None: + dtype = self.dtype + if dims is None: + dims = self.dims + if shape is None: + shape = self.shape + return type(self)(dtype=dtype, shape=shape, dims=dims, **kwargs) + + def filter(self, value, strict=False, allow_downcast=None): + # TODO implement this + return value + + def convert_variable(self, var): + # TODO: Implement this + return var + + def __repr__(self): + return f"XTensorType({self.dtype}, {self.dims}, {self.shape})" + + def __hash__(self): + return hash((type(self), self.dtype, self.shape, self.dims)) + + def __eq__(self, other): + return ( + type(self) is type(other) + and self.dims == other.dims + and self.shape == other.shape + ) + + def is_super(self, otype): + if type(self) is not type(otype): + return False + if self.dtype != otype.dtype: + return False + if self.dims != otype.dims: + return False + if any( + s_dim_length is not None and s_dim_length != o_dim_length + for s_dim_length, o_dim_length in zip(self.shape, otype.shape) + ): + return False + return True + + +def xtensor( + name: str | None = None, + *, + dims: Sequence[str], + shape: Sequence[int | None] | None = None, + dtype: str | np.dtype = "floatX", +): + return XTensorType(dtype=dtype, dims=dims, shape=shape)(name=name) + + +_XTensorTypeType = TypeVar("_XTensorTypeType", bound=XTensorType) + + +class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): + # These can't work because Python requires native output types + def __bool__(self): + raise TypeError( + "XTensorVariable cannot be converted to Python boolean. " + "Call `.astype(bool)` for the symbolic equivalent." + ) + + def __index__(self): + raise TypeError( + "XTensorVariable cannot be converted to Python integer. " + "Call `.astype(int)` for the symbolic equivalent." + ) + + def __int__(self): + raise TypeError( + "XTensorVariable cannot be converted to Python integer. " + "Call `.astype(int)` for the symbolic equivalent." + ) + + def __float__(self): + raise TypeError( + "XTensorVariables cannot be converted to Python float. " + "Call `.astype(float)` for the symbolic equivalent." + ) + + def __complex__(self): + raise TypeError( + "XTensorVariables cannot be converted to Python complex number. " + "Call `.astype(complex)` for the symbolic equivalent." + ) + + # DataArray-like attributes + # https://docs.xarray.dev/en/latest/api.html#id1 + @property + def values(self) -> TensorVariable: + return px.basic.tensor_from_xtensor(self) + + # Can't provide property data because that's already taken by Constants! + # data = values + + @property + def coords(self): + raise NotImplementedError("coords not implemented for XTensorVariable") + + @property + def dims(self) -> tuple[str]: + return self.type.dims + + @property + def sizes(self) -> dict[str, TensorVariable]: + return dict(zip(self.dims, self.shape)) + + @property + def as_numpy(self): + # No-op, since the underlying data is always a numpy array + return self + + # ndarray attributes + # https://docs.xarray.dev/en/latest/api.html#ndarray-attributes + @property + def ndim(self) -> int: + return self.type.ndim + + @property + def shape(self) -> tuple[TensorVariable]: + return tuple(px.basic.tensor_from_xtensor(self).shape) + + @property + def size(self): + return variadic_mul(*self.shape) + + @property + def dtype(self): + return self.type.dtype + + # DataArray contents + # https://docs.xarray.dev/en/latest/api.html#dataarray-contents + def rename(self, new_name_or_name_dict=None, **names): + if isinstance(new_name_or_name_dict, str): + new_name = new_name_or_name_dict + name_dict = None + else: + new_name = None + name_dict = new_name_or_name_dict + new_out = px.basic.rename(self, name_dict, **names) + new_out.name = new_name + return new_out + + def item(self): + raise NotImplementedError("item not implemented for XTensorVariable") + + # Indexing + # https://docs.xarray.dev/en/latest/api.html#id2 + def __setitem__(self, key, value): + raise TypeError("XTensorVariable does not support item assignment.") + + @property + def loc(self): + raise NotImplementedError("loc not implemented for XTensorVariable") + + def sel(self, *args, **kwargs): + raise NotImplementedError("sel not implemented for XTensorVariable") + + def __getitem__(self, idx): + raise NotImplementedError("Indexing not yet implemnented") + + +class XTensorConstantSignature(tuple): + def __eq__(self, other): + if type(self) is not type(other): + return False + + (ttype0, data0), (ttype1, data1) = self, other + if ttype0 != ttype1 or data0.shape != data1.shape: + return False + + # TODO: Cash sum and use it in hash like TensorConstant does + return (data0 == data1).all() + + def __ne__(self, other): + return not self == other + + def __hash__(self): + (ttype, data) = self + return hash((type(self), ttype, data.shape)) + + def pytensor_hash(self): + _, data = self + return hash_from_ndarray(data) + + +class XTensorConstant(XTensorVariable, Constant[_XTensorTypeType]): + def __init__(self, type: _XTensorTypeType, data, name=None): + # TODO: Add checks that type and data are compatible + Constant.__init__(self, type, data, name) + + def signature(self): + return XTensorConstantSignature((self.type, self.data)) + + +XTensorType.variable_type = XTensorVariable +XTensorType.constant_type = XTensorConstant + + +def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): + if isinstance(x, xr.DataArray): + x_dims = x.dims + x_data = x.values + + if dims is not None and dims != x_dims: + raise ValueError( + f"xr.DataArray dims {x_dims} don't match requested specified {dims}. " + "Use transpose or rename" + ) + else: + x_data = tensor_constant(x).data + if dims is not None: + x_dims = dims + else: + if x_data.ndim == 0: + x_dims = () + else: + raise TypeError( + "Cannot convert TensorLike constant to XTensorConstant without specifying dims." + ) + try: + return XTensorConstant( + XTensorType(dtype=x_data.dtype, dims=x_dims, shape=x_data.shape), + x_data, + name=name, + ) + except TypeError: + raise TypeError(f"Could not convert {x} to XTensorType") + + +if XARRAY_AVAILABLE: + + @_as_symbolic.register(xr.DataArray) + def as_symbolic_xarray(x, **kwargs): + return xtensor_constant(x, **kwargs) + + +def as_xtensor(x, name=None, dims: Sequence[str] | None = None): + if isinstance(x, Apply): + if len(x.outputs) != 1: + raise ValueError( + "It is ambiguous which output of a multi-output Op has to be fetched.", + x, + ) + else: + x = x.outputs[0] + + if isinstance(x, Variable): + if isinstance(x.type, XTensorType): + return x + if isinstance(x.type, TensorType): + if x.type.ndim > 0 and dims is None: + raise TypeError( + "non-scalar TensorVariable cannot be converted to XTensorVariable without dims." + ) + return px.basic.xtensor_from_tensor(x, dims) + else: + raise TypeError( + "Variable with type {x.type} cannot be converted to XTensorVariable." + ) + try: + return xtensor_constant(x, name=name, dims=dims) + except TypeError as err: + raise TypeError(f"Cannot convert {x} to XTensorType {type(x)}") from err + + +register_view_op_c_code( + XTensorType, + # XTensorType is just TensorType under the hood + *ViewOp.c_code_and_version[TensorType], +) + +register_deep_copy_op_c_code( + XTensorType, + # XTensorType is just TensorType under the hood + *DeepCopyOp.c_code_and_version[TensorType], +) diff --git a/tests/xtensor/__init__.py b/tests/xtensor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/xtensor/util.py b/tests/xtensor/util.py new file mode 100644 index 0000000000..81dc98a75c --- /dev/null +++ b/tests/xtensor/util.py @@ -0,0 +1,60 @@ +# ruff: noqa: E402 +import pytest + + +pytest.importorskip("xarray") + +import numpy as np +from xarray import DataArray +from xarray.testing import assert_allclose + +from pytensor import function +from pytensor.xtensor.type import XTensorType + + +def xr_function(*args, **kwargs): + """Compile and wrap a PyTensor function to return xarray DataArrays.""" + fn = function(*args, **kwargs) + symbolic_outputs = fn.maker.fgraph.outputs + assert all( + isinstance(out.type, XTensorType) for out in symbolic_outputs + ), "All outputs must be xtensor" + + def xfn(*xr_inputs): + np_inputs = [ + inp.values if isinstance(inp, DataArray) else inp for inp in xr_inputs + ] + np_outputs = fn(*np_inputs) + if not isinstance(np_outputs, tuple | list): + return DataArray(np_outputs, dims=symbolic_outputs[0].type.dims) + else: + return tuple( + DataArray(res, dims=out.type.dims) + for res, out in zip(np_outputs, symbolic_outputs) + ) + + xfn.fn = fn + return xfn + + +def xr_assert_allclose(x, y, *args, **kwargs): + # Assert that two xarray DataArrays are close, ignoring coordinates + x = x.drop_vars(x.coords) + y = y.drop_vars(y.coords) + assert_allclose(x, y, *args, **kwargs) + + +def xr_arange_like(x): + return DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + ) + + +def xr_random_like(x, rng=None): + if rng is None: + rng = np.random.default_rng() + + return DataArray( + rng.standard_normal(size=x.type.shape, dtype=x.type.dtype), dims=x.type.dims + ) From 0e3135081673136617460e4839aa37c28e411138 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 6 Jun 2025 11:49:32 +0200 Subject: [PATCH 03/13] Implement stack for XTensorVariables --- pytensor/xtensor/rewriting/__init__.py | 1 + pytensor/xtensor/rewriting/shape.py | 29 +++++++++++ pytensor/xtensor/shape.py | 71 ++++++++++++++++++++++++++ pytensor/xtensor/type.py | 5 ++ tests/xtensor/test_shape.py | 67 ++++++++++++++++++++++++ 5 files changed, 173 insertions(+) create mode 100644 pytensor/xtensor/rewriting/shape.py create mode 100644 pytensor/xtensor/shape.py create mode 100644 tests/xtensor/test_shape.py diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index 6ff8b80822..d4bb32ad66 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1 +1,2 @@ import pytensor.xtensor.rewriting.basic +import pytensor.xtensor.rewriting.shape diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py new file mode 100644 index 0000000000..b2eabb5c8e --- /dev/null +++ b/pytensor/xtensor/rewriting/shape.py @@ -0,0 +1,29 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor import moveaxis +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.rewriting.basic import register_xcanonicalize +from pytensor.xtensor.shape import Stack + + +@register_xcanonicalize +@node_rewriter(tracks=[Stack]) +def lower_stack(fgraph, node): + [x] = node.inputs + batch_ndim = x.type.ndim - len(node.op.stacked_dims) + stacked_axes = [ + i for i, dim in enumerate(x.type.dims) if dim in node.op.stacked_dims + ] + end = tuple(range(-len(stacked_axes), 0)) + + x_tensor = tensor_from_xtensor(x) + x_tensor_transposed = moveaxis(x_tensor, source=stacked_axes, destination=end) + if batch_ndim == (x.type.ndim - 1): + # This happens when we stack a "single" dimension, in this case all we need is the transpose + # Note: If we have meaningful rewrites before lowering, consider canonicalizing this as a Transpose + Rename + final_tensor = x_tensor_transposed + else: + final_shape = (*tuple(x_tensor_transposed.shape)[:batch_ndim], -1) + final_tensor = x_tensor_transposed.reshape(final_shape) + + new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims) + return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py new file mode 100644 index 0000000000..8fa0f42630 --- /dev/null +++ b/pytensor/xtensor/shape.py @@ -0,0 +1,71 @@ +from collections.abc import Sequence + +from pytensor.graph import Apply +from pytensor.xtensor.basic import XOp +from pytensor.xtensor.type import as_xtensor, xtensor + + +class Stack(XOp): + __props__ = ("new_dim_name", "stacked_dims") + + def __init__(self, new_dim_name: str, stacked_dims: tuple[str, ...]): + super().__init__() + if new_dim_name in stacked_dims: + raise ValueError( + f"Stacking dim {new_dim_name} must not be in {stacked_dims}" + ) + if not stacked_dims: + raise ValueError(f"Stacking dims must not be empty: got {stacked_dims}") + self.new_dim_name = new_dim_name + self.stacked_dims = stacked_dims + + def make_node(self, x): + x = as_xtensor(x) + if not (set(self.stacked_dims) <= set(x.type.dims)): + raise ValueError( + f"Stacking dims {self.stacked_dims} must be a subset of {x.type.dims}" + ) + if self.new_dim_name in x.type.dims: + raise ValueError( + f"Stacking dim {self.new_dim_name} must not be in {x.type.dims}" + ) + if len(self.stacked_dims) == x.type.ndim: + batch_dims, batch_shape = (), () + else: + batch_dims, batch_shape = zip( + *( + (dim, shape) + for dim, shape in zip(x.type.dims, x.type.shape) + if dim not in self.stacked_dims + ) + ) + stack_shape = 1 + for dim, shape in zip(x.type.dims, x.type.shape): + if dim in self.stacked_dims: + if shape is None: + stack_shape = None + break + else: + stack_shape *= shape + output = xtensor( + dtype=x.type.dtype, + shape=(*batch_shape, stack_shape), + dims=(*batch_dims, self.new_dim_name), + ) + return Apply(self, [x], [output]) + + +def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]): + if dim is not None: + if dims: + raise ValueError("Cannot use both positional dim and keyword dims in stack") + dims = dim + + y = x + for new_dim_name, stacked_dims in dims.items(): + if isinstance(stacked_dims, str): + raise TypeError( + f"Stacking dims must be a sequence of strings, got a single string: {stacked_dims}" + ) + y = Stack(new_dim_name, tuple(stacked_dims))(y) + return y diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 77f6c2a715..3dbb1bac6c 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -233,6 +233,11 @@ def sel(self, *args, **kwargs): def __getitem__(self, idx): raise NotImplementedError("Indexing not yet implemnented") + # Reshaping and reorganizing + # https://docs.xarray.dev/en/latest/api.html#id8 + def stack(self, dim, **dims): + return px.shape.stack(self, dim, **dims) + class XTensorConstantSignature(tuple): def __eq__(self, other): diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py new file mode 100644 index 0000000000..42c8eb069d --- /dev/null +++ b/tests/xtensor/test_shape.py @@ -0,0 +1,67 @@ +# ruff: noqa: E402 +import pytest + + +pytest.importorskip("xarray") + +from itertools import chain, combinations + +from pytensor.xtensor.shape import stack +from pytensor.xtensor.type import xtensor +from tests.xtensor.util import ( + xr_arange_like, + xr_assert_allclose, + xr_function, +) + + +def powerset(iterable, min_group_size=0): + "Subsequences of the iterable from shortest to longest." + # powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) + s = list(iterable) + return chain.from_iterable( + combinations(s, r) for r in range(min_group_size, len(s) + 1) + ) + + +def test_stack(): + dims = ("a", "b", "c", "d") + x = xtensor("x", dims=dims, shape=(2, 3, 5, 7)) + outs = [ + stack(x, new_dim=dims_to_stack) + for dims_to_stack in powerset(dims, min_group_size=2) + ] + + fn = xr_function([x], outs) + x_test = xr_arange_like(x) + res = fn(x_test) + + expected_res = [ + x_test.stack(new_dim=dims_to_stack) + for dims_to_stack in powerset(dims, min_group_size=2) + ] + for outs_i, res_i, expected_res_i in zip(outs, res, expected_res): + xr_assert_allclose(res_i, expected_res_i) + + +def test_stack_single_dim(): + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 5)) + out = stack(x, {"d": ["a"]}) + assert out.type.dims == ("b", "c", "d") + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + res = fn(x_test) + expected_res = x_test.stack(d=["a"]) + xr_assert_allclose(res, expected_res) + + +def test_multiple_stacks(): + x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7)) + out = stack(x, new_dim1=("a", "b"), new_dim2=("c", "d")) + + fn = xr_function([x], [out]) + x_test = xr_arange_like(x) + res = fn(x_test) + expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d")) + xr_assert_allclose(res[0], expected_res) From d9760d594f167f5a1bca7f1ec2ec3c32052d3901 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 26 May 2025 13:14:41 +0200 Subject: [PATCH 04/13] Implement Elemwise and Blockwise operations for XTensorVariables --- pytensor/xtensor/__init__.py | 3 + pytensor/xtensor/linalg.py | 72 ++++++++++++ pytensor/xtensor/math.py | 31 +++++ pytensor/xtensor/rewriting/__init__.py | 1 + pytensor/xtensor/rewriting/vectorization.py | 64 ++++++++++ pytensor/xtensor/type.py | 124 ++++++++++++++++++++ pytensor/xtensor/vectorization.py | 122 +++++++++++++++++++ tests/xtensor/test_linalg.py | 83 +++++++++++++ tests/xtensor/test_math.py | 109 +++++++++++++++++ 9 files changed, 609 insertions(+) create mode 100644 pytensor/xtensor/linalg.py create mode 100644 pytensor/xtensor/math.py create mode 100644 pytensor/xtensor/rewriting/vectorization.py create mode 100644 pytensor/xtensor/vectorization.py create mode 100644 tests/xtensor/test_linalg.py create mode 100644 tests/xtensor/test_math.py diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index 6c25adc05f..d8c901c75f 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -1,6 +1,9 @@ import warnings import pytensor.xtensor.rewriting +from pytensor.xtensor import ( + linalg, +) from pytensor.xtensor.type import ( XTensorType, as_xtensor, diff --git a/pytensor/xtensor/linalg.py b/pytensor/xtensor/linalg.py new file mode 100644 index 0000000000..5bd8849e98 --- /dev/null +++ b/pytensor/xtensor/linalg.py @@ -0,0 +1,72 @@ +from collections.abc import Sequence +from typing import Literal + +from pytensor.tensor.slinalg import Cholesky, Solve +from pytensor.xtensor.type import as_xtensor +from pytensor.xtensor.vectorization import XBlockwise + + +def cholesky( + x, + lower: bool = True, + *, + check_finite: bool = False, + overwrite_a: bool = False, + on_error: Literal["raise", "nan"] = "raise", + dims: Sequence[str], +): + if len(dims) != 2: + raise ValueError(f"Cholesky needs two dims, got {len(dims)}") + + core_op = Cholesky( + lower=lower, + check_finite=check_finite, + overwrite_a=overwrite_a, + on_error=on_error, + ) + core_dims = ( + ((dims[0], dims[1]),), + ((dims[0], dims[1]),), + ) + x_op = XBlockwise(core_op, signature=core_op.gufunc_signature, core_dims=core_dims) + return x_op(x) + + +def solve( + a, + b, + dims: Sequence[str], + assume_a="gen", + lower: bool = False, + check_finite: bool = False, +): + a, b = as_xtensor(a), as_xtensor(b) + if len(dims) == 2: + b_ndim = 1 + [m1_dim] = [dim for dim in dims if dim not in b.type.dims] + m2_dim = dims[0] if dims[0] != m1_dim else dims[1] + input_core_dims = ((m1_dim, m2_dim), (m2_dim,)) + output_core_dims = ((m2_dim,),) + elif len(dims) == 3: + b_ndim = 2 + [n_dim] = [dim for dim in dims if dim not in a.type.dims] + [m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim] + input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim)) + output_core_dims = ( + ( + m2_dim, + n_dim, + ), + ) + else: + raise ValueError("Solve dims must have length 2 or 3") + + core_op = Solve( + b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite + ) + x_op = XBlockwise( + core_op, + signature=core_op.gufunc_signature, + core_dims=(input_core_dims, output_core_dims), + ) + return x_op(a, b) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py new file mode 100644 index 0000000000..73ee5c8c7b --- /dev/null +++ b/pytensor/xtensor/math.py @@ -0,0 +1,31 @@ +import inspect +import sys + +import pytensor.scalar as ps +from pytensor.scalar import ScalarOp +from pytensor.xtensor.vectorization import XElemwise + + +this_module = sys.modules[__name__] + + +def get_all_scalar_ops(): + """ + Find all scalar operations in the pytensor.scalar module that can be wrapped with XElemwise. + + Returns: + dict: A dictionary mapping operation names to XElemwise instances + """ + result = {} + + # Get all module members + for name, obj in inspect.getmembers(ps): + # Check if the object is a scalar op (has make_node method and is not an abstract class) + if isinstance(obj, ScalarOp): + result[name] = XElemwise(obj) + + return result + + +for name, op in get_all_scalar_ops().items(): + setattr(this_module, name, op) diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index d4bb32ad66..ac74ddd73d 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1,2 +1,3 @@ import pytensor.xtensor.rewriting.basic import pytensor.xtensor.rewriting.shape +import pytensor.xtensor.rewriting.vectorization diff --git a/pytensor/xtensor/rewriting/vectorization.py b/pytensor/xtensor/rewriting/vectorization.py new file mode 100644 index 0000000000..500d6f72d9 --- /dev/null +++ b/pytensor/xtensor/rewriting/vectorization.py @@ -0,0 +1,64 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import Elemwise +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.rewriting.utils import register_xcanonicalize +from pytensor.xtensor.vectorization import XBlockwise, XElemwise + + +@register_xcanonicalize +@node_rewriter(tracks=[XElemwise]) +def lower_elemwise(fgraph, node): + out_dims = node.outputs[0].type.dims + + # Convert input XTensors to Tensors and align batch dimensions + tensor_inputs = [] + for inp in node.inputs: + inp_dims = inp.type.dims + order = [ + inp_dims.index(out_dim) if out_dim in inp_dims else "x" + for out_dim in out_dims + ] + tensor_inp = tensor_from_xtensor(inp).dimshuffle(order) + tensor_inputs.append(tensor_inp) + + tensor_outs = Elemwise(scalar_op=node.op.scalar_op)( + *tensor_inputs, return_list=True + ) + + # Convert output Tensors to XTensors + new_outs = [ + xtensor_from_tensor(tensor_out, dims=out_dims) for tensor_out in tensor_outs + ] + return new_outs + + +@register_xcanonicalize +@node_rewriter(tracks=[XBlockwise]) +def lower_blockwise(fgraph, node): + op: XBlockwise = node.op + batch_ndim = node.outputs[0].type.ndim - len(op.outputs_sig[0]) + batch_dims = node.outputs[0].type.dims[:batch_ndim] + + # Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end + tensor_inputs = [] + for inp, core_dims in zip(node.inputs, op.core_dims[0]): + inp_dims = inp.type.dims + # Align the batch dims of the input, and place the core dims on the right + batch_order = [ + inp_dims.index(batch_dim) if batch_dim in inp_dims else "x" + for batch_dim in batch_dims + ] + core_order = [inp_dims.index(core_dim) for core_dim in core_dims] + tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order) + tensor_inputs.append(tensor_inp) + + tensor_op = Blockwise(core_op=node.op.core_op, signature=op.signature) + tensor_outs = tensor_op(*tensor_inputs, return_list=True) + + # Convert output Tensors to XTensors + new_outs = [ + xtensor_from_tensor(tensor_out, dims=old_out.type.dims) + for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True) + ] + return new_outs diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 3dbb1bac6c..9005d1ff89 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -158,6 +158,109 @@ def __complex__(self): "Call `.astype(complex)` for the symbolic equivalent." ) + # Python valid overloads + def __abs__(self): + return px.math.abs(self) + + def __neg__(self): + return px.math.neg(self) + + def __lt__(self, other): + return px.math.lt(self, other) + + def __le__(self, other): + return px.math.le(self, other) + + def __gt__(self, other): + return px.math.gt(self, other) + + def __ge__(self, other): + return px.math.ge(self, other) + + def __invert__(self): + return px.math.invert(self) + + def __and__(self, other): + return px.math.and_(self, other) + + def __or__(self, other): + return px.math.or_(self, other) + + def __xor__(self, other): + return px.math.xor(self, other) + + def __rand__(self, other): + return px.math.and_(other, self) + + def __ror__(self, other): + return px.math.or_(other, self) + + def __rxor__(self, other): + return px.math.xor(other, self) + + def __add__(self, other): + return px.math.add(self, other) + + def __sub__(self, other): + return px.math.sub(self, other) + + def __mul__(self, other): + return px.math.mul(self, other) + + def __div__(self, other): + return px.math.div(self, other) + + def __pow__(self, other): + return px.math.pow(self, other) + + def __mod__(self, other): + return px.math.mod(self, other) + + def __divmod__(self, other): + return px.math.divmod(self, other) + + def __truediv__(self, other): + return px.math.true_div(self, other) + + def __floordiv__(self, other): + return px.math.floor_div(self, other) + + def __rtruediv__(self, other): + return px.math.true_div(other, self) + + def __rfloordiv__(self, other): + return px.math.floor_div(other, self) + + def __radd__(self, other): + return px.math.add(other, self) + + def __rsub__(self, other): + return px.math.sub(other, self) + + def __rmul__(self, other): + return px.math.mul(other, self) + + def __rdiv__(self, other): + return px.math.div_proxy(other, self) + + def __rmod__(self, other): + return px.math.mod(other, self) + + def __rdivmod__(self, other): + return px.math.divmod(other, self) + + def __rpow__(self, other): + return px.math.pow(other, self) + + def __ceil__(self): + return px.math.ceil(self) + + def __floor__(self): + return px.math.floor(self) + + def __trunc__(self): + return px.math.trunc(self) + # DataArray-like attributes # https://docs.xarray.dev/en/latest/api.html#id1 @property @@ -215,6 +318,11 @@ def rename(self, new_name_or_name_dict=None, **names): new_out.name = new_name return new_out + def copy(self, name: str | None = None): + out = px.math.identity(self) + out.name = name + return out + def item(self): raise NotImplementedError("item not implemented for XTensorVariable") @@ -233,6 +341,22 @@ def sel(self, *args, **kwargs): def __getitem__(self, idx): raise NotImplementedError("Indexing not yet implemnented") + # ndarray methods + # https://docs.xarray.dev/en/latest/api.html#id7 + def clip(self, min, max): + return px.math.clip(self, min, max) + + def conj(self): + return px.math.conj(self) + + @property + def imag(self): + return px.math.imag(self) + + @property + def real(self): + return px.math.real(self) + # Reshaping and reorganizing # https://docs.xarray.dev/en/latest/api.html#id8 def stack(self, dim, **dims): diff --git a/pytensor/xtensor/vectorization.py b/pytensor/xtensor/vectorization.py new file mode 100644 index 0000000000..1fe7dd99d7 --- /dev/null +++ b/pytensor/xtensor/vectorization.py @@ -0,0 +1,122 @@ +from itertools import chain + +from pytensor import scalar as ps +from pytensor.graph import Apply, Op +from pytensor.tensor import tensor +from pytensor.tensor.utils import _parse_gufunc_signature +from pytensor.xtensor.basic import XOp +from pytensor.xtensor.type import as_xtensor, xtensor + + +class XElemwise(XOp): + __props__ = ("scalar_op",) + + def __init__(self, scalar_op): + super().__init__() + self.scalar_op = scalar_op + + def make_node(self, *inputs): + inputs = [as_xtensor(inp) for inp in inputs] + if (self.scalar_op.nin != -1) and (len(inputs) != self.scalar_op.nin): + raise ValueError( + f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}" + ) + + dims_and_shape: dict[str, int | None] = {} + for inp in inputs: + for dim, dim_length in zip(inp.type.dims, inp.type.shape): + if dim not in dims_and_shape: + dims_and_shape[dim] = dim_length + elif dim_length is not None: + # Check for conflicting shapes + if (dims_and_shape[dim] is not None) and ( + dims_and_shape[dim] != dim_length + ): + raise ValueError(f"Dimension {dim} has conflicting shapes") + # Keep the non-None shape + dims_and_shape[dim] = dim_length + + if dims_and_shape: + output_dims, output_shape = zip(*dims_and_shape.items()) + else: + output_dims, output_shape = (), () + + dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs] + output_dtypes = [ + out.type.dtype for out in self.scalar_op.make_node(*dummy_scalars).outputs + ] + outputs = [ + xtensor(dtype=output_dtype, dims=output_dims, shape=output_shape) + for output_dtype in output_dtypes + ] + return Apply(self, inputs, outputs) + + +class XBlockwise(XOp): + __props__ = ("core_op", "signature", "core_dims") + + def __init__( + self, + core_op: Op, + signature: str, + core_dims: tuple[tuple[tuple[str, ...], ...], tuple[tuple[str, ...], ...]], + ): + super().__init__() + self.core_op = core_op + self.signature = signature + self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) + self.core_dims = core_dims + + def make_node(self, *inputs): + inputs = [as_xtensor(i) for i in inputs] + if len(inputs) != len(self.inputs_sig): + raise ValueError( + f"Wrong number of inputs, expected {len(self.inputs_sig)}, got {len(inputs)}" + ) + + dims_and_shape: dict[str, int | None] = {} + for inp in inputs: + for dim, dim_length in zip(inp.type.dims, inp.type.shape): + if dim not in dims_and_shape: + dims_and_shape[dim] = dim_length + elif dim_length is not None: + # Check for conflicting shapes + if (dims_and_shape[dim] is not None) and ( + dims_and_shape[dim] != dim_length + ): + raise ValueError(f"Dimension {dim} has conflicting shapes") + # Keep the non-None shape + dims_and_shape[dim] = dim_length + + core_inputs_dims, core_outputs_dims = self.core_dims + # TODO: Avoid intermediate dict + core_dims = set(chain.from_iterable(core_inputs_dims)) + batched_dims_and_shape = { + k: v for k, v in dims_and_shape.items() if k not in core_dims + } + batch_dims, batch_shape = zip(*batched_dims_and_shape.items()) + + dummy_core_inputs = [] + for inp, core_inp_dims in zip(inputs, core_inputs_dims): + try: + core_static_shape = [ + inp.type.shape[inp.type.dims.index(d)] for d in core_inp_dims + ] + except IndexError: + raise ValueError( + f"At least one core dim={core_inp_dims} missing from input {inp} with dims={inp.type.dims}" + ) + dummy_core_inputs.append( + tensor(dtype=inp.type.dtype, shape=core_static_shape) + ) + core_node = self.core_op.make_node(*dummy_core_inputs) + + outputs = [ + xtensor( + dtype=core_out.type.dtype, + shape=batch_shape + core_out.type.shape, + dims=batch_dims + core_out_dims, + ) + for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims) + ] + return Apply(self, inputs, outputs) diff --git a/tests/xtensor/test_linalg.py b/tests/xtensor/test_linalg.py new file mode 100644 index 0000000000..407867070d --- /dev/null +++ b/tests/xtensor/test_linalg.py @@ -0,0 +1,83 @@ +# ruff: noqa: E402 +import pytest + + +pytest.importorskip("xarray") +pytest.importorskip("xarray_einstats") + +import numpy as np +from xarray import DataArray +from xarray_einstats.linalg import ( + cholesky as xr_cholesky, +) +from xarray_einstats.linalg import ( + solve as xr_solve, +) + +from pytensor import function +from pytensor.xtensor.linalg import cholesky, solve +from pytensor.xtensor.type import xtensor + + +def test_cholesky(): + x = xtensor("x", dims=("a", "batch", "b"), shape=(4, 3, 4)) + y = cholesky(x, dims=["b", "a"]) + assert y.type.dims == ("batch", "b", "a") + assert y.type.shape == (3, 4, 4) + + fn = function([x], y) + rng = np.random.default_rng(25) + x_ = rng.random(size=(4, 3, 3)) + x_ = x_ @ x_.mT + x_test = DataArray(x_.transpose(1, 0, 2), dims=x.type.dims) + np.testing.assert_allclose( + fn(x_test.values), + xr_cholesky(x_test, dims=["b", "a"]).values, + ) + + +def test_solve_vector_b(): + a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1)) + b = xtensor("b", dims=("city", "planet"), shape=(None, 2)) + x = solve(a, b, dims=["country", "city"]) + assert x.type.dims == ("galaxy", "planet", "city") + assert x.type.shape == ( + 1, + 2, + None, + ) # Core Solve doesn't make use of the fact A must be square in the static shape + + fn = function([a, b], x) + + rng = np.random.default_rng(25) + a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims) + b_test = DataArray(rng.random(size=(4, 2)), dims=b.type.dims) + + np.testing.assert_allclose( + fn(a_test.values, b_test.values), + xr_solve(a_test, b_test, dims=["country", "city"]).values, + ) + + +def test_solve_matrix_b(): + a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1)) + b = xtensor("b", dims=("district", "city", "planet"), shape=(5, None, 2)) + x = solve(a, b, dims=["country", "city", "district"]) + assert x.type.dims == ("galaxy", "planet", "city", "district") + assert x.type.shape == ( + 1, + 2, + None, + 5, + ) # Core Solve doesn't make use of the fact A must be square in the static shape + + fn = function([a, b], x) + + rng = np.random.default_rng(25) + a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims) + b_test = DataArray(rng.random(size=(5, 4, 2)), dims=b.type.dims) + + np.testing.assert_allclose( + fn(a_test.values, b_test.values), + xr_solve(a_test, b_test, dims=["country", "city", "district"]).values, + ) diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py new file mode 100644 index 0000000000..cc1461afd0 --- /dev/null +++ b/tests/xtensor/test_math.py @@ -0,0 +1,109 @@ +# ruff: noqa: E402 +import pytest + + +pytest.importorskip("xarray") # + +import numpy as np +from xarray import DataArray + +from pytensor import function +from pytensor.xtensor.basic import rename +from pytensor.xtensor.math import add, exp +from pytensor.xtensor.type import xtensor +from tests.xtensor.util import xr_assert_allclose, xr_function + + +def test_scalar_case(): + x = xtensor("x", dims=(), shape=()) + y = xtensor("y", dims=(), shape=()) + out = add(x, y) + + fn = function([x, y], out) + + x_test = DataArray(2.0, dims=()) + y_test = DataArray(3.0, dims=()) + np.testing.assert_allclose(fn(x_test.values, y_test.values), 5.0) + + +def test_dimension_alignment(): + x = xtensor("x", dims=("city", "country", "planet"), shape=(2, 3, 4)) + y = xtensor( + "y", + dims=("galaxy", "country", "city"), + shape=(5, 3, 2), + ) + z = xtensor("z", dims=("universe",), shape=(1,)) + out = add(x, y, z) + assert out.type.dims == ("city", "country", "planet", "galaxy", "universe") + + fn = function([x, y, z], out) + + rng = np.random.default_rng(41) + test_x, test_y, test_z = ( + DataArray(rng.normal(size=inp.type.shape), dims=inp.type.dims) + for inp in [x, y, z] + ) + np.testing.assert_allclose( + fn(test_x.values, test_y.values, test_z.values), + (test_x + test_y + test_z).values, + ) + + +def test_renamed_dimension_alignment(): + x = xtensor("x", dims=("a", "b1", "b2"), shape=(2, 3, 3)) + y = rename(x, b1="b2", b2="b1") + z = rename(x, b2="b3") + assert y.type.dims == ("a", "b2", "b1") + assert z.type.dims == ("a", "b1", "b3") + + out1 = add(x, x) # self addition + assert out1.type.dims == ("a", "b1", "b2") + out2 = add(x, y) # transposed addition + assert out2.type.dims == ("a", "b1", "b2") + out3 = add(x, z) # outer addition + assert out3.type.dims == ("a", "b1", "b2", "b3") + + fn = xr_function([x], [out1, out2, out3]) + x_test = DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + ) + results = fn(x_test) + expected_results = [ + x_test + x_test, + x_test + x_test.rename(b1="b2", b2="b1"), + x_test + x_test.rename(b2="b3"), + ] + for result, expected_result in zip(results, expected_results): + xr_assert_allclose(result, expected_result) + + +def test_chained_operations(): + x = xtensor("x", dims=("city",), shape=(None,)) + y = xtensor("y", dims=("country",), shape=(4,)) + z = add(exp(x), exp(y)) + assert z.type.dims == ("city", "country") + assert z.type.shape == (None, 4) + + fn = function([x, y], z) + + x_test = DataArray(np.zeros(3), dims="city") + y_test = DataArray(np.ones(4), dims="country") + + np.testing.assert_allclose( + fn(x_test.values, y_test.values), + (np.exp(x_test) + np.exp(y_test)).values, + ) + + +def test_multiple_constant(): + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + out = exp(x * 2) + 2 + + fn = function([x], out) + + x_test = np.zeros((2, 3), dtype=x.type.dtype) + res = fn(x_test) + expected_res = np.exp(x_test * 2) + 2 + np.testing.assert_allclose(res, expected_res) From 2a31ecb0f5d9a037c7b131882bcb10070afb1874 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 6 Jun 2025 11:14:11 +0200 Subject: [PATCH 05/13] Implement cast for XTensorVariables --- pytensor/xtensor/math.py | 28 ++++++++++++++++++++++++++++ pytensor/xtensor/type.py | 3 +++ tests/xtensor/test_math.py | 20 +++++++++++++++++++- 3 files changed, 50 insertions(+), 1 deletion(-) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index 73ee5c8c7b..0e9d66f232 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -1,8 +1,13 @@ import inspect import sys +import numpy as np + import pytensor.scalar as ps +from pytensor import config from pytensor.scalar import ScalarOp +from pytensor.scalar.basic import _cast_mapping +from pytensor.xtensor.basic import as_xtensor from pytensor.xtensor.vectorization import XElemwise @@ -29,3 +34,26 @@ def get_all_scalar_ops(): for name, op in get_all_scalar_ops().items(): setattr(this_module, name, op) + + +_xelemwise_cast_op: dict[str, XElemwise] = {} + + +def cast(x, dtype): + if dtype == "floatX": + dtype = config.floatX + else: + dtype = np.dtype(dtype).name + + x = as_xtensor(x) + if x.type.dtype == dtype: + return x + if x.type.dtype.startswith("complex") and not dtype.startswith("complex"): + raise TypeError( + "Casting from complex to real is ambiguous: consider" + " real(), imag(), angle() or abs()" + ) + + if dtype not in _xelemwise_cast_op: + _xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype]) + return _xelemwise_cast_op[dtype](x) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 9005d1ff89..3ebe1e6771 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -323,6 +323,9 @@ def copy(self, name: str | None = None): out.name = name return out + def astype(self, dtype): + return px.math.cast(self, dtype) + def item(self): raise NotImplementedError("item not implemented for XTensorVariable") diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index cc1461afd0..f8a601f4a9 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -11,7 +11,7 @@ from pytensor.xtensor.basic import rename from pytensor.xtensor.math import add, exp from pytensor.xtensor.type import xtensor -from tests.xtensor.util import xr_assert_allclose, xr_function +from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function def test_scalar_case(): @@ -107,3 +107,21 @@ def test_multiple_constant(): res = fn(x_test) expected_res = np.exp(x_test * 2) + 2 np.testing.assert_allclose(res, expected_res) + + +def test_cast(): + x = xtensor("x", shape=(2, 3), dims=("a", "b"), dtype="float32") + yf64 = x.astype("float64") + yi16 = x.astype("int16") + ybool = x.astype("bool") + + fn = xr_function([x], [yf64, yi16, ybool]) + x_test = xr_arange_like(x) + res_f64, res_i16, res_bool = fn(x_test) + xr_assert_allclose(res_f64, x_test.astype("float64")) + xr_assert_allclose(res_i16, x_test.astype("int16")) + xr_assert_allclose(res_bool, x_test.astype("bool")) + + yc64 = x.astype("complex64") + with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"): + yc64.astype("float64") From b35b056c677560c0f9e9704f882b76b63b07f210 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 25 May 2025 22:23:10 +0200 Subject: [PATCH 06/13] Implement reduction operations for XTensorVariables --- pytensor/tensor/extra_ops.py | 18 ---- pytensor/xtensor/__init__.py | 1 + pytensor/xtensor/reduction.py | 124 ++++++++++++++++++++++++ pytensor/xtensor/rewriting/__init__.py | 1 + pytensor/xtensor/rewriting/reduction.py | 72 ++++++++++++++ pytensor/xtensor/special.py | 7 ++ pytensor/xtensor/type.py | 35 +++++++ tests/xtensor/test_reduction.py | 27 ++++++ 8 files changed, 267 insertions(+), 18 deletions(-) create mode 100644 pytensor/xtensor/reduction.py create mode 100644 pytensor/xtensor/rewriting/reduction.py create mode 100644 pytensor/xtensor/special.py create mode 100644 tests/xtensor/test_reduction.py diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 7a1bc75b0b..dc92238010 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -473,24 +473,6 @@ def cumprod(x, axis=None): return CumOp(axis=axis, mode="mul")(x) -class CumsumOp(Op): - __props__ = ("axis",) - - def __new__(typ, *args, **kwargs): - obj = object.__new__(CumOp, *args, **kwargs) - obj.mode = "add" - return obj - - -class CumprodOp(Op): - __props__ = ("axis",) - - def __new__(typ, *args, **kwargs): - obj = object.__new__(CumOp, *args, **kwargs) - obj.mode = "mul" - return obj - - def diff(x, n=1, axis=-1): """Calculate the `n`-th order discrete difference along the given `axis`. diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index d8c901c75f..78406810be 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -3,6 +3,7 @@ import pytensor.xtensor.rewriting from pytensor.xtensor import ( linalg, + special, ) from pytensor.xtensor.type import ( XTensorType, diff --git a/pytensor/xtensor/reduction.py b/pytensor/xtensor/reduction.py new file mode 100644 index 0000000000..815cac0020 --- /dev/null +++ b/pytensor/xtensor/reduction.py @@ -0,0 +1,124 @@ +from collections.abc import Sequence +from functools import partial +from types import EllipsisType + +import pytensor.scalar as ps +from pytensor.graph.basic import Apply, Variable +from pytensor.tensor.math import variadic_mul +from pytensor.xtensor.basic import XOp +from pytensor.xtensor.math import neq, sqrt +from pytensor.xtensor.math import sqr as square +from pytensor.xtensor.type import as_xtensor, xtensor + + +REDUCE_DIM = str | Sequence[str] | EllipsisType | None + + +class XReduce(XOp): + __slots__ = ("binary_op", "dims") + + def __init__(self, binary_op, dims: Sequence[str]): + super().__init__() + self.binary_op = binary_op + # Order of reduce dims doens't change the behavior of the Op + self.dims = tuple(sorted(dims)) + + def make_node(self, x: Variable) -> Apply: + x = as_xtensor(x) + x_dims = x.type.dims + x_dims_set = set(x_dims) + reduce_dims_set = set(self.dims) + if x_dims_set == reduce_dims_set: + out_dims, out_shape = [], [] + else: + if not reduce_dims_set.issubset(x_dims_set): + raise ValueError( + f"Reduced dims {self.dims} not found in array dimensions {x_dims}." + ) + out_dims, out_shape = zip( + *[ + (d, s) + for d, s in zip(x_dims, x.type.shape) + if d not in reduce_dims_set + ] + ) + output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) + return Apply(self, [x], [output]) + + +def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]: + if isinstance(dim, str): + return (dim,) + elif dim is None or dim is Ellipsis: + x = as_xtensor(x) + return x.type.dims + return dim + + +def reduce(x, dim: REDUCE_DIM = None, *, binary_op): + dims = _process_user_dims(x, dim) + return XReduce(binary_op=binary_op, dims=dims)(x) + + +sum = partial(reduce, binary_op=ps.add) +prod = partial(reduce, binary_op=ps.mul) +max = partial(reduce, binary_op=ps.scalar_maximum) +min = partial(reduce, binary_op=ps.scalar_minimum) + + +def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op): + x = as_xtensor(x) + if x.type.dtype != "bool": + x = neq(x, 0) + return reduce(x, dim=dim, binary_op=binary_op) + + +all = partial(bool_reduce, binary_op=ps.and_) +any = partial(bool_reduce, binary_op=ps.or_) + + +def _infer_reduced_size(original_var, reduced_var): + reduced_dims = reduced_var.dims + return variadic_mul( + *[size for dim, size in original_var.sizes if dim not in reduced_dims] + ) + + +def mean(x, dim: REDUCE_DIM): + x = as_xtensor(x) + sum_x = sum(x, dim) + n = _infer_reduced_size(x, sum_x) + return sum_x / n + + +def var(x, dim: REDUCE_DIM, *, ddof: int = 0): + x = as_xtensor(x) + x_mean = mean(x, dim) + n = _infer_reduced_size(x, x_mean) + return square(x - x_mean) / (n - ddof) + + +def std(x, dim: REDUCE_DIM, *, ddof: int = 0): + return sqrt(var(x, dim, ddof=ddof)) + + +class XCumReduce(XOp): + __props__ = ("binary_op", "dims") + + def __init__(self, binary_op, dims: Sequence[str]): + self.binary_op = binary_op + self.dims = tuple(sorted(dims)) # Order doesn't matter + + def make_node(self, x: Variable) -> Apply: + x = as_xtensor(x) + out = x.type() + return Apply(self, [x], [out]) + + +def cumreduce(x, dim: REDUCE_DIM, *, binary_op): + dims = _process_user_dims(x, dim) + return XCumReduce(dims=dims, binary_op=binary_op)(x) + + +cumsum = partial(cumreduce, binary_op=ps.add) +cumprod = partial(cumreduce, binary_op=ps.mul) diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index ac74ddd73d..7ce55b9256 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1,3 +1,4 @@ import pytensor.xtensor.rewriting.basic +import pytensor.xtensor.rewriting.reduction import pytensor.xtensor.rewriting.shape import pytensor.xtensor.rewriting.vectorization diff --git a/pytensor/xtensor/rewriting/reduction.py b/pytensor/xtensor/rewriting/reduction.py new file mode 100644 index 0000000000..8fc0151cb2 --- /dev/null +++ b/pytensor/xtensor/rewriting/reduction.py @@ -0,0 +1,72 @@ +from functools import partial + +import pytensor.scalar as ps +from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.tensor.extra_ops import CumOp +from pytensor.tensor.math import All, Any, CAReduce, Max, Min, Prod, Sum +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.reduction import XCumReduce, XReduce +from pytensor.xtensor.rewriting.utils import register_xcanonicalize + + +@register_xcanonicalize +@node_rewriter(tracks=[XReduce]) +def lower_reduce(fgraph, node): + [x] = node.inputs + [out] = node.outputs + x_dims = x.type.dims + reduce_dims = node.op.dims + reduce_axis = [x_dims.index(dim) for dim in reduce_dims] + + if not reduce_axis: + return [x] + + match node.op.binary_op: + case ps.add: + tensor_op_class = Sum + case ps.mul: + tensor_op_class = Prod + case ps.and_: + tensor_op_class = All + case ps.or_: + tensor_op_class = Any + case ps.scalar_maximum: + tensor_op_class = Max + case ps.scalar_minimum: + tensor_op_class = Min + case _: + # Case without known/predefined Ops + tensor_op_class = partial(CAReduce, scalar_op=node.op.binary_op) + + x_tensor = tensor_from_xtensor(x) + out_tensor = tensor_op_class(axis=reduce_axis)(x_tensor) + new_out = xtensor_from_tensor(out_tensor, out.type.dims) + return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[XCumReduce]) +def lower_cumreduce(fgraph, node): + [x] = node.inputs + x_dims = x.type.dims + reduce_dims = node.op.dims + reduce_axis = [x_dims.index(dim) for dim in reduce_dims] + + if not reduce_axis: + return [x] + + match node.op.binary_op: + case ps.add: + tensor_op_class = partial(CumOp, mode="add") + case ps.mul: + tensor_op_class = partial(CumOp, mode="mul") + case _: + # We don't know how to convert an arbitrary binary cum/reduce Op + return None + + # Each dim corresponds to an application of Cumsum/Cumprod + out_tensor = tensor_from_xtensor(x) + for axis in reduce_axis: + out_tensor = tensor_op_class(axis=axis)(out_tensor) + out = xtensor_from_tensor(out_tensor, x.type.dims) + return [out] diff --git a/pytensor/xtensor/special.py b/pytensor/xtensor/special.py new file mode 100644 index 0000000000..d5b2057145 --- /dev/null +++ b/pytensor/xtensor/special.py @@ -0,0 +1,7 @@ +from pytensor.xtensor.math import exp +from pytensor.xtensor.reduction import REDUCE_DIM + + +def softmax(x, dim: REDUCE_DIM = None): + exp_x = exp(x) + return exp_x / exp_x.sum(dim=dim) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 3ebe1e6771..40a67772a4 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -360,6 +360,41 @@ def imag(self): def real(self): return px.math.real(self) + # Aggregation + # https://docs.xarray.dev/en/latest/api.html#id6 + def all(self, dim): + return px.reduction.all(self, dim) + + def any(self, dim): + return px.reduction.any(self, dim) + + def max(self, dim): + return px.reduction.max(self, dim) + + def min(self, dim): + return px.reduction.min(self, dim) + + def mean(self, dim): + return px.reduction.mean(self, dim) + + def prod(self, dim): + return px.reduction.prod(self, dim) + + def sum(self, dim): + return px.reduction.sum(self, dim) + + def std(self, dim): + return px.reduction.std(self, dim) + + def var(self, dim): + return px.reduction.var(self, dim) + + def cumsum(self, dim): + return px.reduction.cumsum(self, dim) + + def cumprod(self, dim): + return px.reduction.cumprod(self, dim) + # Reshaping and reorganizing # https://docs.xarray.dev/en/latest/api.html#id8 def stack(self, dim, **dims): diff --git a/tests/xtensor/test_reduction.py b/tests/xtensor/test_reduction.py new file mode 100644 index 0000000000..7cc9a674f1 --- /dev/null +++ b/tests/xtensor/test_reduction.py @@ -0,0 +1,27 @@ +# ruff: noqa: E402 +import pytest + + +pytest.importorskip("xarray") + +from pytensor.xtensor.type import xtensor +from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function + + +@pytest.mark.parametrize( + "dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"] +) +@pytest.mark.parametrize( + "method", ["sum", "prod", "all", "any", "max", "min", "cumsum", "cumprod"][2:] +) +def test_reduction(method, dim): + x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7)) + out = getattr(x, method)(dim=dim) + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + + xr_assert_allclose( + fn(x_test), + getattr(x_test, method)(dim=dim), + ) From 60c7bf2991d78e443f28fd0ae2841f3eb292b9c9 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 26 May 2025 15:28:27 +0200 Subject: [PATCH 07/13] Implement concat for XTensorVariables --- pytensor/xtensor/__init__.py | 1 + pytensor/xtensor/rewriting/shape.py | 47 +++++++++++++++++++- pytensor/xtensor/shape.py | 54 +++++++++++++++++++++++ tests/xtensor/test_shape.py | 66 ++++++++++++++++++++++++++++- 4 files changed, 165 insertions(+), 3 deletions(-) diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index 78406810be..a72bf66c79 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -5,6 +5,7 @@ linalg, special, ) +from pytensor.xtensor.shape import concat from pytensor.xtensor.type import ( XTensorType, as_xtensor, diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index b2eabb5c8e..06b8c40a32 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,8 +1,8 @@ from pytensor.graph import node_rewriter -from pytensor.tensor import moveaxis +from pytensor.tensor import broadcast_to, join, moveaxis from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Stack +from pytensor.xtensor.shape import Concat, Stack @register_xcanonicalize @@ -27,3 +27,46 @@ def lower_stack(fgraph, node): new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims) return [new_out] + + +@register_xcanonicalize("shape_unsafe") +@node_rewriter(tracks=[Concat]) +def lower_concat(fgraph, node): + out_dims = node.outputs[0].type.dims + concat_dim = node.op.dim + concat_axis = out_dims.index(concat_dim) + + # Convert input XTensors to Tensors and align batch dimensions + tensor_inputs = [] + for inp in node.inputs: + inp_dims = inp.type.dims + order = [ + inp_dims.index(out_dim) if out_dim in inp_dims else "x" + for out_dim in out_dims + ] + tensor_inp = tensor_from_xtensor(inp).dimshuffle(order) + tensor_inputs.append(tensor_inp) + + # Broadcast non-concatenated dimensions of each input + non_concat_shape = [None] * len(out_dims) + for tensor_inp in tensor_inputs: + # TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime + # I'm running this as "shape_unsafe" to simplify the logic / returned graph + for i, (bcast, sh) in enumerate( + zip(tensor_inp.type.broadcastable, tensor_inp.shape) + ): + if bcast or i == concat_axis or non_concat_shape[i] is not None: + continue + non_concat_shape[i] = sh + + assert non_concat_shape.count(None) == 1 + + bcast_tensor_inputs = [] + for tensor_inp in tensor_inputs: + # We modify the concat_axis in place, as we don't need the list anywhere else + non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis] + bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape)) + + joined_tensor = join(concat_axis, *bcast_tensor_inputs) + new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) + return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 8fa0f42630..f39d495285 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -1,6 +1,8 @@ from collections.abc import Sequence +from pytensor import Variable from pytensor.graph import Apply +from pytensor.scalar import upcast from pytensor.xtensor.basic import XOp from pytensor.xtensor.type import as_xtensor, xtensor @@ -69,3 +71,55 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) ) y = Stack(new_dim_name, tuple(stacked_dims))(y) return y + + +class Concat(XOp): + __props__ = ("dim",) + + def __init__(self, dim: str): + self.dim = dim + super().__init__() + + def make_node(self, *inputs: Variable) -> Apply: + inputs = [as_xtensor(inp) for inp in inputs] + concat_dim = self.dim + + dims_and_shape: dict[str, int | None] = {} + for inp in inputs: + for dim, dim_length in zip(inp.type.dims, inp.type.shape): + if dim not in dims_and_shape: + dims_and_shape[dim] = dim_length + else: + if dim == concat_dim: + if dim_length is None: + dims_and_shape[dim] = None + elif dims_and_shape[dim] is not None: + dims_and_shape[dim] += dim_length + elif dim_length is not None: + # Check for conflicting in non-concatenated shapes + if (dims_and_shape[dim] is not None) and ( + dims_and_shape[dim] != dim_length + ): + raise ValueError( + f"Non-concatenated dimension {dim} has conflicting shapes" + ) + # Keep the non-None shape + dims_and_shape[dim] = dim_length + + if concat_dim not in dims_and_shape: + # It's a new dim, that should be located at the start + dims_and_shape = {concat_dim: len(inputs)} | dims_and_shape + elif dims_and_shape[concat_dim] is not None: + # We need to add +1 for every input that doesn't have this dimension + for inp in inputs: + if concat_dim not in inp.type.dims: + dims_and_shape[concat_dim] += 1 + + dims, shape = zip(*dims_and_shape.items()) + dtype = upcast(*[x.type.dtype for x in inputs]) + output = xtensor(dtype=dtype, dims=dims, shape=shape) + return Apply(self, inputs, [output]) + + +def concat(xtensors, dim: str): + return Concat(dim=dim)(*xtensors) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 42c8eb069d..eabae8feb8 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -6,12 +6,16 @@ from itertools import chain, combinations -from pytensor.xtensor.shape import stack +import numpy as np +from xarray import concat as xr_concat + +from pytensor.xtensor.shape import concat, stack from pytensor.xtensor.type import xtensor from tests.xtensor.util import ( xr_arange_like, xr_assert_allclose, xr_function, + xr_random_like, ) @@ -65,3 +69,63 @@ def test_multiple_stacks(): res = fn(x_test) expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d")) xr_assert_allclose(res[0], expected_res) + + +@pytest.mark.parametrize("dim", ("a", "b", "new")) +def test_concat(dim): + rng = np.random.default_rng(sum(map(ord, dim))) + + x1 = xtensor("x1", dims=("a", "b"), shape=(2, 3)) + x2 = xtensor("x2", dims=("b", "a"), shape=(3, 2)) + + x3_shape0 = 4 if dim == "a" else 2 + x3_shape1 = 5 if dim == "b" else 3 + x3 = xtensor("x3", dims=("a", "b"), shape=(x3_shape0, x3_shape1)) + + out = concat([x1, x2, x3], dim=dim) + + fn = xr_function([x1, x2, x3], out) + x1_test = xr_random_like(x1, rng) + x2_test = xr_random_like(x2, rng) + x3_test = xr_random_like(x3, rng) + + res = fn(x1_test, x2_test, x3_test) + expected_res = xr_concat([x1_test, x2_test, x3_test], dim=dim) + xr_assert_allclose(res, expected_res) + + +@pytest.mark.parametrize("dim", ("a", "b", "c", "d", "new")) +def test_concat_with_broadcast(dim): + rng = np.random.default_rng(sum(map(ord, dim)) + 1) + + x1 = xtensor("x1", dims=("a", "b"), shape=(2, 3)) + x2 = xtensor("x2", dims=("b", "c"), shape=(3, 5)) + x3 = xtensor("x3", dims=("c", "d"), shape=(5, 7)) + x4 = xtensor("x4", dims=(), shape=()) + + out = concat([x1, x2, x3, x4], dim=dim) + + fn = xr_function([x1, x2, x3, x4], out) + + x1_test = xr_random_like(x1, rng) + x2_test = xr_random_like(x2, rng) + x3_test = xr_random_like(x3, rng) + x4_test = xr_random_like(x4, rng) + res = fn(x1_test, x2_test, x3_test, x4_test) + expected_res = xr_concat([x1_test, x2_test, x3_test, x4_test], dim=dim) + xr_assert_allclose(res, expected_res) + + +def test_concat_scalar(): + x1 = xtensor("x1", dims=(), shape=()) + x2 = xtensor("x2", dims=(), shape=()) + + out = concat([x1, x2], dim="new_dim") + + fn = xr_function([x1, x2], out) + + x1_test = xr_random_like(x1) + x2_test = xr_random_like(x2) + res = fn(x1_test, x2_test) + expected_res = xr_concat([x1_test, x2_test], dim="new_dim") + xr_assert_allclose(res, expected_res) From 12ac737d6b7e31b3bf0fa3cd239cedd1998cdee6 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 28 May 2025 10:20:10 -0400 Subject: [PATCH 08/13] Implement transpose for XTensorVariables --- pytensor/xtensor/rewriting/shape.py | 18 +++++- pytensor/xtensor/shape.py | 95 ++++++++++++++++++++++++++++- pytensor/xtensor/type.py | 46 +++++++++++++- tests/xtensor/test_shape.py | 82 ++++++++++++++++++++++++- 4 files changed, 237 insertions(+), 4 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 06b8c40a32..03deb9a91c 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -2,7 +2,7 @@ from pytensor.tensor import broadcast_to, join, moveaxis from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, Stack +from pytensor.xtensor.shape import Concat, Stack, Transpose @register_xcanonicalize @@ -70,3 +70,19 @@ def lower_concat(fgraph, node): joined_tensor = join(concat_axis, *bcast_tensor_inputs) new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[Transpose]) +def lower_transpose(fgraph, node): + [x] = node.inputs + # Use the final dimensions that were already computed in make_node + out_dims = node.outputs[0].type.dims + in_dims = x.type.dims + + # Compute the permutation based on the final dimensions + perm = tuple(in_dims.index(d) for d in out_dims) + x_tensor = tensor_from_xtensor(x) + x_tensor_transposed = x_tensor.transpose(perm) + new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims) + return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index f39d495285..cc0a2a2fa6 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -1,10 +1,12 @@ +import warnings from collections.abc import Sequence +from typing import Literal from pytensor import Variable from pytensor.graph import Apply from pytensor.scalar import upcast from pytensor.xtensor.basic import XOp -from pytensor.xtensor.type import as_xtensor, xtensor +from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor class Stack(XOp): @@ -73,6 +75,97 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) return y +class Transpose(XOp): + __props__ = ("dims",) + + def __init__( + self, + dims: tuple[str | Literal[...], ...], + ): + super().__init__() + if dims.count(...) > 1: + raise ValueError("an index can only have a single ellipsis ('...')") + self.dims = dims + + def make_node(self, x): + x = as_xtensor(x) + + transpose_dims = self.dims + x_dims = x.type.dims + + if transpose_dims == () or transpose_dims == (...,): + out_dims = tuple(reversed(x_dims)) + elif ... in transpose_dims: + # Handle ellipsis expansion + ellipsis_idx = transpose_dims.index(...) + pre = transpose_dims[:ellipsis_idx] + post = transpose_dims[ellipsis_idx + 1 :] + middle = [d for d in x_dims if d not in pre + post] + out_dims = (*pre, *middle, *post) + if set(out_dims) != set(x_dims): + raise ValueError(f"{out_dims} must be a permuted list of {x_dims}") + else: + out_dims = transpose_dims + if set(out_dims) != set(x_dims): + raise ValueError( + f"{out_dims} must be a permuted list of {x_dims}, unless `...` is included" + ) + + output = xtensor( + dtype=x.type.dtype, + shape=tuple(x.type.shape[x.type.dims.index(d)] for d in out_dims), + dims=out_dims, + ) + return Apply(self, [x], [output]) + + +def transpose( + x, + *dims: str | Literal[...], + missing_dims: Literal["raise", "warn", "ignore"] = "raise", +) -> XTensorVariable: + """Transpose dimensions of the tensor. + + Parameters + ---------- + x : XTensorVariable + Input tensor to transpose. + *dims : str + Dimensions to transpose to. Can include ellipsis (...) to represent + remaining dimensions in their original order. + missing_dims : {"raise", "warn", "ignore"}, optional + How to handle dimensions that don't exist in the input tensor: + - "raise": Raise an error if any dimensions don't exist (default) + - "warn": Warn if any dimensions don't exist + - "ignore": Silently ignore any dimensions that don't exist + + Returns + ------- + XTensorVariable + Transposed tensor with reordered dimensions. + + Raises + ------ + ValueError + If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise". + """ + # Validate dimensions + x = as_xtensor(x) + all_dims = x.type.dims + invalid_dims = set(dims) - {..., *all_dims} + if invalid_dims: + if missing_dims != "ignore": + msg = f"Dimensions {invalid_dims} do not exist. Expected one or more of: {all_dims}" + if missing_dims == "raise": + raise ValueError(msg) + else: + warnings.warn(msg) + # Handle missing dimensions if not raising + dims = tuple(d for d in dims if d in all_dims or d is ...) + + return Transpose(dims)(x) + + class Concat(XOp): __props__ = ("dim",) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 40a67772a4..eb3f4c6934 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -16,7 +16,7 @@ XARRAY_AVAILABLE = False from collections.abc import Sequence -from typing import TypeVar +from typing import Literal, TypeVar import numpy as np @@ -360,6 +360,19 @@ def imag(self): def real(self): return px.math.real(self) + @property + def T(self) -> "XTensorVariable": + """Return the full transpose of the tensor. + + This is equivalent to calling transpose() with no arguments. + + Returns + ------- + XTensorVariable + Fully transposed tensor. + """ + return self.transpose() + # Aggregation # https://docs.xarray.dev/en/latest/api.html#id6 def all(self, dim): @@ -397,6 +410,37 @@ def cumprod(self, dim): # Reshaping and reorganizing # https://docs.xarray.dev/en/latest/api.html#id8 + def transpose( + self, + *dims: str | Literal[...], + missing_dims: Literal["raise", "warn", "ignore"] = "raise", + ) -> "XTensorVariable": + """Transpose dimensions of the tensor. + + Parameters + ---------- + *dims : str | Ellipsis + Dimensions to transpose. If empty, performs a full transpose. + Can use ellipsis (...) to represent remaining dimensions. + missing_dims : {"raise", "warn", "ignore"}, default="raise" + How to handle dimensions that don't exist in the tensor: + - "raise": Raise an error if any dimensions don't exist + - "warn": Warn if any dimensions don't exist + - "ignore": Silently ignore any dimensions that don't exist + + Returns + ------- + XTensorVariable + Transposed tensor with reordered dimensions. + + Raises + ------ + ValueError + If missing_dims="raise" and any dimensions don't exist. + If multiple ellipsis are provided. + """ + return px.shape.transpose(self, *dims, missing_dims=missing_dims) + def stack(self, dim, **dims): return px.shape.stack(self, dim, **dims) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index eabae8feb8..bb2e9d6158 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -4,12 +4,13 @@ pytest.importorskip("xarray") +import re from itertools import chain, combinations import numpy as np from xarray import concat as xr_concat -from pytensor.xtensor.shape import concat, stack +from pytensor.xtensor.shape import concat, stack, transpose from pytensor.xtensor.type import xtensor from tests.xtensor.util import ( xr_arange_like, @@ -28,6 +29,85 @@ def powerset(iterable, min_group_size=0): ) +def test_transpose(): + a, b, c, d, e = "abcde" + + x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11)) + permutations = [ + (a, b, c, d, e), # identity + (e, d, c, b, a), # full tranpose + (), # eqivalent to full transpose + (a, b, c, e, d), # swap last two dims + (..., d, c), # equivalent to (a, b, e, d, c) + (b, a, ..., e, d), # equivalent to (b, a, c, d, e) + (c, a, ...), # equivalent to (c, a, b, d, e) + ] + outs = [transpose(x, *perm) for perm in permutations] + + fn = xr_function([x], outs) + x_test = xr_arange_like(x) + res = fn(x_test) + expected_res = [x_test.transpose(*perm) for perm in permutations] + for outs_i, res_i, expected_res_i in zip(outs, res, expected_res): + xr_assert_allclose(res_i, expected_res_i) + + +def test_xtensor_variable_transpose(): + """Test the transpose() method of XTensorVariable.""" + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + + # Test basic transpose + out = x.transpose() + fn = xr_function([x], out) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test.transpose()) + + # Test transpose with specific dimensions + out = x.transpose("c", "a", "b") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b")) + + # Test transpose with ellipsis + out = x.transpose("c", ...) + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + # Test error cases + with pytest.raises( + ValueError, + match=re.escape( + "Dimensions {'d'} do not exist. Expected one or more of: ('a', 'b', 'c')" + ), + ): + x.transpose("d") + + with pytest.raises(ValueError, match="an index can only have a single ellipsis"): + x.transpose("a", ..., "b", ...) + + # Test missing_dims parameter + # Test ignore + out = x.transpose("c", ..., "d", missing_dims="ignore") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + # Test warn + with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"): + out = x.transpose("c", ..., "d", missing_dims="warn") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + +def test_xtensor_variable_T(): + """Test the T property of XTensorVariable.""" + # Test T property with 3D tensor + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + out = x.T + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + xr_assert_allclose(fn(x_test), x_test.T) + + def test_stack(): dims = ("a", "b", "c", "d") x = xtensor("x", dims=dims, shape=(2, 3, 5, 7)) From 010ca7933fa95a3f5cc4a67a9747c165cd8aff8f Mon Sep 17 00:00:00 2001 From: "Oriol (ProDesk)" Date: Thu, 22 May 2025 18:05:40 +0200 Subject: [PATCH 09/13] Implement unstack for XTensorVariables --- pytensor/xtensor/rewriting/shape.py | 23 +++++++- pytensor/xtensor/shape.py | 87 ++++++++++++++++++++++++++++- pytensor/xtensor/type.py | 3 + tests/xtensor/test_shape.py | 46 ++++++++++++++- 4 files changed, 155 insertions(+), 4 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 03deb9a91c..84447670c2 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,8 +1,8 @@ from pytensor.graph import node_rewriter -from pytensor.tensor import broadcast_to, join, moveaxis +from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, Stack, Transpose +from pytensor.xtensor.shape import Concat, Stack, Transpose, UnStack @register_xcanonicalize @@ -29,6 +29,25 @@ def lower_stack(fgraph, node): return [new_out] +@register_xcanonicalize +@node_rewriter(tracks=[UnStack]) +def lower_unstack(fgraph, node): + x = node.inputs[0] + unstacked_lengths = node.inputs[1:] + axis_to_unstack = x.type.dims.index(node.op.old_dim_name) + + x_tensor = tensor_from_xtensor(x) + x_tensor_transposed = moveaxis(x_tensor, source=[axis_to_unstack], destination=[-1]) + final_tensor = x_tensor_transposed.reshape( + (*x_tensor_transposed.shape[:-1], *unstacked_lengths) + ) + # Reintroduce any static shape information that was lost during the reshape + final_tensor = specify_shape(final_tensor, node.outputs[0].type.shape) + + new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims) + return [new_out] + + @register_xcanonicalize("shape_unsafe") @node_rewriter(tracks=[Concat]) def lower_concat(fgraph, node): diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index cc0a2a2fa6..38b702db84 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -4,7 +4,9 @@ from pytensor import Variable from pytensor.graph import Apply -from pytensor.scalar import upcast +from pytensor.scalar import discrete_dtypes, upcast +from pytensor.tensor import as_tensor, get_scalar_constant_value +from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.xtensor.basic import XOp from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor @@ -75,6 +77,89 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) return y +class UnStack(XOp): + __props__ = ("old_dim_name", "unstacked_dims") + + def __init__( + self, + old_dim_name: str, + unstacked_dims: tuple[str, ...], + ): + super().__init__() + if old_dim_name in unstacked_dims: + raise ValueError( + f"Dim to be unstacked {old_dim_name} can't be in {unstacked_dims}" + ) + if not unstacked_dims: + raise ValueError("Dims to unstack into can't be empty.") + if len(unstacked_dims) == 1: + raise ValueError("Only one dimension to unstack into, use rename instead") + self.old_dim_name = old_dim_name + self.unstacked_dims = unstacked_dims + + def make_node(self, x, *unstacked_length): + x = as_xtensor(x) + if self.old_dim_name not in x.type.dims: + raise ValueError( + f"Dim to unstack {self.old_dim_name} must be in {x.type.dims}" + ) + if not set(self.unstacked_dims).isdisjoint(x.type.dims): + raise ValueError( + f"Dims to unstack into {self.unstacked_dims} must not be in {x.type.dims}" + ) + + if len(unstacked_length) != len(self.unstacked_dims): + raise ValueError( + f"Number of unstacked lengths {len(unstacked_length)} must match number of unstacked dims {len(self.unstacked_dims)}" + ) + unstacked_lengths = [as_tensor(length, ndim=0) for length in unstacked_length] + if not all(length.dtype in discrete_dtypes for length in unstacked_lengths): + raise TypeError("Unstacked lengths must be discrete dtypes.") + + if x.type.ndim == 1: + batch_dims, batch_shape = (), () + else: + batch_dims, batch_shape = zip( + *( + (dim, shape) + for dim, shape in zip(x.type.dims, x.type.shape) + if dim != self.old_dim_name + ) + ) + + static_unstacked_lengths = [None] * len(unstacked_lengths) + for i, length in enumerate(unstacked_lengths): + try: + static_length = get_scalar_constant_value(length) + except NotScalarConstantError: + pass + else: + static_unstacked_lengths[i] = int(static_length) + + output = xtensor( + dtype=x.type.dtype, + shape=(*batch_shape, *static_unstacked_lengths), + dims=(*batch_dims, *self.unstacked_dims), + ) + return Apply(self, [x, *unstacked_lengths], [output]) + + +def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]): + if dim is not None: + if dims: + raise ValueError( + "Cannot use both positional dim and keyword dims in unstack" + ) + dims = dim + + y = x + for old_dim_name, unstacked_dict in dims.items(): + y = UnStack(old_dim_name, tuple(unstacked_dict.keys()))( + y, *tuple(unstacked_dict.values()) + ) + return y + + class Transpose(XOp): __props__ = ("dims",) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index eb3f4c6934..4d5aaa1a64 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -444,6 +444,9 @@ def transpose( def stack(self, dim, **dims): return px.shape.stack(self, dim, **dims) + def unstack(self, dim, **dims): + return px.shape.unstack(self, dim, **dims) + class XTensorConstantSignature(tuple): def __eq__(self, other): diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index bb2e9d6158..571f6ba131 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -8,9 +8,10 @@ from itertools import chain, combinations import numpy as np +from xarray import DataArray from xarray import concat as xr_concat -from pytensor.xtensor.shape import concat, stack, transpose +from pytensor.xtensor.shape import concat, stack, transpose, unstack from pytensor.xtensor.type import xtensor from tests.xtensor.util import ( xr_arange_like, @@ -151,6 +152,49 @@ def test_multiple_stacks(): xr_assert_allclose(res[0], expected_res) +def test_unstack_constant_size(): + x = xtensor("x", dims=("a", "bc", "d"), shape=(2, 3 * 5, 7)) + y = unstack(x, bc=dict(b=3, c=5)) + assert y.type.dims == ("a", "d", "b", "c") + assert y.type.shape == (2, 7, 3, 5) + + fn = xr_function([x], y) + + x_test = xr_arange_like(x) + x_np = x_test.values + res = fn(x_test) + expected = ( + DataArray(x_np.reshape(2, 3, 5, 7), dims=("a", "b", "c", "d")) + .stack(bc=("b", "c")) + .unstack("bc") + ) + xr_assert_allclose(res, expected) + + +def test_unstack_symbolic_size(): + x = xtensor(dims=("a", "b", "c")) + y = stack(x, bc=("b", "c")) + y = y / y.sum("bc") + z = unstack(y, bc={"b": x.sizes["b"], "c": x.sizes["c"]}) + x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 3, 5))) + fn = xr_function([x], z) + res = fn(x_test) + expected_res = x_test / x_test.sum(["b", "c"]) + xr_assert_allclose(res, expected_res) + + +def test_stack_unstack(): + x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7)) + stack_x = stack(x, bd=("b", "d")) + unstack_x = unstack(stack_x, bd=dict(b=3, d=7)) + + x_test = xr_arange_like(x) + fn = xr_function([x], unstack_x) + res = fn(x_test) + expected_res = x_test.transpose("a", "c", "b", "d") + xr_assert_allclose(res, expected_res) + + @pytest.mark.parametrize("dim", ("a", "b", "new")) def test_concat(dim): rng = np.random.default_rng(sum(map(ord, dim))) From b3deb39d6031481818ece46078b41e9ad7e51d08 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 21 May 2025 19:11:02 +0200 Subject: [PATCH 10/13] Implement index for XTensorVariables --- pytensor/xtensor/__init__.py | 1 - pytensor/xtensor/indexing.py | 186 +++++++++++++ pytensor/xtensor/rewriting/__init__.py | 1 + pytensor/xtensor/rewriting/indexing.py | 150 +++++++++++ pytensor/xtensor/type.py | 101 ++++++- tests/xtensor/test_indexing.py | 348 +++++++++++++++++++++++++ 6 files changed, 784 insertions(+), 3 deletions(-) create mode 100644 pytensor/xtensor/indexing.py create mode 100644 pytensor/xtensor/rewriting/indexing.py create mode 100644 tests/xtensor/test_indexing.py diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index a72bf66c79..06265e40de 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -7,7 +7,6 @@ ) from pytensor.xtensor.shape import concat from pytensor.xtensor.type import ( - XTensorType, as_xtensor, xtensor, xtensor_constant, diff --git a/pytensor/xtensor/indexing.py b/pytensor/xtensor/indexing.py new file mode 100644 index 0000000000..91e74017c9 --- /dev/null +++ b/pytensor/xtensor/indexing.py @@ -0,0 +1,186 @@ +# HERE LIE DRAGONS +# Useful links to make sense of all the numpy/xarray complexity +# https://numpy.org/devdocs//user/basics.indexing.html +# https://numpy.org/neps/nep-0021-advanced-indexing.html +# https://docs.xarray.dev/en/latest/user-guide/indexing.html +# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html + +from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.scalar.basic import discrete_dtypes +from pytensor.tensor.basic import as_tensor +from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice +from pytensor.xtensor.basic import XOp, xtensor_from_tensor +from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor + + +def as_idx_variable(idx, indexed_dim: str): + if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)): + raise TypeError( + "XTensors do not support indexing with None (np.newaxis), use expand_dims instead" + ) + if isinstance(idx, slice): + idx = make_slice(idx) + elif isinstance(idx, Variable) and isinstance(idx.type, SliceType): + pass + elif ( + isinstance(idx, tuple) + and len(idx) == 2 + and ( + isinstance(idx[0], str) + or ( + isinstance(idx[0], tuple | list) + and all(isinstance(d, str) for d in idx[0]) + ) + ) + ): + # Special case for ("x", array) that xarray supports + dim, idx = idx + if isinstance(idx, Variable) and isinstance(idx.type, XTensorType): + raise IndexError( + f"Giving a dimension name to an XTensorVariable indexer is not supported: {(dim, idx)}. " + "Use .rename() instead." + ) + if isinstance(dim, str): + dims = (dim,) + else: + dims = tuple(dim) + idx = as_xtensor(as_tensor(idx), dims=dims) + else: + # Must be integer / boolean indices, we already counted for None and slices + try: + idx = as_xtensor(idx) + except TypeError: + idx = as_tensor(idx) + if idx.type.ndim > 1: + # Same error that xarray raises + raise IndexError( + "Unlabeled multi-dimensional array cannot be used for indexing" + ) + # This is implicitly an XTensorVariable with dim matching the indexed one + idx = xtensor_from_tensor(idx, dims=(indexed_dim,)[: idx.type.ndim]) + + if idx.type.dtype == "bool": + if idx.type.ndim != 1: + # xarray allaws `x[True]`, but I think it is a bug: https://github.com/pydata/xarray/issues/10379 + # Otherwise, it is always restricted to 1d boolean indexing arrays + raise NotImplementedError( + "Only 1d boolean indexing arrays are supported" + ) + if idx.type.dims != (indexed_dim,): + raise IndexError( + "Boolean indexer should be unlabeled or on the same dimension to the indexed array. " + f"Indexer is on {idx.type.dims} but the target dimension is {indexed_dim}." + ) + + # Convert to nonzero indices + idx = as_xtensor(idx.values.nonzero()[0], dims=idx.type.dims) + + elif idx.type.dtype not in discrete_dtypes: + raise TypeError("Numerical indices must be integers or boolean") + return idx + + +def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None: + if dim_length is None: + return None + if isinstance(slc, Constant): + d = slc.data + start, stop, step = d.start, d.stop, d.step + elif slc.owner is None: + # It's a root variable no way of knowing what we're getting + return None + else: + # It's a MakeSliceOp + start, stop, step = slc.owner.inputs + if isinstance(start, Constant): + start = start.data + else: + return None + if isinstance(stop, Constant): + stop = stop.data + else: + return None + if isinstance(step, Constant): + step = step.data + else: + return None + return len(range(*slice(start, stop, step).indices(dim_length))) + + +class Index(XOp): + __props__ = () + + def make_node(self, x, *idxs): + x = as_xtensor(x) + + if any(idx is Ellipsis for idx in idxs): + if idxs.count(Ellipsis) > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # Convert intermediate Ellipsis to slice(None) + ellipsis_loc = idxs.index(Ellipsis) + n_implied_none_slices = x.type.ndim - (len(idxs) - 1) + idxs = ( + *idxs[:ellipsis_loc], + *((slice(None),) * n_implied_none_slices), + *idxs[ellipsis_loc + 1 :], + ) + + x_ndim = x.type.ndim + x_dims = x.type.dims + x_shape = x.type.shape + out_dims = [] + out_shape = [] + + def combine_dim_info(idx_dim, idx_dim_shape): + if idx_dim not in out_dims: + # First information about the dimension length + out_dims.append(idx_dim) + out_shape.append(idx_dim_shape) + else: + # Dim already introduced in output by a previous index + # Update static shape or raise if incompatible + out_dim_pos = out_dims.index(idx_dim) + out_dim_shape = out_shape[out_dim_pos] + if out_dim_shape is None: + # We don't know the size of the dimension yet + out_shape[out_dim_pos] = idx_dim_shape + elif idx_dim_shape is not None and idx_dim_shape != out_dim_shape: + raise IndexError( + f"Dimension of indexers mismatch for dim {idx_dim}" + ) + + if len(idxs) > x_ndim: + raise IndexError("Too many indices") + + idxs = [ + as_idx_variable(idx, dim) for idx, dim in zip(idxs, x_dims, strict=False) + ] + + for i, idx in enumerate(idxs): + if isinstance(idx.type, SliceType): + idx_dim = x_dims[i] + idx_dim_shape = get_static_slice_length(idx, x_shape[i]) + combine_dim_info(idx_dim, idx_dim_shape) + else: + if idx.type.ndim == 0: + # Scalar index, dimension is dropped + continue + + assert isinstance(idx.type, XTensorType) + + idx_dims = idx.type.dims + for idx_dim in idx_dims: + idx_dim_shape = idx.type.shape[idx_dims.index(idx_dim)] + combine_dim_info(idx_dim, idx_dim_shape) + + for dim_i, shape_i in zip(x_dims[i + 1 :], x_shape[i + 1 :]): + # Add back any unindexed dimensions + if dim_i not in out_dims: + # If the dimension was not indexed, we keep it as is + combine_dim_info(dim_i, shape_i) + + output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) + return Apply(self, [x, *idxs], [output]) + + +index = Index() diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index 7ce55b9256..a65ad0db85 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1,4 +1,5 @@ import pytensor.xtensor.rewriting.basic +import pytensor.xtensor.rewriting.indexing import pytensor.xtensor.rewriting.reduction import pytensor.xtensor.rewriting.shape import pytensor.xtensor.rewriting.vectorization diff --git a/pytensor/xtensor/rewriting/indexing.py b/pytensor/xtensor/rewriting/indexing.py new file mode 100644 index 0000000000..70f232ffb1 --- /dev/null +++ b/pytensor/xtensor/rewriting/indexing.py @@ -0,0 +1,150 @@ +from itertools import zip_longest + +from pytensor import as_symbolic +from pytensor.graph import Constant, node_rewriter +from pytensor.tensor import TensorType, arange, specify_shape +from pytensor.tensor.subtensor import _non_consecutive_adv_indexing +from pytensor.tensor.type_other import NoneTypeT, SliceType +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.indexing import Index +from pytensor.xtensor.rewriting.utils import register_xcanonicalize +from pytensor.xtensor.type import XTensorType + + +def to_basic_idx(idx): + if isinstance(idx.type, SliceType): + if isinstance(idx, Constant): + return idx.data + elif idx.owner: + # MakeSlice Op + # We transform NoneConsts to regular None so that basic Subtensor can be used if possible + return slice( + *[ + None if isinstance(i.type, NoneTypeT) else i + for i in idx.owner.inputs + ] + ) + else: + return idx + if ( + isinstance(idx.type, XTensorType) + and idx.type.ndim == 0 + and idx.type.dtype != bool + ): + return idx.values + raise TypeError("Cannot convert idx to basic idx") + + +@register_xcanonicalize +@node_rewriter(tracks=[Index]) +def lower_index(fgraph, node): + """Lower XTensorVariable indexing to regular TensorVariable indexing. + + xarray-like indexing has two modes: + 1. Orthogonal indexing: Indices of different output labeled dimensions are combined to produce all combinations of indices. + 2. Vectorized indexing: Indices of the same output labeled dimension are combined point-wise like in regular numpy advanced indexing. + + An Index Op can combine both modes. + To achieve orthogonal indexing using numpy semantics we must use multidimensional advanced indexing. + We expand the dims of each index so they are as large as the number of output dimensions, place the indices that + belong to the same output dimension in the same axis, and those that belong to different output dimensions in different axes. + + For instance to do an outer 2x2 indexing we can select x[arange(x.shape[0])[:, None], arange(x.shape[1])[None, :]], + This is a generalization of `np.ix_` that allows combining some dimensions, and not others, as well as have + indices that have more than one dimension at the start. + + In addition, xarray basic index (slices), can be vectorized with other advanced indices (if they act on the same output dimension). + However, in numpy, basic indices are always orthogonal to advanced indices. To make them behave like vectorized indices + we have to convert the slices to equivalent advanced indices. + We do this by creating an `arange` tensor that matches the shape of the dimension being indexed, + and then indexing it with the original slice. This index is then handled as a regular advanced index. + + Note: The IndexOp has only 2 types of indices: Slices and XTensorVariables. Regular array indices + are converted to the appropriate XTensorVariable by `Index.make_node` + """ + + x, *idxs = node.inputs + [out] = node.outputs + x_tensor = tensor_from_xtensor(x) + + if all( + ( + isinstance(idx.type, SliceType) + or (isinstance(idx.type, XTensorType) and idx.type.ndim == 0) + ) + for idx in idxs + ): + # Special case having just basic indexing + x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)] + + else: + # General case, we have to align the indices positionally to achieve vectorized or orthogonal indexing + # May need to convert basic indexing to advanced indexing if it acts on a dimension that is also indexed by an advanced index + x_dims = x.type.dims + x_shape = tuple(x.shape) + out_ndim = out.type.ndim + out_dims = out.type.dims + aligned_idxs = [] + basic_idx_axis = [] + # zip_longest adds the implicit slice(None) + for i, (idx, x_dim) in enumerate( + zip_longest(idxs, x_dims, fillvalue=as_symbolic(slice(None))) + ): + if isinstance(idx.type, SliceType): + if not any( + ( + isinstance(other_idx.type, XTensorType) + and x_dim in other_idx.dims + ) + for j, other_idx in enumerate(idxs) + if j != i + ): + # We can use basic indexing directly if no other index acts on this dimension + # This is an optimization that avoids creating an unnecessary arange tensor + # and facilitates the use of the specialized AdvancedSubtensor1 when possible + aligned_idxs.append(idx) + basic_idx_axis.append(out_dims.index(x_dim)) + else: + # Otherwise we need to convert the basic index into an equivalent advanced indexing + # And align it so it interacts correctly with the other advanced indices + adv_idx_equivalent = arange(x_shape[i])[to_basic_idx(idx)] + ds_order = ["x"] * out_ndim + ds_order[out_dims.index(x_dim)] = 0 + aligned_idxs.append(adv_idx_equivalent.dimshuffle(ds_order)) + else: + assert isinstance(idx.type, XTensorType) + if idx.type.ndim == 0: + # Scalar index, we can use it directly + aligned_idxs.append(idx.values) + else: + # Vector index, we need to align the indexing dimensions with the base_dims + ds_order = ["x"] * out_ndim + for j, idx_dim in enumerate(idx.dims): + ds_order[out_dims.index(idx_dim)] = j + aligned_idxs.append(idx.values.dimshuffle(ds_order)) + + # Squeeze indexing dimensions that were not used because we kept basic indexing slices + if basic_idx_axis: + aligned_idxs = [ + idx.squeeze(axis=basic_idx_axis) + if (isinstance(idx.type, TensorType) and idx.type.ndim > 0) + else idx + for idx in aligned_idxs + ] + + x_tensor_indexed = x_tensor[tuple(aligned_idxs)] + + if basic_idx_axis and _non_consecutive_adv_indexing(aligned_idxs): + # Numpy moves advanced indexing dimensions to the front when they are not consecutive + # We need to transpose them back to the expected output order + x_tensor_indexed_basic_dims = [out_dims[axis] for axis in basic_idx_axis] + x_tensor_indexed_dims = [ + dim for dim in out_dims if dim not in x_tensor_indexed_basic_dims + ] + x_tensor_indexed_basic_dims + transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims] + x_tensor_indexed = x_tensor_indexed.transpose(transpose_order) + + # Add lost shape information + x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape) + new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims) + return [new_out] diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 4d5aaa1a64..f4fef683c5 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -1,3 +1,5 @@ +import warnings + from pytensor.compile import ( DeepCopyOp, ViewOp, @@ -16,7 +18,7 @@ XARRAY_AVAILABLE = False from collections.abc import Sequence -from typing import Literal, TypeVar +from typing import Any, Literal, TypeVar import numpy as np @@ -342,7 +344,102 @@ def sel(self, *args, **kwargs): raise NotImplementedError("sel not implemented for XTensorVariable") def __getitem__(self, idx): - raise NotImplementedError("Indexing not yet implemnented") + if isinstance(idx, dict): + return self.isel(idx) + + if not isinstance(idx, tuple): + idx = (idx,) + + return px.indexing.index(self, *idx) + + def isel( + self, + indexers: dict[str, Any] | None = None, + drop: bool = False, # Unused by PyTensor + missing_dims: Literal["raise", "warn", "ignore"] = "raise", + **indexers_kwargs, + ): + if indexers_kwargs: + if indexers is not None: + raise ValueError( + "Cannot pass both indexers and indexers_kwargs to isel" + ) + indexers = indexers_kwargs + + if missing_dims not in {"raise", "warn", "ignore"}: + raise ValueError( + f"Unrecognized options {missing_dims} for missing_dims argument" + ) + + # Sort indices and pass them to index + dims = self.type.dims + indices = [slice(None)] * self.type.ndim + for key, idx in indexers.items(): + if idx is Ellipsis: + # Xarray raises a less informative error, suggesting indices must be integer + # But slices are also fine + raise TypeError("Ellipsis (...) is an invalid labeled index") + try: + indices[dims.index(key)] = idx + except IndexError: + if missing_dims == "raise": + raise ValueError( + f"Dimension {key} does not exist. Expected one of {dims}" + ) + elif missing_dims == "warn": + warnings.warn( + f"Dimension {key} does not exist. Expected one of {dims}", + UserWarning, + ) + + return px.indexing.index(self, *indices) + + def _head_tail_or_thin( + self, + indexers: dict[str, Any] | int | None, + indexers_kwargs: dict[str, Any], + *, + kind: Literal["head", "tail", "thin"], + ): + if indexers_kwargs: + if indexers is not None: + raise ValueError( + "Cannot pass both indexers and indexers_kwargs to head" + ) + indexers = indexers_kwargs + + if indexers is None: + if kind == "thin": + raise TypeError( + "thin() indexers must be either dict-like or a single integer" + ) + else: + # Default to 5 for head and tail + indexers = {dim: 5 for dim in self.type.dims} + + elif not isinstance(indexers, dict): + indexers = {dim: indexers for dim in self.type.dims} + + if kind == "head": + indices = {dim: slice(None, value) for dim, value in indexers.items()} + elif kind == "tail": + sizes = self.sizes + # Can't use slice(-value, None), in case value is zero + indices = { + dim: slice(sizes[dim] - value, None) for dim, value in indexers.items() + } + elif kind == "thin": + indices = {dim: slice(None, None, value) for dim, value in indexers.items()} + return self.isel(indices) + + def head(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="head") + + def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="tail") + + def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin") # ndarray methods # https://docs.xarray.dev/en/latest/api.html#id7 diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py new file mode 100644 index 0000000000..721fd1e695 --- /dev/null +++ b/tests/xtensor/test_indexing.py @@ -0,0 +1,348 @@ +import re + +import numpy as np +import pytest +from xarray import DataArray + +from pytensor.tensor import tensor +from pytensor.xtensor import xtensor +from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function + + +@pytest.mark.parametrize( + "indices", + [ + (0,), + (slice(1, None),), + (slice(None, -1),), + (slice(None, None, -1),), + (0, slice(None), -1, slice(1, None)), + (..., 0, -1), + (0, ..., -1), + (0, -1, ...), + ], +) +@pytest.mark.parametrize("labeled", (False, True), ids=["unlabeled", "labeled"]) +def test_basic_indexing(labeled, indices): + if ... in indices and labeled: + pytest.skip("Ellipsis not supported with labeled indexing") + + dims = ("a", "b", "c", "d") + x = xtensor(dims=dims, shape=(2, 3, 5, 7)) + + if labeled: + shufled_dims = tuple(np.random.permutation(dims)) + indices = dict(zip(shufled_dims, indices, strict=False)) + out = x[indices] + + fn = xr_function([x], out) + x_test_values = np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape( + x.type.shape + ) + x_test = DataArray(x_test_values, dims=x.type.dims) + res = fn(x_test) + expected_res = x_test[indices] + xr_assert_allclose(res, expected_res) + + +def test_single_vector_indexing_on_existing_dim(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx = tensor("idx", dtype=int, shape=(4,)) + xidx = xtensor("idx", dtype=int, shape=(4,), dims=("a",)) + + x_test = xr_arange_like(x) + idx_test = np.array([0, 1, 0, 2], dtype=int) + xidx_test = DataArray(idx_test, dims=("a",)) + + # Equivalent ways of indexing a->a + y = x[idx] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[idx_test] + xr_assert_allclose(res, expected_res) + + y = x[(("a", idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[(("a", idx_test),)] + xr_assert_allclose(res, expected_res) + + y = x[((("a",), idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[((("a",), idx_test),)] + xr_assert_allclose(res, expected_res) + + y = x[xidx] + fn = xr_function([x, xidx], y) + res = fn(x_test, xidx_test) + expected_res = x_test[xidx_test] + xr_assert_allclose(res, expected_res) + + +def test_single_vector_indexing_on_new_dim(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx = tensor("idx", dtype=int, shape=(4,)) + xidx = xtensor("idx", dtype=int, shape=(4,), dims=("new_a",)) + + x_test = xr_arange_like(x) + idx_test = np.array([0, 1, 0, 2], dtype=int) + xidx_test = DataArray(idx_test, dims=("new_a",)) + + # Equivalent ways of indexing a->new_a + y = x[(("new_a", idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[(("new_a", idx_test),)] + xr_assert_allclose(res, expected_res) + + y = x[((["new_a"], idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[((["new_a"], idx_test),)] + xr_assert_allclose(res, expected_res) + + y = x[xidx] + fn = xr_function([x, xidx], y) + res = fn(x_test, xidx_test) + expected_res = x_test[xidx_test] + xr_assert_allclose(res, expected_res) + + +def test_single_vector_indexing_interacting_with_existing_dim(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx = tensor("idx", dtype=int, shape=(4,)) + xidx = xtensor("idx", dtype=int, shape=(4,), dims=("a",)) + + x_test = xr_arange_like(x) + idx_test = np.array([0, 1, 0, 2], dtype=int) + xidx_test = DataArray(idx_test, dims=("a",)) + + # Two equivalent ways of indexing a->b + # By labeling the index on a, as "b", we cause pointwise indexing between the two dimensions. + y = x[("b", idx), 1:] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[("b", idx_test), 1:] + xr_assert_allclose(res, expected_res) + + y = x[xidx.rename(a="b"), 1:] + fn = xr_function([x, xidx], y) + res = fn(x_test, xidx_test) + expected_res = x_test[xidx_test.rename(a="b"), 1:] + xr_assert_allclose(res, expected_res) + + +@pytest.mark.parametrize( + "dims_order", + [ + ("a", "b", "ar", "br", "o"), + ("o", "br", "ar", "b", "a"), + ("a", "b", "o", "ar", "br"), + ("a", "o", "ar", "b", "br"), + ], +) +def test_multiple_vector_indexing(dims_order): + x = xtensor(dims=dims_order, shape=(5, 7, 11, 13, 17)) + idx_a = xtensor("idx_a", dtype=int, shape=(4,), dims=("a",)) + idx_b = xtensor("idx_b", dtype=int, shape=(3,), dims=("b",)) + + idxs = [slice(None)] * 5 + idxs[x.type.dims.index("a")] = idx_a + idxs[x.type.dims.index("b")] = idx_b + idxs[x.type.dims.index("ar")] = idx_a[::-1] + idxs[x.type.dims.index("br")] = idx_b[::-1] + + out = x[tuple(idxs)] + fn = xr_function([x, idx_a, idx_b], out) + + x_test = xr_arange_like(x) + idx_a_test = DataArray(np.array([0, 1, 0, 2], dtype=int), dims=("a",)) + idx_b_test = DataArray(np.array([1, 3, 0], dtype=int), dims=("b",)) + res = fn(x_test, idx_a_test, idx_b_test) + idxs_test = [slice(None)] * 5 + idxs_test[x.type.dims.index("a")] = idx_a_test + idxs_test[x.type.dims.index("b")] = idx_b_test + idxs_test[x.type.dims.index("ar")] = idx_a_test[::-1] + idxs_test[x.type.dims.index("br")] = idx_b_test[::-1] + expected_res = x_test[tuple(idxs_test)] + xr_assert_allclose(res, expected_res) + + +def test_matrix_indexing(): + x = xtensor(dims=("a", "b", "c"), shape=(3, 5, 7)) + idx_ab = xtensor("idx_ab", dtype=int, shape=(4, 2), dims=("a", "b")) + idx_cd = xtensor("idx_cd", dtype=int, shape=(4, 3), dims=("c", "d")) + + out = x[idx_ab, slice(1, 3), idx_cd] + fn = xr_function([x, idx_ab, idx_cd], out) + + x_test = xr_arange_like(x) + idx_ab_test = DataArray( + np.array([[0, 1], [1, 2], [0, 2], [-1, -2]], dtype=int), dims=("a", "b") + ) + idx_cd_test = DataArray( + np.array([[1, 2, 3], [0, 4, 5], [2, 6, -1], [3, -2, 0]], dtype=int), + dims=("c", "d"), + ) + res = fn(x_test, idx_ab_test, idx_cd_test) + expected_res = x_test[idx_ab_test, slice(1, 3), idx_cd_test] + xr_assert_allclose(res, expected_res) + + +def test_assign_multiple_out_dims(): + x = xtensor("x", shape=(5, 7), dims=("a", "b")) + idx1 = tensor("idx1", dtype=int, shape=(4, 3)) + idx2 = tensor("idx2", dtype=int, shape=(3, 2)) + out = x[(("out1", "out2"), idx1), (["out2", "out3"], idx2)] + + fn = xr_function([x, idx1, idx2], out) + + rng = np.random.default_rng() + x_test = xr_arange_like(x) + idx1_test = rng.binomial(n=4, p=0.5, size=(4, 3)) + idx2_test = rng.binomial(n=4, p=0.5, size=(3, 2)) + res = fn(x_test, idx1_test, idx2_test) + expected_res = x_test[(("out1", "out2"), idx1_test), (["out2", "out3"], idx2_test)] + xr_assert_allclose(res, expected_res) + + +def test_assign_indexer_dims_fails(): + # Test cases where the implicit naming of the indexer dimensions is not allowed. + x = xtensor("x", shape=(5, 7), dims=("a", "b")) + idx1 = xtensor("idx1", dtype=int, shape=(4,), dims=("c",)) + + with pytest.raises( + IndexError, + match=re.escape( + "Giving a dimension name to an XTensorVariable indexer is not supported: ('d', idx1). " + "Use .rename() instead." + ), + ): + x[("d", idx1),] + + with pytest.raises( + IndexError, + match=re.escape( + "Boolean indexer should be unlabeled or on the same dimension to the indexed array. " + "Indexer is on ('c',) but the target dimension is a." + ), + ): + x[idx1.astype("bool")] + + +class TestVectorizedIndexingNotAllowedToBroadcast: + def test_compile_time_error(self): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx_a = xtensor("idx_a", dtype=int, shape=(4,), dims=("b",)) + idx_b = xtensor("idx_b", dtype=int, shape=(1,), dims=("b",)) + with pytest.raises( + IndexError, match="Dimension of indexers mismatch for dim b" + ): + x[idx_a, idx_b] + + @pytest.mark.xfail( + reason="Check that lowered indexing is not allowed to broadcast not implemented yet" + ) + def test_runtime_error(self): + """ + Test that, unlike in numpy, indices with different shapes cannot act on the same dimension, + even if the shapes could broadcast as per numpy semantics. + """ + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx_a = xtensor("idx_a", dtype=int, shape=(None,), dims=("b",)) + idx_b = xtensor("idx_b", dtype=int, shape=(None,), dims=("b",)) + out = x[idx_a, idx_b] + + fn = xr_function([x, idx_a, idx_b], out) + + x_test = xr_arange_like(x) + valid_idx_a_test = DataArray(np.array([0], dtype=int), dims=("b",)) + idx_b_test = DataArray(np.array([1], dtype=int), dims=("b",)) + xr_assert_allclose( + fn(x_test, valid_idx_a_test, idx_b_test), + x_test[valid_idx_a_test, idx_b_test], + ) + + invalid_idx_a_test = DataArray(np.array([0, 1, 0, 1], dtype=int), dims=("b",)) + with pytest.raises(ValueError): + fn(x_test, invalid_idx_a_test, idx_b_test) + + +@pytest.mark.parametrize( + "dims_order", + [ + ("a", "b", "c", "d"), + ("d", "c", "b", "a"), + ("c", "a", "b", "d"), + ], +) +def test_scalar_integer_indexing(dims_order): + x = xtensor(dims=dims_order, shape=(3, 5, 7, 11)) + scalar_idx = xtensor("scalar_idx", dtype=int, shape=(), dims=()) + vec_idx1 = xtensor("vec_idx", dtype=int, shape=(4,), dims=("a",)) + vec_idx2 = xtensor("vec_idx2", dtype=int, shape=(4,), dims=("c",)) + + idxs = [None] * 4 + idxs[x.type.dims.index("a")] = scalar_idx + idxs[x.type.dims.index("b")] = vec_idx1 + idxs[x.type.dims.index("c")] = vec_idx2 + idxs[x.type.dims.index("d")] = -scalar_idx + out1 = x[tuple(idxs)] + + idxs[x.type.dims.index("a")] = vec_idx1.rename(a="c") + out2 = x[tuple(idxs)] + + fn = xr_function([x, scalar_idx, vec_idx1, vec_idx2], (out1, out2)) + + x_test = xr_arange_like(x) + scalar_idx_test = DataArray(np.array(1, dtype=int), dims=()) + vec_idx_test1 = DataArray(np.array([0, 1, 0, 2], dtype=int), dims=("a",)) + vec_idx_test2 = DataArray(np.array([0, 2, 2, 1], dtype=int), dims=("c",)) + res1, res2 = fn(x_test, scalar_idx_test, vec_idx_test1, vec_idx_test2) + idxs = [None] * 4 + idxs[x.type.dims.index("a")] = scalar_idx_test + idxs[x.type.dims.index("b")] = vec_idx_test1 + idxs[x.type.dims.index("c")] = vec_idx_test2 + idxs[x.type.dims.index("d")] = -scalar_idx_test + expected_res1 = x_test[tuple(idxs)] + idxs[x.type.dims.index("a")] = vec_idx_test1.rename(a="c") + expected_res2 = x_test[tuple(idxs)] + xr_assert_allclose(res1, expected_res1) + xr_assert_allclose(res2, expected_res2) + + +def test_unsupported_boolean_indexing(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + + mat_idx = xtensor("idx", dtype=bool, shape=(4, 2), dims=("a", "b")) + scalar_idx = mat_idx.isel(a=0, b=1) + + for idx in (mat_idx, scalar_idx, scalar_idx.values): + with pytest.raises( + NotImplementedError, + match="Only 1d boolean indexing arrays are supported", + ): + x[idx] + + +def test_boolean_indexing(): + x = xtensor("x", shape=(8, 7), dims=("a", "b")) + bool_idx = xtensor("bool_idx", dtype=bool, shape=(8,), dims=("a",)) + int_idx = xtensor("int_idx", dtype=int, shape=(4, 3), dims=("a", "new_dim")) + + out_vectorized = x[bool_idx, int_idx] + out_orthogonal = x[bool_idx, int_idx.rename(a="b")] + fn = xr_function([x, bool_idx, int_idx], [out_vectorized, out_orthogonal]) + + x_test = xr_arange_like(x) + bool_idx_test = DataArray(np.array([True, False] * 4, dtype=bool), dims=("a",)) + int_idx_test = DataArray( + np.random.binomial(n=4, p=0.5, size=(4, 3)), + dims=("a", "new_dim"), + ) + res1, res2 = fn(x_test, bool_idx_test, int_idx_test) + expected_res1 = x_test[bool_idx_test, int_idx_test] + expected_res2 = x_test[bool_idx_test, int_idx_test.rename(a="b")] + xr_assert_allclose(res1, expected_res1) + xr_assert_allclose(res2, expected_res2) From fa87a384a54f99aa54eda3fd16c3184469fd1021 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 2 Jun 2025 11:27:11 +0200 Subject: [PATCH 11/13] Implement index update for XTensorVariables --- pytensor/xtensor/indexing.py | 33 ++++++ pytensor/xtensor/rewriting/indexing.py | 82 ++++++++++++-- pytensor/xtensor/type.py | 28 ++++- tests/xtensor/test_indexing.py | 142 ++++++++++++++++++++++++- 4 files changed, 272 insertions(+), 13 deletions(-) diff --git a/pytensor/xtensor/indexing.py b/pytensor/xtensor/indexing.py index 91e74017c9..01517db55d 100644 --- a/pytensor/xtensor/indexing.py +++ b/pytensor/xtensor/indexing.py @@ -4,6 +4,7 @@ # https://numpy.org/neps/nep-0021-advanced-indexing.html # https://docs.xarray.dev/en/latest/user-guide/indexing.html # https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html +from typing import Literal from pytensor.graph.basic import Apply, Constant, Variable from pytensor.scalar.basic import discrete_dtypes @@ -184,3 +185,35 @@ def combine_dim_info(idx_dim, idx_dim_shape): index = Index() + + +class IndexUpdate(XOp): + __props__ = ("mode",) + + def __init__(self, mode: Literal["set", "inc"]): + if mode not in ("set", "inc"): + raise ValueError("mode must be 'set' or 'inc'") + self.mode = mode + + def make_node(self, x, y, *idxs): + # Call Index on (x, *idxs) to process inputs and infer output type + x_view_node = index.make_node(x, *idxs) + x, *idxs = x_view_node.inputs + [x_view] = x_view_node.outputs + + try: + y = as_xtensor(y) + except TypeError: + y = as_xtensor(as_tensor(y), dims=x_view.type.dims) + + if not set(y.type.dims).issubset(x_view.type.dims): + raise ValueError( + f"Value dimensions {y.type.dims} must be a subset of the indexed dimensions {x_view.type.dims}" + ) + + out = x.type() + return Apply(self, [x, y, *idxs], [out]) + + +index_assignment = IndexUpdate("set") +index_increment = IndexUpdate("inc") diff --git a/pytensor/xtensor/rewriting/indexing.py b/pytensor/xtensor/rewriting/indexing.py index 70f232ffb1..6b0b650848 100644 --- a/pytensor/xtensor/rewriting/indexing.py +++ b/pytensor/xtensor/rewriting/indexing.py @@ -3,10 +3,10 @@ from pytensor import as_symbolic from pytensor.graph import Constant, node_rewriter from pytensor.tensor import TensorType, arange, specify_shape -from pytensor.tensor.subtensor import _non_consecutive_adv_indexing +from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor from pytensor.tensor.type_other import NoneTypeT, SliceType from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor -from pytensor.xtensor.indexing import Index +from pytensor.xtensor.indexing import Index, IndexUpdate, index from pytensor.xtensor.rewriting.utils import register_xcanonicalize from pytensor.xtensor.type import XTensorType @@ -35,9 +35,7 @@ def to_basic_idx(idx): raise TypeError("Cannot convert idx to basic idx") -@register_xcanonicalize -@node_rewriter(tracks=[Index]) -def lower_index(fgraph, node): +def _lower_index(node): """Lower XTensorVariable indexing to regular TensorVariable indexing. xarray-like indexing has two modes: @@ -59,12 +57,18 @@ def lower_index(fgraph, node): We do this by creating an `arange` tensor that matches the shape of the dimension being indexed, and then indexing it with the original slice. This index is then handled as a regular advanced index. - Note: The IndexOp has only 2 types of indices: Slices and XTensorVariables. Regular array indices - are converted to the appropriate XTensorVariable by `Index.make_node` + Finally, the location of views resulting from advanced indices follows two distinct behaviors in numpy. + When all advanced indices are consecutive, the respective view is located in the "original" location. + However, if advanced indices are separated by basic indices (slices in our case), the output views + always show up at the front of the array. This information is returned as the second output of this function, + which labels the final position of the indexed dimensions under this rule. """ + assert isinstance(node.op, Index) + x, *idxs = node.inputs [out] = node.outputs + x_tensor_indexed_dims = out.type.dims x_tensor = tensor_from_xtensor(x) if all( @@ -141,10 +145,68 @@ def lower_index(fgraph, node): x_tensor_indexed_dims = [ dim for dim in out_dims if dim not in x_tensor_indexed_basic_dims ] + x_tensor_indexed_basic_dims - transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims] - x_tensor_indexed = x_tensor_indexed.transpose(transpose_order) + + return x_tensor_indexed, x_tensor_indexed_dims + + +@register_xcanonicalize +@node_rewriter(tracks=[Index]) +def lower_index(fgraph, node): + """Lower XTensorVariable indexing to regular TensorVariable indexing. + + The bulk of the work is done by `_lower_index`, except for special logic to control the + location of non-consecutive advanced indices, and to preserve static shape information. + """ + + [out] = node.outputs + out_dims = out.type.dims + + x_tensor_indexed, x_tensor_indexed_dims = _lower_index(node) + if x_tensor_indexed_dims != out_dims: + # Numpy moves advanced indexing dimensions to the front when they are not consecutive + # We need to transpose them back to the expected output order + transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims] + x_tensor_indexed = x_tensor_indexed.transpose(transpose_order) # Add lost shape information x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape) - new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims) + + new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.dims) + return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[IndexUpdate]) +def lower_index_update(fgraph, node): + """Lower XTensorVariable index update to regular TensorVariable indexing update. + + This rewrite requires converting the index view to a tensor-based equivalent expression, + just like `lower_index`. It then requires aligning the dimensions of y with the + dimensions of the index view, with special care for non-consecutive dimensions being + pulled to the front axis according to numpy rules. + """ + x, y, *idxs = node.inputs + + # Lower the indexing part first + indexed_node = index.make_node(x, *idxs) + x_tensor_indexed, x_tensor_indexed_dims = _lower_index(indexed_node) + y_tensor = tensor_from_xtensor(y) + + # Align dimensions of y with those of the indexed tensor x + y_dims = y.type.dims + y_dims_set = set(y_dims) + y_order = tuple( + y_dims.index(x_dim) if x_dim in y_dims_set else "x" + for x_dim in x_tensor_indexed_dims + ) + # Remove useless left expand_dims + while len(y_order) > 0 and y_order[0] == "x": + y_order = y_order[1:] + if y_order != tuple(range(y_tensor.type.ndim)): + y_tensor = y_tensor.dimshuffle(y_order) + + x_tensor_updated = inc_subtensor( + x_tensor_indexed, y_tensor, set_instead_of_inc=node.op.mode == "set" + ) + new_out = xtensor_from_tensor(x_tensor_updated, dims=x.type.dims) return [new_out] diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index f4fef683c5..bfe2ca63c4 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -333,8 +333,10 @@ def item(self): # Indexing # https://docs.xarray.dev/en/latest/api.html#id2 - def __setitem__(self, key, value): - raise TypeError("XTensorVariable does not support item assignment.") + def __setitem__(self, idx, value): + raise TypeError( + "XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead." + ) @property def loc(self): @@ -394,6 +396,28 @@ def isel( return px.indexing.index(self, *indices) + def set(self, value): + if not ( + self.owner is not None and isinstance(self.owner.op, px.indexing.Index) + ): + raise ValueError( + f"set can only be called on the output of an index (or isel) operation. Self is the result of {self.owner}" + ) + + x, *idxs = self.owner.inputs + return px.indexing.index_assignment(x, value, *idxs) + + def inc(self, value): + if not ( + self.owner is not None and isinstance(self.owner.op, px.indexing.Index) + ): + raise ValueError( + f"inc can only be called on the output of an index (or isel) operation. Self is the result of {self.owner}" + ) + + x, *idxs = self.owner.inputs + return px.indexing.index_increment(x, value, *idxs) + def _head_tail_or_thin( self, indexers: dict[str, Any] | int | None, diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py index 721fd1e695..c7d8572bdc 100644 --- a/tests/xtensor/test_indexing.py +++ b/tests/xtensor/test_indexing.py @@ -6,7 +6,12 @@ from pytensor.tensor import tensor from pytensor.xtensor import xtensor -from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function +from tests.xtensor.util import ( + xr_arange_like, + xr_assert_allclose, + xr_function, + xr_random_like, +) @pytest.mark.parametrize( @@ -346,3 +351,138 @@ def test_boolean_indexing(): expected_res2 = x_test[bool_idx_test, int_idx_test.rename(a="b")] xr_assert_allclose(res1, expected_res1) xr_assert_allclose(res2, expected_res2) + + +@pytest.mark.parametrize("mode", ("set", "inc")) +def test_basic_index_update(mode): + x = xtensor("x", shape=(11, 7), dims=("a", "b")) + y = xtensor("y", shape=(7, 5), dims=("a", "b")) + x_indexed = x[2:-2, 2:] + update_method = getattr(x_indexed, mode) + + x_updated = [ + update_method(y), + update_method(y.T), + update_method(y.isel(a=-1)), + update_method(y.isel(b=-1)), + update_method(y.isel(a=-2, b=-2)), + ] + + fn = xr_function([x, y], x_updated) + x_test = xr_random_like(x) + y_test = xr_random_like(y) + results = fn(x_test, y_test) + + def update_fn(y): + x = x_test.copy() + if mode == "set": + x[2:-2, 2:] = y + elif mode == "inc": + x[2:-2, 2:] += y + return x + + expected_results = [ + update_fn(y_test), + update_fn(y_test.T), + update_fn(y_test.isel(a=-1)), + update_fn(y_test.isel(b=-1)), + update_fn(y_test.isel(a=-2, b=-2)), + ] + for result, expected_result in zip(results, expected_results): + xr_assert_allclose(result, expected_result) + + +@pytest.mark.parametrize("mode", ("set", "inc")) +@pytest.mark.parametrize("idx_dtype", (int, bool)) +def test_adv_index_update(mode, idx_dtype): + x = xtensor("x", shape=(5, 5), dims=("a", "b")) + y = xtensor("y", shape=(3,), dims=("b",)) + idx = xtensor("idx", dtype=idx_dtype, shape=(None,), dims=("a",)) + + orthogonal_update1 = getattr(x[idx, -3:], mode)(y) + orthogonal_update2 = getattr(x[idx, -3:], mode)(y.rename(b="a")) + if idx_dtype is not bool: + # Vectorized booling indexing/update is not allowed + vectorized_update = getattr(x[idx.rename(a="b"), :3], mode)(y) + else: + with pytest.raises( + IndexError, + match="Boolean indexer should be unlabeled or on the same dimension to the indexed array.", + ): + getattr(x[idx.rename(a="b"), :3], mode)(y) + vectorized_update = x + + outs = [orthogonal_update1, orthogonal_update2, vectorized_update] + + fn = xr_function([x, idx, y], outs) + x_test = xr_random_like(x) + y_test = xr_random_like(y) + if idx_dtype is int: + idx_test = DataArray([0, 1, 2], dims=("a",)) + else: + idx_test = DataArray([True, False, True, True, False], dims=("a",)) + results = fn(x_test, idx_test, y_test) + + def update_fn(x, idx, y): + x = x.copy() + if mode == "set": + x[idx] = y + else: + x[idx] += y + return x + + expected_results = [ + update_fn(x_test, (idx_test, slice(-3, None)), y_test), + update_fn( + x_test, + (idx_test, slice(-3, None)), + y_test.rename(b="a"), + ), + update_fn(x_test, (idx_test.rename(a="b"), slice(None, 3)), y_test) + if idx_dtype is not bool + else x_test, + ] + for result, expected_result in zip(results, expected_results): + xr_assert_allclose(result, expected_result) + + +@pytest.mark.parametrize("mode", ("set", "inc")) +def test_non_consecutive_idx_update(mode): + x = xtensor("x", shape=(2, 3, 5, 7), dims=("a", "b", "c", "d")) + y = xtensor("y", shape=(5, 4), dims=("c", "b")) + x_indexed = x[:, [0, 1, 2, 2], :, ("b", [0, 1, 1, 2])] + out = getattr(x_indexed, mode)(y) + + fn = xr_function([x, y], out) + x_test = xr_random_like(x) + y_test = xr_random_like(y) + + result = fn(x_test, y_test) + expected_result = x_test.copy() + # xarray fails inplace operation with the "tuple trick" + # https://github.com/pydata/xarray/issues/10387 + d_indexer = DataArray([0, 1, 1, 2], dims=("b",)) + if mode == "set": + expected_result[:, [0, 1, 2, 2], :, d_indexer] = y_test + else: + expected_result[:, [0, 1, 2, 2], :, d_indexer] += y_test + xr_assert_allclose(result, expected_result) + + +def test_indexing_renames_into_update_variable(): + x = xtensor("x", shape=(5, 5), dims=("a", "b")) + y = xtensor("y", shape=(3,), dims=("d",)) + idx = xtensor("idx", dtype=int, shape=(None,), dims=("d",)) + + # define "d" dimension by slicing the "a" dimension so we can set y into x + orthogonal_update1 = x[idx].set(y) + fn = xr_function([x, idx, y], orthogonal_update1) + + x_test = np.abs(xr_random_like(x)) + y_test = -np.abs(xr_random_like(y)) + idx_test = DataArray([0, 2, 3], dims=("d",)) + + result = fn(x_test, idx_test, y_test) + expected_result = x_test.copy() + expected_result[idx_test] = y_test + xr_assert_allclose(result, expected_result) From 9aa755ad89ec61ef17d7b60f8f58f32665b7e88a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 26 May 2025 17:38:05 +0200 Subject: [PATCH 12/13] Implement diff for XTensorVariables --- pytensor/xtensor/type.py | 9 +++++++++ tests/xtensor/test_indexing.py | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index bfe2ca63c4..bed71d19b3 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -529,6 +529,15 @@ def cumsum(self, dim): def cumprod(self, dim): return px.reduction.cumprod(self, dim) + def diff(self, dim, n=1): + """Compute the n-th discrete difference along the given dimension.""" + slice1 = {dim: slice(1, None)} + slice2 = {dim: slice(None, -1)} + x = self + for _ in range(n): + x = x[slice1] - x[slice2] + return x + # Reshaping and reorganizing # https://docs.xarray.dev/en/latest/api.html#id8 def transpose( diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py index c7d8572bdc..fdafd14220 100644 --- a/tests/xtensor/test_indexing.py +++ b/tests/xtensor/test_indexing.py @@ -486,3 +486,22 @@ def test_indexing_renames_into_update_variable(): expected_result = x_test.copy() expected_result[idx_test] = y_test xr_assert_allclose(result, expected_result) + + +@pytest.mark.parametrize("n", ["implicit", 1, 2]) +@pytest.mark.parametrize("dim", ["a", "b"]) +def test_diff(dim, n): + x = xtensor(dims=("a", "b"), shape=(7, 11)) + if n == "implicit": + out = x.diff(dim) + else: + out = x.diff(dim, n=n) + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + res = fn(x_test) + if n == "implicit": + expected_res = x_test.diff(dim) + else: + expected_res = x_test.diff(dim, n=n) + xr_assert_allclose(res, expected_res) From d196fc136337d1884a8ebf0b0f98c792ad680084 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Fri, 6 Jun 2025 11:35:21 -0400 Subject: [PATCH 13/13] Adding xdot for labeled tensors Adding rewrite Lint --- pytensor/xtensor/math.py | 118 ++++++++++++++++++++++++- pytensor/xtensor/rewriting/__init__.py | 1 + pytensor/xtensor/rewriting/math.py | 40 +++++++++ pytensor/xtensor/type.py | 4 + tests/xtensor/test_math.py | 31 +++++++ 5 files changed, 192 insertions(+), 2 deletions(-) create mode 100644 pytensor/xtensor/rewriting/math.py diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index 0e9d66f232..d6af4ea378 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -5,9 +5,11 @@ import pytensor.scalar as ps from pytensor import config +from pytensor.graph.basic import Apply from pytensor.scalar import ScalarOp -from pytensor.scalar.basic import _cast_mapping -from pytensor.xtensor.basic import as_xtensor +from pytensor.scalar.basic import _cast_mapping, upcast +from pytensor.xtensor.basic import XOp, as_xtensor +from pytensor.xtensor.type import xtensor from pytensor.xtensor.vectorization import XElemwise @@ -57,3 +59,115 @@ def cast(x, dtype): if dtype not in _xelemwise_cast_op: _xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype]) return _xelemwise_cast_op[dtype](x) + + +class XDot(XOp): + """Matrix multiplication between two XTensorVariables. + + This operation performs matrix multiplication between two tensors, automatically + aligning and contracting dimensions. The behavior matches xarray's dot operation. + + Parameters + ---------- + dims : tuple of str + The dimensions to contract over. If None, will contract over all matching dimensions. + """ + + __props__ = ("dims",) + + def __init__(self, dims: tuple[str, ...] | None = None): + self.dims = dims + super().__init__() + + def make_node(self, x, y): + x = as_xtensor(x) + y = as_xtensor(y) + + # Get dimensions to contract + if self.dims is None: + # Contract over all matching dimensions + x_dims = set(x.type.dims) + y_dims = set(y.type.dims) + contract_dims = tuple(x_dims & y_dims) + else: + contract_dims = self.dims + + # Determine output dimensions and shapes + x_dims = list(x.type.dims) + y_dims = list(y.type.dims) + x_shape = list(x.type.shape) + y_shape = list(y.type.shape) + + # Remove contracted dimensions + for dim in contract_dims: + x_idx = x_dims.index(dim) + y_idx = y_dims.index(dim) + x_dims.pop(x_idx) + y_dims.pop(y_idx) + x_shape.pop(x_idx) + y_shape.pop(y_idx) + + # Combine remaining dimensions + out_dims = tuple(x_dims + y_dims) + out_shape = tuple(x_shape + y_shape) + + # Determine output dtype + out_dtype = upcast(x.type.dtype, y.type.dtype) + + out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims) + return Apply(self, [x, y], [out]) + + +def dot(x, y, dims: tuple[str, ...] | None = None): + """Matrix multiplication between two XTensorVariables. + + This operation performs matrix multiplication between two tensors, automatically + aligning and contracting dimensions. The behavior matches xarray's dot operation. + + Parameters + ---------- + x : XTensorVariable + First input tensor + y : XTensorVariable + Second input tensor + dims : tuple of str, optional + The dimensions to contract over. If None, will contract over all matching dimensions. + + Returns + ------- + XTensorVariable + The result of the matrix multiplication. + + Examples + -------- + >>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3)) + >>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4)) + >>> z = dot(x, y) # Result has dimensions ("a", "c") + """ + x = as_xtensor(x) + y = as_xtensor(y) + + # Validate dimensions if specified + if dims is not None: + if not isinstance(dims, tuple): + dims = tuple(dims) + for dim in dims: + if dim not in x.type.dims: + raise ValueError( + f"Dimension {dim} not found in first input {x.type.dims}" + ) + if dim not in y.type.dims: + raise ValueError( + f"Dimension {dim} not found in second input {y.type.dims}" + ) + # Check for compatible shapes in contracted dimensions + x_idx = x.type.dims.index(dim) + y_idx = y.type.dims.index(dim) + x_size = x.type.shape[x_idx] + y_size = y.type.shape[y_idx] + if x_size is not None and y_size is not None and x_size != y_size: + raise ValueError( + f"Dimension {dim} has incompatible shapes: {x_size} and {y_size}" + ) + + return XDot(dims=dims)(x, y) diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index a65ad0db85..bdbb30f147 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1,5 +1,6 @@ import pytensor.xtensor.rewriting.basic import pytensor.xtensor.rewriting.indexing +import pytensor.xtensor.rewriting.math import pytensor.xtensor.rewriting.reduction import pytensor.xtensor.rewriting.shape import pytensor.xtensor.rewriting.vectorization diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py new file mode 100644 index 0000000000..069078bd61 --- /dev/null +++ b/pytensor/xtensor/rewriting/math.py @@ -0,0 +1,40 @@ +from pytensor.graph import node_rewriter +from pytensor.tensor import tensordot +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.math import XDot +from pytensor.xtensor.rewriting.utils import register_xcanonicalize + + +@register_xcanonicalize +@node_rewriter(tracks=[XDot]) +def lower_dot(fgraph, node): + """Rewrite XDot to tensor.dot. + + This rewrite converts an XDot operation to a tensor-based dot operation, + handling dimension alignment and contraction. + """ + [x, y] = node.inputs + [out] = node.outputs + + # Convert inputs to tensors + x_tensor = tensor_from_xtensor(x) + y_tensor = tensor_from_xtensor(y) + + # Get dimensions to contract + if node.op.dims is None: + # Contract over all matching dimensions + x_dims = set(x.type.dims) + y_dims = set(y.type.dims) + contract_dims = tuple(x_dims & y_dims) + else: + contract_dims = node.op.dims + + # Get axes to contract for each input + x_axes = [x.type.dims.index(dim) for dim in contract_dims] + y_axes = [y.type.dims.index(dim) for dim in contract_dims] + + # Perform dot product + out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes)) + + # Convert back to xtensor + return [xtensor_from_tensor(out_tensor, out.type.dims)] diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index bed71d19b3..158d6a7e0c 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -577,6 +577,10 @@ def stack(self, dim, **dims): def unstack(self, dim, **dims): return px.shape.unstack(self, dim, **dims) + def dot(self, other, dims=None): + """Matrix multiplication with another XTensorVariable, contracting over matching or specified dims.""" + return px.math.dot(self, other, dims=dims) + class XTensorConstantSignature(tuple): def __eq__(self, other): diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index f8a601f4a9..13d5a25a71 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -125,3 +125,34 @@ def test_cast(): yc64 = x.astype("complex64") with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"): yc64.astype("float64") + + +def test_dot(): + """Test basic dot product operations.""" + # Test matrix-matrix dot product + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b", "c"), shape=(3, 4)) + z = x.dot(y) + assert z.type.dims == ("a", "c") + assert z.type.shape == (2, 4) + + fn = xr_function([x, y], z) + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) + y_test = DataArray(np.ones((3, 4)), dims=("b", "c")) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test) + xr_assert_allclose(z_test, expected) + + # Test matrix-vector dot product + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b",), shape=(3,)) + z = x.dot(y) + assert z.type.dims == ("a",) + assert z.type.shape == (2,) + + fn = xr_function([x, y], z) + x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) + y_test = DataArray(np.ones(3), dims=("b",)) + z_test = fn(x_test, y_test) + expected = x_test.dot(y_test) + xr_assert_allclose(z_test, expected)