Skip to content

Commit d873bff

Browse files
Fold new rewrite into local_pow_canonicalize and simplify docstring
1 parent 97bd1c4 commit d873bff

File tree

1 file changed

+17
-30
lines changed

1 file changed

+17
-30
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010
import pytensor.scalar.basic as ps
1111
import pytensor.scalar.math as ps_math
12-
from pytensor.graph import FunctionGraph
13-
from pytensor.graph.basic import Apply, Constant, Variable
12+
from pytensor.graph.basic import Constant, Variable
1413
from pytensor.graph.rewriting.basic import (
1514
NodeRewriter,
1615
PatternNodeRewriter,
@@ -1906,39 +1905,27 @@ def local_reciprocal_canon(fgraph, node):
19061905
@register_canonicalize
19071906
@node_rewriter([pt_pow])
19081907
def local_pow_canonicalize(fgraph, node):
1909-
cst = get_underlying_scalar_constant_value(
1910-
node.inputs[1], only_process_constants=True, raise_not_constant=False
1911-
)
1912-
if cst == 0:
1913-
return [alloc_like(1, node.outputs[0], fgraph)]
1914-
if cst == 1:
1915-
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
1916-
1917-
1918-
@register_canonicalize
1919-
@node_rewriter([pt_pow])
1920-
def local_pow_canonicalize_base_1(
1921-
fgraph: FunctionGraph, node: Apply
1922-
) -> list[TensorVariable] | None:
19231908
"""
1924-
Replace `1 ** x` with 1, broadcast to the shape of the output.
1909+
Rewrites for exponential functions with straight-forward simplifications:
1910+
1. x ** 0 -> 1
1911+
2. x ** 1 -> x
1912+
3. 1 ** x -> 1
19251913
1926-
Parameters
1927-
----------
1928-
fgraph: FunctionGraph
1929-
Full function graph being rewritten
1930-
node: Apply
1931-
Specific node being rewritten
1932-
1933-
Returns
1934-
-------
1935-
rewritten_output: list[TensorVariable] | None
1936-
Rewritten output of node, or None if no rewrite is possible
1914+
In all cases, the shape of the output is the result of broadcasting the shapes of the inputs.
19371915
"""
1938-
cst = get_underlying_scalar_constant_value(
1916+
1917+
cst_base = get_underlying_scalar_constant_value(
19391918
node.inputs[0], only_process_constants=True, raise_not_constant=False
19401919
)
1941-
if cst == 1:
1920+
if cst_base == 1:
1921+
return [alloc_like(1, node.outputs[0], fgraph)]
1922+
1923+
cst_exponent = get_underlying_scalar_constant_value(
1924+
node.inputs[1], only_process_constants=True, raise_not_constant=False
1925+
)
1926+
if cst_exponent == 0:
1927+
return [alloc_like(1, node.outputs[0], fgraph)]
1928+
if cst_exponent == 1:
19421929
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
19431930

19441931

0 commit comments

Comments
 (0)