Skip to content

Commit e44ca9f

Browse files
committed
Now with more einsum
1 parent e81657e commit e44ca9f

File tree

3 files changed

+52
-28
lines changed

3 files changed

+52
-28
lines changed

pytensor/xtensor/math.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
231231
if d not in union:
232232
raise ValueError(f"Dimension {d} not found in either input {y.type.dims}")
233233

234-
dotted_dims = tuple(dim_set & intersection)
235-
summed_dims = tuple(dim_set.difference(dotted_dims))
236-
237-
result = XDot(dims=dotted_dims)(x, y)
238-
239-
if summed_dims:
240-
# Sum over all remaining axes
241-
result = result.sum(dim=summed_dims)
234+
result = XDot(dims=tuple(dim_set))(x, y)
242235

243236
return result

pytensor/xtensor/rewriting/math.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from string import ascii_lowercase
2+
13
from pytensor.graph import node_rewriter
2-
from pytensor.tensor import tensordot
4+
from pytensor.tensor import einsum
5+
from pytensor.tensor.shape import reshape
36
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
47
from pytensor.xtensor.math import XDot
58
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
@@ -20,22 +23,29 @@ def lower_dot(fgraph, node):
2023
x_tensor = tensor_from_xtensor(x)
2124
y_tensor = tensor_from_xtensor(y)
2225

23-
# Get the axes for contraction
24-
x_axes = [x.type.dims.index(dim) for dim in node.op.dims]
25-
y_axes = [y.type.dims.index(dim) for dim in node.op.dims]
26+
# Collect all dimension names across inputs and output
27+
all_dims = list(
28+
dict.fromkeys(x.type.dims + y.type.dims + out.type.dims)
29+
) # preserve order
30+
if len(all_dims) > len(ascii_lowercase):
31+
raise ValueError("Too many dimensions to map to einsum subscripts")
32+
33+
dim_to_char = dict(zip(all_dims, ascii_lowercase))
34+
35+
# Build einsum string
36+
x_subs = "".join(dim_to_char[d] for d in x.type.dims)
37+
y_subs = "".join(dim_to_char[d] for d in y.type.dims)
38+
out_subs = "".join(dim_to_char[d] for d in out.type.dims)
39+
einsum_str = f"{x_subs},{y_subs}->{out_subs}"
2640

27-
# Check that shapes match along contracted dimensions
28-
for dim in node.op.dims:
29-
x_idx = x.type.dims.index(dim)
30-
y_idx = y.type.dims.index(dim)
31-
if x.type.shape[x_idx] != y.type.shape[y_idx]:
32-
raise ValueError(
33-
"Input arrays have inconsistent type shape along the axes "
34-
f"that are to be reduced with tensordot: {x.type.shape[x_idx]} != {y.type.shape[y_idx]}"
35-
)
41+
# Perform the einsum operation
42+
out_tensor = einsum(einsum_str, x_tensor, y_tensor)
3643

37-
# Perform the tensordot operation
38-
out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes))
44+
# Reshape to match the expected output shape
45+
try:
46+
out_tensor = reshape(out_tensor, out.type.shape)
47+
except (TypeError, ValueError):
48+
# Skip reshaping if symbolic shapes are present
49+
pass
3950

40-
# Convert back to xtensor
4151
return [xtensor_from_tensor(out_tensor, out.type.dims)]

tests/xtensor/test_math.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,10 @@ def test_dot():
246246
z_test = fn(x_test, y_test)
247247
xr_assert_allclose(z_test, expected)
248248

249-
return
250249
# Dot product with sum in the middle
251250
# This is not supported yet
252-
x_test = DataArray(np.arange(120).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d"))
253-
y_test = DataArray(np.arange(360).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e"))
251+
x_test = DataArray(np.arange(120.0).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d"))
252+
y_test = DataArray(np.arange(360.0).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e"))
254253
expected = x_test.dot(y_test, dim=("b", "d"))
255254
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 4, 5))
256255
y = xtensor("y", dims=("b", "c", "d", "e"), shape=(3, 4, 5, 6))
@@ -259,6 +258,27 @@ def test_dot():
259258
z_test = fn(x_test, y_test)
260259
xr_assert_allclose(z_test, expected)
261260

261+
# Same but with first two dims
262+
expected = x_test.dot(y_test, dim=["a", "b"])
263+
z = x.dot(y, dim=["a", "b"])
264+
fn = xr_function([x, y], z)
265+
z_test = fn(x_test, y_test)
266+
xr_assert_allclose(z_test, expected)
267+
268+
# Same but with last two
269+
expected = x_test.dot(y_test, dim=["d", "e"])
270+
z = x.dot(y, dim=["d", "e"])
271+
fn = xr_function([x, y], z)
272+
z_test = fn(x_test, y_test)
273+
xr_assert_allclose(z_test, expected)
274+
275+
# Same but with every other dim
276+
expected = x_test.dot(y_test, dim=["a", "c", "e"])
277+
z = x.dot(y, dim=["a", "c", "e"])
278+
fn = xr_function([x, y], z)
279+
z_test = fn(x_test, y_test)
280+
xr_assert_allclose(z_test, expected)
281+
262282

263283
def test_dot_errors():
264284
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
@@ -273,7 +293,8 @@ def test_dot_errors():
273293
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
274294
y = xtensor("y", dims=("b", "c"), shape=(4, 5))
275295
with pytest.raises(
276-
ValueError, match="Input arrays have inconsistent type shape along the axes"
296+
ValueError,
297+
match="Size of label 'b' for operand 1.*does not match previous terms",
277298
):
278299
z = x.dot(y)
279300
fn = function([x, y], z)

0 commit comments

Comments
 (0)