Skip to content
Merged
121 changes: 119 additions & 2 deletions pytensor/xtensor/math.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import sys
from collections.abc import Iterable
from types import EllipsisType

import numpy as np

import pytensor.scalar as ps
from pytensor import config
from pytensor.graph.basic import Apply
from pytensor.scalar import ScalarOp
from pytensor.scalar.basic import _cast_mapping
from pytensor.xtensor.basic import as_xtensor
from pytensor.scalar.basic import _cast_mapping, upcast
from pytensor.xtensor.basic import XOp, as_xtensor
from pytensor.xtensor.type import xtensor
from pytensor.xtensor.vectorization import XElemwise


Expand Down Expand Up @@ -134,3 +138,116 @@ def cast(x, dtype):
if dtype not in _xelemwise_cast_op:
_xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype])
return _xelemwise_cast_op[dtype](x)


class XDot(XOp):
"""Matrix multiplication between two XTensorVariables.

This operation performs matrix multiplication between two tensors, automatically
aligning and contracting dimensions. The behavior matches xarray's dot operation.

Parameters
----------
dims : tuple of str
The dimensions to contract over. If None, will contract over all matching dimensions.
"""

__props__ = ("dims",)

def __init__(self, dims: Iterable[str]):
self.dims = dims
super().__init__()

def make_node(self, x, y):
x = as_xtensor(x)
y = as_xtensor(y)

# Filter out contracted dimensions
x_dims = [dim for dim in x.type.dims if dim not in self.dims]
y_dims = [dim for dim in y.type.dims if dim not in self.dims]
x_shape = [
size for dim, size in zip(x.type.dims, x.type.shape) if dim not in self.dims
]
y_shape = [
size for dim, size in zip(y.type.dims, y.type.shape) if dim not in self.dims
]

# Combine remaining dimensions
out_dims = tuple(x_dims + y_dims)
out_shape = tuple(x_shape + y_shape)

# Determine output dtype
out_dtype = upcast(x.type.dtype, y.type.dtype)

out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, y], [out])


def dot(x, y, dims: str | Iterable[str] | EllipsisType | None = None):
"""Matrix multiplication between two XTensorVariables.

This operation performs matrix multiplication between two tensors, automatically
aligning and contracting dimensions. The behavior matches xarray's dot operation.

Parameters
----------
x : XTensorVariable
First input tensor
y : XTensorVariable
Second input tensor
dims : str, Iterable[Hashable], EllipsisType, or None, optional
The dimensions to contract over. If None, will contract over all matching dimensions.
If Ellipsis (...), will contract over all dimensions.

Returns
-------
XTensorVariable
The result of the matrix multiplication.

Examples
--------
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
>>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4))
>>> z = dot(x, y) # Result has dimensions ("a", "c")
>>> z = dot(x, y, dim=...) # Contract over all dimensions
"""
x = as_xtensor(x)
y = as_xtensor(y)

# Canonicalize dims
if isinstance(dims, str):
dims = (dims,)
elif isinstance(dims, Iterable):
dims = tuple(dims)

# Validate provided dims
if isinstance(dims, Iterable):
for dim in dims:
if dim not in x.type.dims:
raise ValueError(
f"Dimension {dim} not found in first input {x.type.dims}"
)
if dim not in y.type.dims:
raise ValueError(
f"Dimension {dim} not found in second input {y.type.dims}"
)

# If dims is ... , we have to sum over all remaining axes
sum_result = dims is ...

# Handle None and ... cases
if dims is None or dims is ...:
# Contract over all matching dimensions
x_dims = set(x.type.dims)
y_dims = set(y.type.dims)
dims = tuple(x_dims & y_dims)

result = XDot(dims=dims)(x, y)

if sum_result:
from pytensor.xtensor.reduction import sum as xtensor_sum

# Sum over all remaining axes
result = xtensor_sum(result, dim=...)

return result
1 change: 1 addition & 0 deletions pytensor/xtensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytensor.xtensor.rewriting.basic
import pytensor.xtensor.rewriting.indexing
import pytensor.xtensor.rewriting.math
import pytensor.xtensor.rewriting.reduction
import pytensor.xtensor.rewriting.shape
import pytensor.xtensor.rewriting.vectorization
41 changes: 41 additions & 0 deletions pytensor/xtensor/rewriting/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from pytensor.graph import node_rewriter
from pytensor.tensor import tensordot
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.math import XDot
from pytensor.xtensor.rewriting.utils import register_lower_xtensor


@register_lower_xtensor
@node_rewriter(tracks=[XDot])
def lower_dot(fgraph, node):
"""Rewrite XDot to tensor.dot.

This rewrite converts an XDot operation to a tensor-based dot operation,
handling dimension alignment and contraction.
"""
[x, y] = node.inputs
[out] = node.outputs

# Convert inputs to tensors
x_tensor = tensor_from_xtensor(x)
y_tensor = tensor_from_xtensor(y)

# Get the axes for contraction
x_axes = [x.type.dims.index(dim) for dim in node.op.dims]
y_axes = [y.type.dims.index(dim) for dim in node.op.dims]

# Check that shapes match along contracted dimensions
for dim in node.op.dims:
x_idx = x.type.dims.index(dim)
y_idx = y.type.dims.index(dim)
if x.type.shape[x_idx] != y.type.shape[y_idx]:
raise ValueError(
"Input arrays have inconsistent type shape along the axes "
f"that are to be reduced with tensordot: {x.type.shape[x_idx]} != {y.type.shape[y_idx]}"
)

# Perform the tensordot operation
out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes))

# Convert back to xtensor
return [xtensor_from_tensor(out_tensor, out.type.dims)]
4 changes: 4 additions & 0 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,10 @@ def stack(self, dim, **dims):
def unstack(self, dim, **dims):
return px.shape.unstack(self, dim, **dims)

def dot(self, other, dims=None):
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
return px.math.dot(self, other, dims=dims)


class XTensorConstantSignature(tuple):
def __eq__(self, other):
Expand Down
108 changes: 108 additions & 0 deletions tests/xtensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,111 @@ def test_cast():
yc64 = x.astype("complex64")
with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"):
yc64.astype("float64")


def test_dot():
"""Test basic dot product operations."""
# Test matrix-vector dot product
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b",), shape=(3,))
z = x.dot(y)
fn = xr_function([x, y], z)

x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
y_test = DataArray(np.ones(3), dims=("b",))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)

# Test matrix-vector dot product with ellipsis
z = x.dot(y, dims=...)
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=...)
xr_assert_allclose(z_test, expected)

# Test matrix-matrix dot product
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
z = x.dot(y)
fn = xr_function([x, y], z)

x_test = DataArray(np.add.outer(np.arange(2.0), np.arange(3.0)), dims=("a", "b"))
y_test = DataArray(np.add.outer(np.arange(3.0), np.arange(4.0)), dims=("b", "c"))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)

# Test matrix-matrix dot product with string dims
z = x.dot(y, dims="b")
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim="b")
xr_assert_allclose(z_test, expected)

# Test matrix-matrix dot product with list of dims
z = x.dot(y, dims=["b"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=["b"])
xr_assert_allclose(z_test, expected)

# Test matrix-matrix dot product with ellipsis
z = x.dot(y, dims=...)
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=...)
xr_assert_allclose(z_test, expected)

# Test a case where there are two dimensions to sum over
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5))
z = x.dot(y)
fn = xr_function([x, y], z)

x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c"))
y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d"))
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test)
xr_assert_allclose(z_test, expected)

# Same but with explicit dimensions
z = x.dot(y, dims=["b", "c"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=["b", "c"])
xr_assert_allclose(z_test, expected)

# Same but with ellipses
z = x.dot(y, dims=...)
fn = xr_function([x, y], z)

z_test = fn(x_test, y_test)
expected = x_test.dot(y_test, dim=...)
xr_assert_allclose(z_test, expected)


def test_dot_errors():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
with pytest.raises(ValueError, match="Dimension c not found in first input"):
x.dot(y, dims=["c"])
with pytest.raises(ValueError, match="Dimension a not found in second input"):
x.dot(y, dims=["a"])

# Test a case where there are no matching dimensions
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
y_test = DataArray(np.ones((4, 5)), dims=("b", "c"))
with pytest.raises(ValueError, match="cannot reindex or align along dimension"):
x_test.dot(y_test)

x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(4, 5))
with pytest.raises(
ValueError, match="Input arrays have inconsistent type shape along the axes"
):
z = x.dot(y)
fn = function([x, y], z)
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
y_test = DataArray(np.ones((4, 5)), dims=("b", "c"))
fn(x_test, y_test)