Skip to content

Commit b55c473

Browse files
authored
Add rewrite for log(sqrt(x)) (#1555)
1 parent 892a8f0 commit b55c473

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,29 @@ def local_sqrt_sqr(fgraph, node):
552552
return [new_out]
553553

554554

555+
@register_specialize
556+
@node_rewriter([log])
557+
def local_log_sqrt(fgraph, node):
558+
x = node.inputs[0]
559+
560+
if (
561+
not x.owner
562+
or not isinstance(x.owner.op, Elemwise)
563+
or not isinstance(x.owner.op.scalar_op, ps.Sqrt)
564+
):
565+
return
566+
567+
# Case for log(sqrt(x)) -> 0.5 * log(x)
568+
x = x.owner.inputs[0]
569+
old_out = node.outputs[0]
570+
new_out = mul(as_tensor_variable(0.5, dtype=x.dtype), log(x))
571+
if new_out.dtype != old_out.dtype:
572+
new_out = cast(new_out, old_out.dtype)
573+
574+
copy_stack_trace(node.out, new_out)
575+
return [new_out]
576+
577+
555578
@register_specialize
556579
@node_rewriter([exp, expm1])
557580
def local_exp_log_nan_switch(fgraph, node):

tests/tensor/rewriting/test_math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,6 +1989,18 @@ def test_exp_log_nested(self, nested_expression, expected_switches):
19891989
assert len(ops_graph) == expected_switches
19901990

19911991

1992+
def test_log_sqrt() -> None:
1993+
x = pt.tensor("x", shape=(None, None))
1994+
out = log(sqrt(x))
1995+
1996+
out = rewrite_graph(out, include=["specialize"])
1997+
1998+
assert utt.assert_equal_computations(
1999+
[out],
2000+
[mul(pt.as_tensor_variable([[0.5]], dtype=x.dtype), log(x))],
2001+
)
2002+
2003+
19922004
class TestSqrSqrt:
19932005
def setup_method(self):
19942006
mode = get_default_mode()

0 commit comments

Comments
 (0)