Skip to content

Commit 18501bf

Browse files
committed
Adding shape checking at rewrite time
1 parent 7c9214d commit 18501bf

File tree

3 files changed

+40
-4
lines changed

3 files changed

+40
-4
lines changed

pytensor/xtensor/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from collections.abc import Hashable, Iterable
2+
from collections.abc import Iterable
33
from types import EllipsisType
44

55
import numpy as np
@@ -190,7 +190,7 @@ def make_node(self, x, y):
190190
return Apply(self, [x, y], [out])
191191

192192

193-
def dot(x, y, dims: str | Iterable[Hashable] | EllipsisType | None = None):
193+
def dot(x, y, dims: str | Iterable[str] | EllipsisType | None = None):
194194
"""Matrix multiplication between two XTensorVariables.
195195
196196
This operation performs matrix multiplication between two tensors, automatically

pytensor/xtensor/rewriting/math.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,21 @@ def lower_dot(fgraph, node):
2020
x_tensor = tensor_from_xtensor(x)
2121
y_tensor = tensor_from_xtensor(y)
2222

23-
# Get axes to contract for each input
23+
# Get the axes for contraction
2424
x_axes = [x.type.dims.index(dim) for dim in node.op.dims]
2525
y_axes = [y.type.dims.index(dim) for dim in node.op.dims]
2626

27-
# Perform dot product
27+
# Check that shapes match along contracted dimensions
28+
for dim in node.op.dims:
29+
x_idx = x.type.dims.index(dim)
30+
y_idx = y.type.dims.index(dim)
31+
if x.type.shape[x_idx] != y.type.shape[y_idx]:
32+
raise ValueError(
33+
"Input arrays have inconsistent type shape along the axes "
34+
f"that are to be reduced with tensordot: {x.type.shape[x_idx]} != {y.type.shape[y_idx]}"
35+
)
36+
37+
# Perform the tensordot operation
2838
out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes))
2939

3040
# Sum over all remaining axes if needed

tests/xtensor/test_math.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,29 @@ def test_dot():
229229
z_test = fn(x_test, y_test)
230230
expected = x_test.dot(y_test, dim=...)
231231
xr_assert_allclose(z_test, expected)
232+
233+
234+
def test_dot_errors():
235+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
236+
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
237+
with pytest.raises(ValueError, match="Dimension c not found in first input"):
238+
x.dot(y, dims=["c"])
239+
with pytest.raises(ValueError, match="Dimension a not found in second input"):
240+
x.dot(y, dims=["a"])
241+
242+
# Test a case where there are no matching dimensions
243+
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
244+
y_test = DataArray(np.ones((4, 5)), dims=("b", "c"))
245+
with pytest.raises(ValueError, match="cannot reindex or align along dimension"):
246+
x_test.dot(y_test)
247+
248+
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
249+
y = xtensor("y", dims=("b", "c"), shape=(4, 5))
250+
with pytest.raises(
251+
ValueError, match="Input arrays have inconsistent type shape along the axes"
252+
):
253+
z = x.dot(y)
254+
fn = function([x, y], z)
255+
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
256+
y_test = DataArray(np.ones((4, 5)), dims=("b", "c"))
257+
fn(x_test, y_test)

0 commit comments

Comments
 (0)