Skip to content

Commit f86a41f

Browse files
committed
Rewrite specifically for Sum and Prod to remove Join
1 parent 981688c commit f86a41f

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
from pytensor.tensor.math import max as pt_max
8787
from pytensor.tensor.math import pow as pt_pow
8888
from pytensor.tensor.math import sum as pt_sum
89+
from pytensor.tensor.math import prod as pt_prod
8990
from pytensor.tensor.rewriting.basic import (
9091
alloc_like,
9192
broadcasted_by,
@@ -1754,6 +1755,30 @@ def local_reduce_broadcastable(fgraph, node):
17541755
# -- in this case we can remove the reduction completely
17551756
return [new_reduced.astype(odtype)]
17561757

1758+
@register_canonicalize
1759+
@register_uncanonicalize
1760+
@register_specialize
1761+
@node_rewriter([Sum, Prod])
1762+
def local_useless_join_(fgraph, node):
1763+
"""
1764+
sum(join(tensor1, tensor2...)) => sum(sum(tensor) for tensor in tensors)
1765+
or
1766+
prod(join(tensor1, tensor2...)) => prod(prod(tensor) for tensor in tensors)
1767+
1768+
"""
1769+
(node_inps,) = node.inputs
1770+
if node_inps.owner and isinstance(node_inps.owner.op, Join):
1771+
inpts = node_inps.owner.inputs[1:]
1772+
# This specific implementation would introduce a
1773+
# `MakeVector` into the graph, which would then
1774+
# be rewritten again with
1775+
# pytensor/tensor/rewriting/basic.py:local_sum_make_vector
1776+
# A similar rewrite must be created for `prod`
1777+
if isinstance(node.op, Sum):
1778+
return [pt_sum([pt_sum(inp) for inp in inpts])]
1779+
elif isinstance(node.op, Prod):
1780+
return [pt_prod([pt_prod(inp) for inp in inpts])]
1781+
17571782

17581783
@register_specialize
17591784
@node_rewriter([Sum, Prod])

0 commit comments

Comments
 (0)