diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index acc8becbfb..327e14951e 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -552,6 +552,29 @@ def local_sqrt_sqr(fgraph, node): return [new_out] +@register_specialize +@node_rewriter([log]) +def local_log_sqrt(fgraph, node): + x = node.inputs[0] + + if ( + not x.owner + or not isinstance(x.owner.op, Elemwise) + or not isinstance(x.owner.op.scalar_op, ps.Sqrt) + ): + return + + # Case for log(sqrt(x)) -> 0.5 * log(x) + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = mul(as_tensor_variable(0.5, dtype=x.dtype), log(x)) + if new_out.dtype != old_out.dtype: + new_out = cast(new_out, old_out.dtype) + + copy_stack_trace(node.out, new_out) + return [new_out] + + @register_specialize @node_rewriter([exp, expm1]) def local_exp_log_nan_switch(fgraph, node): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index a6e734ae82..43f65c2282 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -1989,6 +1989,18 @@ def test_exp_log_nested(self, nested_expression, expected_switches): assert len(ops_graph) == expected_switches +def test_log_sqrt() -> None: + x = pt.tensor("x", shape=(None, None)) + out = log(sqrt(x)) + + out = rewrite_graph(out, include=["specialize"]) + + assert utt.assert_equal_computations( + [out], + [mul(pt.as_tensor_variable([[0.5]], dtype=x.dtype), log(x))], + ) + + class TestSqrSqrt: def setup_method(self): mode = get_default_mode()