Skip to content
Merged
9 changes: 1 addition & 8 deletions pytensor/xtensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
if d not in union:
raise ValueError(f"Dimension {d} not found in either input {y.type.dims}")

dotted_dims = tuple(dim_set & intersection)
summed_dims = tuple(dim_set.difference(dotted_dims))

result = XDot(dims=dotted_dims)(x, y)

if summed_dims:
# Sum over all remaining axes
result = result.sum(dim=summed_dims)
result = XDot(dims=tuple(dim_set))(x, y)

return result
42 changes: 26 additions & 16 deletions pytensor/xtensor/rewriting/math.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from string import ascii_lowercase

from pytensor.graph import node_rewriter
from pytensor.tensor import tensordot
from pytensor.tensor import einsum
from pytensor.tensor.shape import reshape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.math import XDot
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
Expand All @@ -20,22 +23,29 @@ def lower_dot(fgraph, node):
x_tensor = tensor_from_xtensor(x)
y_tensor = tensor_from_xtensor(y)

# Get the axes for contraction
x_axes = [x.type.dims.index(dim) for dim in node.op.dims]
y_axes = [y.type.dims.index(dim) for dim in node.op.dims]
# Collect all dimension names across inputs and output
all_dims = list(
dict.fromkeys(x.type.dims + y.type.dims + out.type.dims)
) # preserve order
if len(all_dims) > len(ascii_lowercase):
raise ValueError("Too many dimensions to map to einsum subscripts")

dim_to_char = dict(zip(all_dims, ascii_lowercase))

# Build einsum string
x_subs = "".join(dim_to_char[d] for d in x.type.dims)
y_subs = "".join(dim_to_char[d] for d in y.type.dims)
out_subs = "".join(dim_to_char[d] for d in out.type.dims)
einsum_str = f"{x_subs},{y_subs}->{out_subs}"

# Check that shapes match along contracted dimensions
for dim in node.op.dims:
x_idx = x.type.dims.index(dim)
y_idx = y.type.dims.index(dim)
if x.type.shape[x_idx] != y.type.shape[y_idx]:
raise ValueError(
"Input arrays have inconsistent type shape along the axes "
f"that are to be reduced with tensordot: {x.type.shape[x_idx]} != {y.type.shape[y_idx]}"
)
# Perform the einsum operation
out_tensor = einsum(einsum_str, x_tensor, y_tensor)

# Perform the tensordot operation
out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes))
# Reshape to match the expected output shape
try:
out_tensor = reshape(out_tensor, out.type.shape)
except (TypeError, ValueError):
# Skip reshaping if symbolic shapes are present
pass

# Convert back to xtensor
return [xtensor_from_tensor(out_tensor, out.type.dims)]
29 changes: 25 additions & 4 deletions tests/xtensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,10 @@ def test_dot():
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

return
# Dot product with sum in the middle
# This is not supported yet
x_test = DataArray(np.arange(120).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d"))
y_test = DataArray(np.arange(360).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e"))
x_test = DataArray(np.arange(120.0).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d"))
y_test = DataArray(np.arange(360.0).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e"))
expected = x_test.dot(y_test, dim=("b", "d"))
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 4, 5))
y = xtensor("y", dims=("b", "c", "d", "e"), shape=(3, 4, 5, 6))
Expand All @@ -259,6 +258,27 @@ def test_dot():
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Same but with first two dims
expected = x_test.dot(y_test, dim=["a", "b"])
z = x.dot(y, dim=["a", "b"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Same but with last two
expected = x_test.dot(y_test, dim=["d", "e"])
z = x.dot(y, dim=["d", "e"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)

# Same but with every other dim
expected = x_test.dot(y_test, dim=["a", "c", "e"])
z = x.dot(y, dim=["a", "c", "e"])
fn = xr_function([x, y], z)
z_test = fn(x_test, y_test)
xr_assert_allclose(z_test, expected)


def test_dot_errors():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
Expand All @@ -273,7 +293,8 @@ def test_dot_errors():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(4, 5))
with pytest.raises(
ValueError, match="Input arrays have inconsistent type shape along the axes"
ValueError,
match="Size of label 'b' for operand 1.*does not match previous terms",
):
z = x.dot(y)
fn = function([x, y], z)
Expand Down