From 7bbf7225e684011901ae61196353c4a66813a014 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 1 Jul 2025 16:21:17 +0200 Subject: [PATCH 1/3] Refactor `lower_aligned` helper --- pytensor/xtensor/rewriting/shape.py | 11 +----- pytensor/xtensor/rewriting/utils.py | 12 ++++++ pytensor/xtensor/rewriting/vectorization.py | 44 ++++++--------------- 3 files changed, 25 insertions(+), 42 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index c0b1a5fe88..a80cca8b96 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -9,6 +9,7 @@ ) from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_lower_xtensor +from pytensor.xtensor.rewriting.utils import lower_aligned from pytensor.xtensor.shape import ( Concat, ExpandDims, @@ -70,15 +71,7 @@ def lower_concat(fgraph, node): concat_axis = out_dims.index(concat_dim) # Convert input XTensors to Tensors and align batch dimensions - tensor_inputs = [] - for inp in node.inputs: - inp_dims = inp.type.dims - order = [ - inp_dims.index(out_dim) if out_dim in inp_dims else "x" - for out_dim in out_dims - ] - tensor_inp = tensor_from_xtensor(inp).dimshuffle(order) - tensor_inputs.append(tensor_inp) + tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs] # Broadcast non-concatenated dimensions of each input non_concat_shape = [None] * len(out_dims) diff --git a/pytensor/xtensor/rewriting/utils.py b/pytensor/xtensor/rewriting/utils.py index f21747c2e6..43c60df370 100644 --- a/pytensor/xtensor/rewriting/utils.py +++ b/pytensor/xtensor/rewriting/utils.py @@ -1,7 +1,12 @@ +import typing +from collections.abc import Sequence + from pytensor.compile import optdb from pytensor.graph.rewriting.basic import NodeRewriter, in2out from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase from pytensor.tensor.rewriting.ofg import inline_ofg_expansion +from pytensor.tensor.variable import TensorVariable +from pytensor.xtensor.type import XTensorVariable lower_xtensor_db = EquilibriumDB(ignore_newtrees=False) @@ -49,3 +54,10 @@ def register(inner_rewriter: RewriteDatabase | NodeRewriter): **kwargs, ) return node_rewriter + + +def lower_aligned(x: XTensorVariable, out_dims: Sequence[str]) -> TensorVariable: + """Lower an XTensorVariable to a TensorVariable so that it's dimensions are aligned with "out_dims".""" + inp_dims = {d: i for i, d in enumerate(x.type.dims)} + ds_order = tuple(inp_dims.get(dim, "x") for dim in out_dims) + return typing.cast(TensorVariable, x.values.dimshuffle(ds_order)) diff --git a/pytensor/xtensor/rewriting/vectorization.py b/pytensor/xtensor/rewriting/vectorization.py index bed7da564b..2450d09358 100644 --- a/pytensor/xtensor/rewriting/vectorization.py +++ b/pytensor/xtensor/rewriting/vectorization.py @@ -2,8 +2,8 @@ from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.utils import compute_batch_shape -from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor -from pytensor.xtensor.rewriting.utils import register_lower_xtensor +from pytensor.xtensor.basic import xtensor_from_tensor +from pytensor.xtensor.rewriting.utils import lower_aligned, register_lower_xtensor from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise @@ -13,15 +13,7 @@ def lower_elemwise(fgraph, node): out_dims = node.outputs[0].type.dims # Convert input XTensors to Tensors and align batch dimensions - tensor_inputs = [] - for inp in node.inputs: - inp_dims = inp.type.dims - order = [ - inp_dims.index(out_dim) if out_dim in inp_dims else "x" - for out_dim in out_dims - ] - tensor_inp = tensor_from_xtensor(inp).dimshuffle(order) - tensor_inputs.append(tensor_inp) + tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs] tensor_outs = Elemwise(scalar_op=node.op.scalar_op)( *tensor_inputs, return_list=True @@ -42,17 +34,10 @@ def lower_blockwise(fgraph, node): batch_dims = node.outputs[0].type.dims[:batch_ndim] # Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end - tensor_inputs = [] - for inp, core_dims in zip(node.inputs, op.core_dims[0]): - inp_dims = inp.type.dims - # Align the batch dims of the input, and place the core dims on the right - batch_order = [ - inp_dims.index(batch_dim) if batch_dim in inp_dims else "x" - for batch_dim in batch_dims - ] - core_order = [inp_dims.index(core_dim) for core_dim in core_dims] - tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order) - tensor_inputs.append(tensor_inp) + tensor_inputs = [ + lower_aligned(inp, batch_dims + core_dims) + for inp, core_dims in zip(node.inputs, op.core_dims[0], strict=True) + ] signature = op.signature or getattr(op.core_op, "gufunc_signature", None) if signature is None: @@ -92,17 +77,10 @@ def lower_rv(fgraph, node): param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim] # Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end - tensor_params = [] - for inp, core_dims in zip(params, op.core_dims[0]): - inp_dims = inp.type.dims - # Align the batch dims of the input, and place the core dims on the right - batch_order = [ - inp_dims.index(batch_dim) if batch_dim in inp_dims else "x" - for batch_dim in param_batch_dims - ] - core_order = [inp_dims.index(core_dim) for core_dim in core_dims] - tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order) - tensor_params.append(tensor_inp) + tensor_params = [ + lower_aligned(inp, param_batch_dims + core_dims) + for inp, core_dims in zip(params, op.core_dims[0], strict=True) + ] size = None if op.extra_dims: From d0f2d0d43474eec62bde1c7f2ea735629eb7ed6a Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Sun, 29 Jun 2025 20:42:32 +0200 Subject: [PATCH 2/3] Implement broadcast for XTensorVariables Co-authored-by: Ricardo --- pytensor/xtensor/__init__.py | 2 +- pytensor/xtensor/rewriting/shape.py | 60 ++++++++++ pytensor/xtensor/shape.py | 63 ++++++++++- pytensor/xtensor/type.py | 9 ++ pytensor/xtensor/vectorization.py | 14 ++- tests/xtensor/test_shape.py | 167 ++++++++++++++++++++++++++++ 6 files changed, 311 insertions(+), 4 deletions(-) diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index 7f1b9ecddb..c9c19be54d 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -3,7 +3,7 @@ import pytensor.xtensor.rewriting from pytensor.xtensor import linalg, random from pytensor.xtensor.math import dot -from pytensor.xtensor.shape import concat +from pytensor.xtensor.shape import broadcast, concat from pytensor.xtensor.type import ( as_xtensor, xtensor, diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index a80cca8b96..9f6238ae40 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -1,3 +1,4 @@ +import pytensor.tensor as pt from pytensor.graph import node_rewriter from pytensor.tensor import ( broadcast_to, @@ -11,6 +12,7 @@ from pytensor.xtensor.rewriting.basic import register_lower_xtensor from pytensor.xtensor.rewriting.utils import lower_aligned from pytensor.xtensor.shape import ( + Broadcast, Concat, ExpandDims, Squeeze, @@ -157,3 +159,61 @@ def lower_expand_dims(fgraph, node): # Convert result back to xtensor result = xtensor_from_tensor(result_tensor, dims=out.type.dims) return [result] + + +@register_lower_xtensor +@node_rewriter(tracks=[Broadcast]) +def lower_broadcast(fgraph, node): + """Rewrite XBroadcast using tensor operations.""" + + excluded_dims = node.op.exclude + + tensor_inputs = [ + lower_aligned(inp, out.type.dims) + for inp, out in zip(node.inputs, node.outputs, strict=True) + ] + + if not excluded_dims: + # Simple case: All dimensions are broadcasted + tensor_outputs = pt.broadcast_arrays(*tensor_inputs) + + else: + # Complex case: Some dimensions are excluded from broadcasting + # Pick the first dimension_length for each dim + broadcast_dims = { + d: None for d in node.outputs[0].type.dims if d not in excluded_dims + } + for xtensor_inp in node.inputs: + for dim, dim_length in xtensor_inp.sizes.items(): + if dim in broadcast_dims and broadcast_dims[dim] is None: + # If the dimension is not excluded, set its shape + broadcast_dims[dim] = dim_length + assert not any( + value is None for value in broadcast_dims.values() + ), "All dimensions must have a length" + + # Create zeros with the broadcast dimensions, to then broadcast each input against + # PyTensor will rewrite into using only the shapes of the zeros tensor + broadcast_dims = pt.zeros( + tuple(broadcast_dims.values()), + dtype=node.outputs[0].type.dtype, + ) + n_broadcast_dims = broadcast_dims.ndim + + tensor_outputs = [] + for tensor_inp, xtensor_out in zip(tensor_inputs, node.outputs, strict=True): + n_excluded_dims = tensor_inp.type.ndim - n_broadcast_dims + # Excluded dimensions are on the right side of the output tensor so we padright the broadcast_dims + # second is equivalent to `np.broadcast_arrays(x, y)[1]` in PyTensor + tensor_outputs.append( + pt.second( + pt.shape_padright(broadcast_dims, n_excluded_dims), + tensor_inp, + ) + ) + + new_outs = [ + xtensor_from_tensor(out_tensor, dims=out.type.dims) + for out_tensor, out in zip(tensor_outputs, node.outputs) + ] + return new_outs diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index bb528fcf26..bbe021a7df 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -13,7 +13,8 @@ from pytensor.tensor.type import integer_dtypes from pytensor.tensor.utils import get_static_shape_from_size_variables from pytensor.xtensor.basic import XOp -from pytensor.xtensor.type import as_xtensor, xtensor +from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor +from pytensor.xtensor.vectorization import combine_dims_and_shape class Stack(XOp): @@ -504,3 +505,63 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa x = Transpose(dims=tuple(target_dims))(x) return x + + +class Broadcast(XOp): + """Broadcast multiple XTensorVariables against each other.""" + + __props__ = ("exclude",) + + def __init__(self, exclude: Sequence[str] = ()): + self.exclude = tuple(exclude) + + def make_node(self, *inputs): + inputs = [as_xtensor(x) for x in inputs] + + exclude = self.exclude + dims_and_shape = combine_dims_and_shape(inputs, exclude=exclude) + + broadcast_dims = tuple(dims_and_shape.keys()) + broadcast_shape = tuple(dims_and_shape.values()) + dtype = upcast(*[x.type.dtype for x in inputs]) + + outputs = [] + for x in inputs: + x_dims = x.type.dims + x_shape = x.type.shape + # The output has excluded dimensions in the order they appear in the op argument + excluded_dims = tuple(d for d in exclude if d in x_dims) + excluded_shape = tuple(x_shape[x_dims.index(d)] for d in excluded_dims) + + output = xtensor( + dtype=dtype, + shape=broadcast_shape + excluded_shape, + dims=broadcast_dims + excluded_dims, + ) + outputs.append(output) + + return Apply(self, inputs, outputs) + + +def broadcast( + *args, exclude: str | Sequence[str] | None = None +) -> tuple[XTensorVariable, ...]: + """Broadcast any number of XTensorVariables against each other. + + Parameters + ---------- + *args : XTensorVariable + The tensors to broadcast against each other. + exclude : str or Sequence[str] or None, optional + """ + if not args: + return () + + if exclude is None: + exclude = () + elif isinstance(exclude, str): + exclude = (exclude,) + elif not isinstance(exclude, Sequence): + raise TypeError(f"exclude must be None, str, or Sequence, got {type(exclude)}") + # xarray broadcast always returns a tuple, even if there's only one tensor + return tuple(Broadcast(exclude=exclude)(*args, return_list=True)) # type: ignore diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index c5f345e45a..0c8ca0914e 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -736,6 +736,15 @@ def dot(self, other, dim=None): """Matrix multiplication with another XTensorVariable, contracting over matching or specified dims.""" return px.math.dot(self, other, dim=dim) + def broadcast(self, *others, exclude=None): + """Broadcast this tensor against other XTensorVariables.""" + return px.shape.broadcast(self, *others, exclude=exclude) + + def broadcast_like(self, other, exclude=None): + """Broadcast this tensor against another XTensorVariable.""" + _, self_bcast = px.shape.broadcast(other, self, exclude=exclude) + return self_bcast + class XTensorConstantSignature(TensorConstantSignature): pass diff --git a/pytensor/xtensor/vectorization.py b/pytensor/xtensor/vectorization.py index 8243e78170..a6cbb2b5c3 100644 --- a/pytensor/xtensor/vectorization.py +++ b/pytensor/xtensor/vectorization.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from itertools import chain import numpy as np @@ -13,13 +14,22 @@ get_static_shape_from_size_variables, ) from pytensor.xtensor.basic import XOp -from pytensor.xtensor.type import as_xtensor, xtensor +from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor -def combine_dims_and_shape(inputs): +def combine_dims_and_shape( + inputs: Sequence[XTensorVariable], exclude: Sequence[str] | None = None +) -> dict[str, int | None]: + """Combine information of static dimensions and shapes from multiple xtensor inputs. + + Exclude + """ + exclude_set: set[str] = set() if exclude is None else set(exclude) dims_and_shape: dict[str, int | None] = {} for inp in inputs: for dim, dim_length in zip(inp.type.dims, inp.type.shape): + if dim in exclude_set: + continue if dim not in dims_and_shape: dims_and_shape[dim] = dim_length elif dim_length is not None: diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 6abd7b5103..a397d07552 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -9,10 +9,12 @@ import numpy as np from xarray import DataArray +from xarray import broadcast as xr_broadcast from xarray import concat as xr_concat from pytensor.tensor import scalar from pytensor.xtensor.shape import ( + broadcast, concat, stack, unstack, @@ -466,3 +468,168 @@ def test_expand_dims_errors(): # Test with a numpy array as dim (not supported) with pytest.raises(TypeError, match="unhashable type"): y.expand_dims(np.array([1, 2])) + + +class TestBroadcast: + @pytest.mark.parametrize( + "exclude", + [ + None, + [], + ["b"], + ["b", "d"], + ["a", "d"], + ["b", "c", "d"], + ["a", "b", "c", "d"], + ], + ) + def test_compatible_excluded_shapes(self, exclude): + # Create test data + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) + y = xtensor("y", dims=("c", "d"), shape=(5, 6)) + z = xtensor("z", dims=("b", "d"), shape=(4, 6)) + + x_test = xr_arange_like(x) + y_test = xr_arange_like(y) + z_test = xr_arange_like(z) + + # Test with excluded dims + x2_expected, y2_expected, z2_expected = xr_broadcast( + x_test, y_test, z_test, exclude=exclude + ) + x2, y2, z2 = broadcast(x, y, z, exclude=exclude) + fn = xr_function([x, y, z], [x2, y2, z2]) + x2_result, y2_result, z2_result = fn(x_test, y_test, z_test) + + xr_assert_allclose(x2_result, x2_expected) + xr_assert_allclose(y2_result, y2_expected) + xr_assert_allclose(z2_result, z2_expected) + + def test_incompatible_excluded_shapes(self): + # Test that excluded dims are allowed to be different sizes + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) + y = xtensor("y", dims=("c", "d"), shape=(5, 6)) + z = xtensor("z", dims=("b", "d"), shape=(4, 7)) + out = broadcast(x, y, z, exclude=["d"]) + + x_test = xr_arange_like(x) + y_test = xr_arange_like(y) + z_test = xr_arange_like(z) + fn = xr_function([x, y, z], out) + results = fn(x_test, y_test, z_test) + expected_results = xr_broadcast(x_test, y_test, z_test, exclude=["d"]) + for res, expected_res in zip(results, expected_results, strict=True): + xr_assert_allclose(res, expected_res) + + @pytest.mark.parametrize("exclude", [[], ["b"], ["b", "c"], ["a", "b", "d"]]) + def test_runtime_shapes(self, exclude): + x = xtensor("x", dims=("a", "b"), shape=(None, 4)) + y = xtensor("y", dims=("c", "d"), shape=(5, None)) + z = xtensor("z", dims=("b", "d"), shape=(None, None)) + out = broadcast(x, y, z, exclude=exclude) + + x_test = xr_arange_like(xtensor(dims=x.dims, shape=(3, 4))) + y_test = xr_arange_like(xtensor(dims=y.dims, shape=(5, 6))) + z_test = xr_arange_like(xtensor(dims=z.dims, shape=(4, 6))) + fn = xr_function([x, y, z], out) + results = fn(x_test, y_test, z_test) + expected_results = xr_broadcast(x_test, y_test, z_test, exclude=exclude) + for res, expected_res in zip(results, expected_results, strict=True): + xr_assert_allclose(res, expected_res) + + # Test invalid shape raises an error + # Note: We might decide not to raise an error in the lowered graphs for performance reasons + if "d" not in exclude: + z_test_bad = xr_arange_like(xtensor(dims=z.dims, shape=(4, 7))) + with pytest.raises(Exception): + fn(x_test, y_test, z_test_bad) + + def test_broadcast_excluded_dims_in_different_order(self): + """Test broadcasting excluded dims are aligned with user input.""" + x = xtensor("x", dims=("a", "c", "b"), shape=(3, 4, 5)) + y = xtensor("y", dims=("a", "b", "c"), shape=(3, 5, 4)) + out = (out_x, out_y) = broadcast(x, y, exclude=["c", "b"]) + assert out_x.type.dims == ("a", "c", "b") + assert out_y.type.dims == ("a", "c", "b") + + x_test = xr_arange_like(x) + y_test = xr_arange_like(y) + fn = xr_function([x, y], out) + results = fn(x_test, y_test) + expected_results = xr_broadcast(x_test, y_test, exclude=["c", "b"]) + for res, expected_res in zip(results, expected_results, strict=True): + xr_assert_allclose(res, expected_res) + + def test_broadcast_errors(self): + """Test error handling in broadcast.""" + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) + y = xtensor("y", dims=("c", "d"), shape=(5, 6)) + z = xtensor("z", dims=("b", "d"), shape=(4, 6)) + + with pytest.raises(TypeError, match="exclude must be None, str, or Sequence"): + broadcast(x, y, z, exclude=1) + + # Test with conflicting shapes + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) + y = xtensor("y", dims=("c", "d"), shape=(5, 6)) + z = xtensor("z", dims=("b", "d"), shape=(4, 7)) + + with pytest.raises(ValueError, match="Dimension .* has conflicting shapes"): + broadcast(x, y, z) + + def test_broadcast_no_input(self): + assert broadcast() == xr_broadcast() + assert broadcast(exclude=("a",)) == xr_broadcast(exclude=("a",)) + + def test_broadcast_single_input(self): + """Test broadcasting a single input.""" + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) + # Broadcast with a single input can still imply a transpose via the exclude parameter + outs = [ + *broadcast(x), + *broadcast(x, exclude=("a", "b")), + *broadcast(x, exclude=("b", "a")), + *broadcast(x, exclude=("b",)), + ] + + fn = xr_function([x], outs) + x_test = xr_arange_like(x) + results = fn(x_test) + expected_results = [ + *xr_broadcast(x_test), + *xr_broadcast(x_test, exclude=("a", "b")), + *xr_broadcast(x_test, exclude=("b", "a")), + *xr_broadcast(x_test, exclude=("b",)), + ] + for res, expected_res in zip(results, expected_results, strict=True): + xr_assert_allclose(res, expected_res) + + @pytest.mark.parametrize("exclude", [None, ["b"], ["b", "c"]]) + def test_broadcast_like(self, exclude): + """Test broadcast_like method""" + # Create test data + x = xtensor("x", dims=("a", "b"), shape=(3, 4)) + y = xtensor("y", dims=("c", "d"), shape=(5, 6)) + z = xtensor("z", dims=("b", "d"), shape=(4, 6)) + + # Order matters so we test both orders + outs = [ + x.broadcast_like(y, exclude=exclude), + y.broadcast_like(x, exclude=exclude), + y.broadcast_like(z, exclude=exclude), + z.broadcast_like(y, exclude=exclude), + ] + + x_test = xr_arange_like(x) + y_test = xr_arange_like(y) + z_test = xr_arange_like(z) + fn = xr_function([x, y, z], outs) + results = fn(x_test, y_test, z_test) + expected_results = [ + x_test.broadcast_like(y_test, exclude=exclude), + y_test.broadcast_like(x_test, exclude=exclude), + y_test.broadcast_like(z_test, exclude=exclude), + z_test.broadcast_like(y_test, exclude=exclude), + ] + for res, expected_res in zip(results, expected_results, strict=True): + xr_assert_allclose(res, expected_res) From 5b49005c6c840f065a76f609b9949efe44c31bf8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 2 Jul 2025 17:13:35 +0200 Subject: [PATCH 3/3] Rename XDot Op to Dot --- pytensor/xtensor/math.py | 4 ++-- pytensor/xtensor/rewriting/math.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index ad6f22bf51..687d7220d7 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -145,7 +145,7 @@ def softmax(x, dim=None): return exp_x / exp_x.sum(dim=dim) -class XDot(XOp): +class Dot(XOp): """Matrix multiplication between two XTensorVariables. This operation performs matrix multiplication between two tensors, automatically @@ -247,6 +247,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None): if d not in union: raise ValueError(f"Dimension {d} not found in either input") - result = XDot(dims=tuple(dim_set))(x, y) + result = Dot(dims=tuple(dim_set))(x, y) return result diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py index 850d91fad3..c767ec490e 100644 --- a/pytensor/xtensor/rewriting/math.py +++ b/pytensor/xtensor/rewriting/math.py @@ -4,12 +4,12 @@ from pytensor.tensor import einsum from pytensor.tensor.shape import specify_shape from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor -from pytensor.xtensor.math import XDot +from pytensor.xtensor.math import Dot from pytensor.xtensor.rewriting.utils import register_lower_xtensor @register_lower_xtensor -@node_rewriter(tracks=[XDot]) +@node_rewriter(tracks=[Dot]) def lower_dot(fgraph, node): """Rewrite XDot to tensor.dot.