|
86 | 86 | from pytensor.tensor.math import max as pt_max
|
87 | 87 | from pytensor.tensor.math import pow as pt_pow
|
88 | 88 | from pytensor.tensor.math import sum as pt_sum
|
| 89 | +from pytensor.tensor.math import prod as pt_prod |
89 | 90 | from pytensor.tensor.rewriting.basic import (
|
90 | 91 | alloc_like,
|
91 | 92 | broadcasted_by,
|
@@ -1754,6 +1755,30 @@ def local_reduce_broadcastable(fgraph, node):
|
1754 | 1755 | # -- in this case we can remove the reduction completely
|
1755 | 1756 | return [new_reduced.astype(odtype)]
|
1756 | 1757 |
|
| 1758 | +@register_canonicalize |
| 1759 | +@register_uncanonicalize |
| 1760 | +@register_specialize |
| 1761 | +@node_rewriter([Sum, Prod]) |
| 1762 | +def local_useless_join_(fgraph, node): |
| 1763 | + """ |
| 1764 | + sum(join(tensor1, tensor2...)) => sum(sum(tensor) for tensor in tensors) |
| 1765 | + or |
| 1766 | + prod(join(tensor1, tensor2...)) => prod(prod(tensor) for tensor in tensors) |
| 1767 | +
|
| 1768 | + """ |
| 1769 | + (node_inps,) = node.inputs |
| 1770 | + if node_inps.owner and isinstance(node_inps.owner.op, Join): |
| 1771 | + inpts = node_inps.owner.inputs[1:] |
| 1772 | + # This specific implementation would introduce a |
| 1773 | + # `MakeVector` into the graph, which would then |
| 1774 | + # be rewritten again with |
| 1775 | + # pytensor/tensor/rewriting/basic.py:local_sum_make_vector |
| 1776 | + # A similar rewrite must be created for `prod` |
| 1777 | + if isinstance(node.op, Sum): |
| 1778 | + return [pt_sum([pt_sum(inp) for inp in inpts])] |
| 1779 | + elif isinstance(node.op, Prod): |
| 1780 | + return [pt_prod([pt_prod(inp) for inp in inpts])] |
| 1781 | + |
1757 | 1782 |
|
1758 | 1783 | @register_specialize
|
1759 | 1784 | @node_rewriter([Sum, Prod])
|
|
0 commit comments