Skip to content

Commit 085e653

Browse files
committed
Simplify local_dot_to_mul and extend it to core dot
1 parent e83fe3a commit 085e653

File tree

2 files changed

+16
-46
lines changed

2 files changed

+16
-46
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -344,57 +344,26 @@ def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node):
344344

345345
@register_canonicalize
346346
@register_specialize
347-
@node_rewriter([_matmul])
348-
def local_blockwise_dot_to_mul(fgraph, node):
349-
"""Rewrite blockwise dots that correspond to multiplication without summation.
350-
351-
We don't touch the regular dot, to not interfere with the BLAS optimizations.
352-
"""
347+
@node_rewriter([_matmul, _dot])
348+
def local_dot_to_mul(fgraph, node):
349+
"""Rewrite blockwise dots that correspond to multiplication without summation."""
353350
a, b = node.inputs
354351
a_static_shape = a.type.shape
355352
b_static_shape = b.type.shape
356-
core_a_ndim = len(node.op.inputs_sig[0])
357-
core_b_ndim = len(node.op.inputs_sig[1])
358353

359-
if core_a_ndim > 2 or core_b_ndim > 2:
360-
# Shouldn't happen, but here just in case
354+
if isinstance(node.op, Dot) and (
355+
len(a_static_shape) != 2 or len(b_static_shape) != 2
356+
):
357+
# For now, we only support matrix-matrix multiplication
358+
# We should eventually canonicalize all dots to this form
361359
return None
362360

363-
if core_b_ndim == 1:
364-
if a_static_shape[-1] == 1 or b_static_shape[-1] == 1:
365-
if core_a_ndim == 1:
366-
# inner product: (..., 1) * (..., 1) -> (...)
367-
# just squeeze the last dimensions of a and b
368-
new_a = a.squeeze(-1)
369-
new_b = b.squeeze(-1)
370-
else:
371-
# matrix vector product: (..., m, 1) * (..., 1) -> (..., m)
372-
# the last dimension of b is already aligned for the elemwise multiplication
373-
# after we squeeze the last dimension of a
374-
new_a = a.squeeze(-1)
375-
new_b = b
376-
else:
377-
return None
378-
379-
else:
380-
if a_static_shape[-1] == 1 or b_static_shape[-2] == 1:
381-
if core_a_ndim == 1:
382-
# vector_matrix product: (..., 1) * (..., 1, n) -> (..., n)
383-
# the last dimension of a is already aligned for the elemwise multiplication
384-
# after we squeeze the one to last dimension of b
385-
new_a = a
386-
new_b = b.squeeze(-2)
387-
else:
388-
# matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
389-
# the dimensions of a and b are already aligned for the elemwise multiplication
390-
new_a = a
391-
new_b = b
392-
else:
393-
return None
361+
# Check if we have matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
362+
if not (a_static_shape[-1] == 1 or b_static_shape[-2] == 1):
363+
return None
394364

395-
new_a = copy_stack_trace(a, new_a)
396-
new_b = copy_stack_trace(b, new_b)
397-
new_out = copy_stack_trace(node.out, mul(new_a, new_b))
365+
new_out = mul(a, b)
366+
copy_stack_trace(node.out, new_out)
398367
return [new_out]
399368

400369

tests/tensor/rewriting/test_math.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4714,14 +4714,15 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
47144714
== 1
47154715
)
47164716

4717-
# For now rewrite only applies to Batched Dots
47184717
rewritten_out = rewrite_graph(out)
47194718
assert rewritten_out.type.shape == out.type.shape
4719+
# For now the rewrite doesn't apply to non matrix-matrix dots
4720+
applies = batched or (len(a_shape) == 2 and len(b_shape) == 2)
47204721
assert sum(
47214722
isinstance(var.owner.op, (Blockwise | Dot))
47224723
for var in ancestors([rewritten_out])
47234724
if var.owner
4724-
) == (0 if batched else 1)
4725+
) == (0 if applies else 1)
47254726

47264727
a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype)
47274728
b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype)

0 commit comments

Comments
 (0)