Skip to content

Commit 97bd1c4

Browse files
Add rewrite for 1 ** x = 1
1 parent 2f1d25a commit 97bd1c4

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import pytensor.scalar.basic as ps
1111
import pytensor.scalar.math as ps_math
12-
from pytensor.graph.basic import Constant, Variable
12+
from pytensor.graph import FunctionGraph
13+
from pytensor.graph.basic import Apply, Constant, Variable
1314
from pytensor.graph.rewriting.basic import (
1415
NodeRewriter,
1516
PatternNodeRewriter,
@@ -1914,6 +1915,33 @@ def local_pow_canonicalize(fgraph, node):
19141915
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
19151916

19161917

1918+
@register_canonicalize
1919+
@node_rewriter([pt_pow])
1920+
def local_pow_canonicalize_base_1(
1921+
fgraph: FunctionGraph, node: Apply
1922+
) -> list[TensorVariable] | None:
1923+
"""
1924+
Replace `1 ** x` with 1, broadcast to the shape of the output.
1925+
1926+
Parameters
1927+
----------
1928+
fgraph: FunctionGraph
1929+
Full function graph being rewritten
1930+
node: Apply
1931+
Specific node being rewritten
1932+
1933+
Returns
1934+
-------
1935+
rewritten_output: list[TensorVariable] | None
1936+
Rewritten output of node, or None if no rewrite is possible
1937+
"""
1938+
cst = get_underlying_scalar_constant_value(
1939+
node.inputs[0], only_process_constants=True, raise_not_constant=False
1940+
)
1941+
if cst == 1:
1942+
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
1943+
1944+
19171945
@register_specialize
19181946
@node_rewriter([mul])
19191947
def local_mul_to_sqr(fgraph, node):

tests/tensor/rewriting/test_math.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4571,3 +4571,22 @@ def test_log_kv_stabilization():
45714571
out.eval({x: 1000.0}, mode=mode),
45724572
-1003.2180912984705,
45734573
)
4574+
4575+
4576+
@pytest.mark.parametrize("shape", [(), (4, 5, 6)], ids=["scalar", "tensor"])
4577+
def test_pow_1_rewrite(shape):
4578+
x = pt.tensor("x", shape=shape)
4579+
z = 1**x
4580+
4581+
f1 = pytensor.function([x], z, mode=get_default_mode().excluding("canonicalize"))
4582+
assert debugprint(f1, file="str").count("Pow") == 1
4583+
4584+
x_val = np.random.random(shape).astype(config.floatX)
4585+
z_val_1 = f1(x_val)
4586+
4587+
f2 = pytensor.function([x], z)
4588+
assert debugprint(f2, file="str").count("Pow") == 0
4589+
4590+
z_val_2 = f2(x_val)
4591+
4592+
np.testing.assert_allclose(z_val_1, z_val_2)

0 commit comments

Comments
 (0)