diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index aa2d279f43..03fa1ae094 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1905,13 +1905,40 @@ def local_reciprocal_canon(fgraph, node): @register_canonicalize @node_rewriter([pt_pow]) def local_pow_canonicalize(fgraph, node): - cst = get_underlying_scalar_constant_value( + """ + Rewrites for exponential functions with straight-forward simplifications: + 1. x ** 0 -> 1 + 2. x ** 1 -> x + 3. 1 ** x -> 1 + + In all cases, the shape of the output is the result of broadcasting the shapes of the inputs. + """ + cst_base = get_underlying_scalar_constant_value( + node.inputs[0], only_process_constants=True, raise_not_constant=False + ) + cst_exponent = get_underlying_scalar_constant_value( node.inputs[1], only_process_constants=True, raise_not_constant=False ) - if cst == 0: - return [alloc_like(1, node.outputs[0], fgraph)] - if cst == 1: - return [alloc_like(node.inputs[0], node.outputs[0], fgraph)] + + new_out = None + + if cst_base == 1: + # 1 ** x = 1 + new_out = broadcast_arrays(*node.inputs)[0] + elif cst_exponent == 0: + # x ** 0 = 1 + new_out = broadcast_arrays(ones_like(node.inputs[0]), node.inputs[1])[0] + elif cst_exponent == 1: + # x ** 1 = x + new_out = broadcast_arrays(*node.inputs)[0] + + if not new_out: + return + + if new_out.dtype != node.out.dtype: + new_out = cast(new_out, dtype=node.out.dtype) + + return [new_out] @register_specialize diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index debcf44c64..d344d29dad 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4571,3 +4571,22 @@ def test_log_kv_stabilization(): out.eval({x: 1000.0}, mode=mode), -1003.2180912984705, ) + + +@pytest.mark.parametrize("shape", [(), (4, 5, 6)], ids=["scalar", "tensor"]) +def test_pow_1_rewrite(shape): + x = pt.tensor("x", shape=shape) + z = 1**x + + assert isinstance(z.owner.op, Elemwise) and isinstance( + z.owner.op.scalar_op, ps.basic.Pow + ) + + f = pytensor.function([x], z) + assert not any( + isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.basic.Pow) + for node in f.maker.fgraph.toposort() + ) + + x_val = np.random.random(shape).astype(config.floatX) + np.testing.assert_allclose(z.eval({x: x_val}), f(x_val))