@@ -344,57 +344,26 @@ def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node):
344
344
345
345
@register_canonicalize
346
346
@register_specialize
347
- @node_rewriter ([_matmul ])
348
- def local_blockwise_dot_to_mul (fgraph , node ):
349
- """Rewrite blockwise dots that correspond to multiplication without summation.
350
-
351
- We don't touch the regular dot, to not interfere with the BLAS optimizations.
352
- """
347
+ @node_rewriter ([_matmul , _dot ])
348
+ def local_dot_to_mul (fgraph , node ):
349
+ """Rewrite blockwise dots that correspond to multiplication without summation."""
353
350
a , b = node .inputs
354
351
a_static_shape = a .type .shape
355
352
b_static_shape = b .type .shape
356
- core_a_ndim = len (node .op .inputs_sig [0 ])
357
- core_b_ndim = len (node .op .inputs_sig [1 ])
358
353
359
- if core_a_ndim > 2 or core_b_ndim > 2 :
360
- # Shouldn't happen, but here just in case
354
+ if isinstance (node .op , Dot ) and (
355
+ len (a_static_shape ) != 2 or len (b_static_shape ) != 2
356
+ ):
357
+ # For now, we only support matrix-matrix multiplication
358
+ # We should eventually canonicalize all dots to this form
361
359
return None
362
360
363
- if core_b_ndim == 1 :
364
- if a_static_shape [- 1 ] == 1 or b_static_shape [- 1 ] == 1 :
365
- if core_a_ndim == 1 :
366
- # inner product: (..., 1) * (..., 1) -> (...)
367
- # just squeeze the last dimensions of a and b
368
- new_a = a .squeeze (- 1 )
369
- new_b = b .squeeze (- 1 )
370
- else :
371
- # matrix vector product: (..., m, 1) * (..., 1) -> (..., m)
372
- # the last dimension of b is already aligned for the elemwise multiplication
373
- # after we squeeze the last dimension of a
374
- new_a = a .squeeze (- 1 )
375
- new_b = b
376
- else :
377
- return None
378
-
379
- else :
380
- if a_static_shape [- 1 ] == 1 or b_static_shape [- 2 ] == 1 :
381
- if core_a_ndim == 1 :
382
- # vector_matrix product: (..., 1) * (..., 1, n) -> (..., n)
383
- # the last dimension of a is already aligned for the elemwise multiplication
384
- # after we squeeze the one to last dimension of b
385
- new_a = a
386
- new_b = b .squeeze (- 2 )
387
- else :
388
- # matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
389
- # the dimensions of a and b are already aligned for the elemwise multiplication
390
- new_a = a
391
- new_b = b
392
- else :
393
- return None
361
+ # Check if we have matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
362
+ if not (a_static_shape [- 1 ] == 1 or b_static_shape [- 2 ] == 1 ):
363
+ return None
394
364
395
- new_a = copy_stack_trace (a , new_a )
396
- new_b = copy_stack_trace (b , new_b )
397
- new_out = copy_stack_trace (node .out , mul (new_a , new_b ))
365
+ new_out = mul (a , b )
366
+ copy_stack_trace (node .out , new_out )
398
367
return [new_out ]
399
368
400
369
0 commit comments