File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change @@ -916,6 +916,10 @@ def specialize_matmul_to_batched_dot(fgraph, node):
916
916
"""
917
917
x , y = node .inputs
918
918
919
+ if x .type .ndim < 3 :
920
+ # This doesn't actually have a batch dimension
921
+ return None
922
+
919
923
# BatchedDot does not allow implicit broadcasting of the batch dimensions
920
924
# We do not want to explicitly broadcast as it may result in huge arrays
921
925
if x .type .broadcastable [:- 2 ] != y .type .broadcastable [:- 2 ]:
@@ -926,6 +930,7 @@ def specialize_matmul_to_batched_dot(fgraph, node):
926
930
if len (x_shape ) > 3 :
927
931
# If we have more than one batch dim, ravel it
928
932
x = x .reshape ((- 1 , x_shape [- 2 ], x_shape [- 1 ]))
933
+ if len (y_shape ) > 3 :
929
934
y = y .reshape ((- 1 , y_shape [- 2 ], y_shape [- 1 ]))
930
935
931
936
new_out = _batched_dot (x , y )
You can’t perform that action at this time.
0 commit comments