@@ -194,16 +194,15 @@ def local_lift_transpose_through_dot(fgraph, node):
194194 return ret
195195
196196
197- @register_canonicalize
198- @register_specialize
199- @node_rewriter (tracks = [_matmul ])
200- def local_batched_matmul_to_core_matmul (fgraph , node ):
197+ def _batched_matmul_to_core_matmul (fgraph , node , allow_reshape : bool ):
201198 """Move batch dimensions of matmul operands to core matmul
202199
203200 Example, if x has batch dimensions that don't overlap with batch dimensions of y
204201 x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
205202
206203 It also works for batch dimensions of y that don't overlap with batch dimensions of x
204+
205+ The rewrite only uses reshape when mixing dimensions, and it can refuse to apply if `allow_reshape=False`
207206 """
208207
209208 x , y = node .inputs
@@ -247,6 +246,9 @@ def local_batched_matmul_to_core_matmul(fgraph, node):
247246 # x was a row matrix, squeeze it to clean up the graph
248247 x_stacked = x_stacked .squeeze (- 2 )
249248 if n_x_axis_to_merge > 1 or not x_is_row :
249+ if not allow_reshape :
250+ return None
251+
250252 # Ravel moved batch dims together with (m) if needed
251253 x_stacked_shape = tuple (x_stacked .shape )
252254 x_stacked = x_stacked .reshape (
@@ -262,6 +264,8 @@ def local_batched_matmul_to_core_matmul(fgraph, node):
262264 # y was a column matrix, squeeze it to clean up the graph
263265 y_stacked = y_stacked .squeeze (- 1 )
264266 if n_y_axis_to_merge > 1 or not y_is_col :
267+ if not allow_reshape :
268+ return False
265269 # Ravel moved batch dims together with (n) if needed
266270 y_stacked_shape = tuple (y_stacked .shape )
267271 y_stacked = y_stacked .reshape (
@@ -319,6 +323,21 @@ def local_batched_matmul_to_core_matmul(fgraph, node):
319323 return [out ]
320324
321325
326+ @register_canonicalize
327+ @node_rewriter (tracks = [_matmul ])
328+ def local_batched_matmul_to_core_matmul (fgraph , node ):
329+ # Allow passing batch dimensions of matmul to core vector / column matrices
330+ return _batched_matmul_to_core_matmul (fgraph , node , allow_reshape = False )
331+
332+
333+ @register_specialize
334+ @node_rewriter (tracks = [_matmul ])
335+ def local_batched_matmul_to_core_matmul_with_reshape (fgraph , node ):
336+ # Allow stacking batch dimensions of matmul with core dimensions, with a reshape operation
337+ # We only apply this in specialize, because grahs with reshape are hard to work with
338+ return _batched_matmul_to_core_matmul (fgraph , node , allow_reshape = True )
339+
340+
322341@register_canonicalize
323342@register_specialize
324343@node_rewriter ([_matmul ])
0 commit comments