Skip to content

Commit 7b6f864

Browse files
committed
Don't try to create invalid BatchedDot in specialize_matmul_to_batched_dot rewrite
1 parent 3c4bf02 commit 7b6f864

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

pytensor/tensor/rewriting/blas.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,10 @@ def specialize_matmul_to_batched_dot(fgraph, node):
916916
"""
917917
x, y = node.inputs
918918

919+
if x.type.ndim < 3:
920+
# This doesn't actually have a batch dimension
921+
return None
922+
919923
# BatchedDot does not allow implicit broadcasting of the batch dimensions
920924
# We do not want to explicitly broadcast as it may result in huge arrays
921925
if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]:
@@ -926,6 +930,7 @@ def specialize_matmul_to_batched_dot(fgraph, node):
926930
if len(x_shape) > 3:
927931
# If we have more than one batch dim, ravel it
928932
x = x.reshape((-1, x_shape[-2], x_shape[-1]))
933+
if len(y_shape) > 3:
929934
y = y.reshape((-1, y_shape[-2], y_shape[-1]))
930935

931936
new_out = _batched_dot(x, y)

0 commit comments

Comments
 (0)