Skip to content

Commit ac6f89e

Browse files
committed
add rewrite for log(sqrt(x))
1 parent 892a8f0 commit ac6f89e

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,27 @@ def local_sqrt_sqr(fgraph, node):
552552
return [new_out]
553553

554554

555+
@register_canonicalize
556+
@node_rewriter([log])
557+
def local_log_sqrt(fgraph, node):
558+
x = node.inputs[0]
559+
560+
if not (x.owner and isinstance(x.owner.op, Elemwise)):
561+
return
562+
563+
prev_op = x.owner.op.scalar_op
564+
node_op = node.op.scalar_op
565+
566+
if isinstance(prev_op, ps.Sqrt) and isinstance(node_op, ps.Log):
567+
# Case for log(sqrt(x)) -> 0.5 * log(x)
568+
x = x.owner.inputs[0]
569+
old_out = node.outputs[0]
570+
new_out = mul(0.5, log(x))
571+
if new_out.dtype != old_out.dtype:
572+
new_out = cast(new_out, old_out.dtype)
573+
return [new_out]
574+
575+
555576
@register_specialize
556577
@node_rewriter([exp, expm1])
557578
def local_exp_log_nan_switch(fgraph, node):

tests/tensor/rewriting/test_math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,6 +1989,18 @@ def test_exp_log_nested(self, nested_expression, expected_switches):
19891989
assert len(ops_graph) == expected_switches
19901990

19911991

1992+
def test_log_sqrt() -> None:
1993+
x = pt.tensor("x", shape=(None, None))
1994+
out = log(sqrt(x))
1995+
1996+
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])
1997+
1998+
assert equal_computations(
1999+
[out],
2000+
[mul(np.array([[0.5]]), log(x))],
2001+
)
2002+
2003+
19922004
class TestSqrSqrt:
19932005
def setup_method(self):
19942006
mode = get_default_mode()

0 commit comments

Comments
 (0)