Skip to content

Commit ed6bbd6

Browse files
committed
Handle symbolic shapes
1 parent 6ce20d0 commit ed6bbd6

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

pytensor/xtensor/rewriting/math.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ def lower_dot(fgraph, node):
4141
# Perform the einsum operation
4242
out_tensor = einsum(einsum_str, x_tensor, y_tensor)
4343

44-
# Reshape to match the expected output shape
45-
try:
44+
# Check if we have symbolic shapes
45+
sym_shape = any(not isinstance(s, int) for s in out.type.shape)
46+
47+
# If we have concrete shapes, reshape to match them
48+
if not sym_shape:
4649
out_tensor = reshape(out_tensor, out.type.shape)
47-
except (TypeError, ValueError):
48-
# Skip reshaping if symbolic shapes are present
49-
pass
5050

5151
return [xtensor_from_tensor(out_tensor, out.type.dims)]

tests/xtensor/test_math.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,17 @@ def test_dot():
277277
z_test = fn(x_test, y_test)
278278
xr_assert_allclose(z_test, expected)
279279

280+
# Test symbolic shapes
281+
x = xtensor("x", dims=("a", "b"), shape=(None, 3)) # First dimension is symbolic
282+
y = xtensor("y", dims=("b", "c"), shape=(3, None)) # Second dimension is symbolic
283+
z = x.dot(y)
284+
fn = xr_function([x, y], z)
285+
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
286+
y_test = DataArray(np.ones((3, 4)), dims=("b", "c"))
287+
z_test = fn(x_test, y_test)
288+
expected = x_test.dot(y_test)
289+
xr_assert_allclose(z_test, expected)
290+
280291

281292
def test_dot_errors():
282293
x = xtensor("x", dims=("a", "b"), shape=(2, 3))

0 commit comments

Comments
 (0)