Skip to content

Commit 67390dc

Browse files
committed
combine if block
1 parent 49cab31 commit 67390dc

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -557,19 +557,18 @@ def local_sqrt_sqr(fgraph, node):
557557
def local_log_sqrt(fgraph, node):
558558
x = node.inputs[0]
559559

560-
if not (x.owner and isinstance(x.owner.op, Elemwise)):
560+
if not (x.owner and isinstance(x.owner.op, Elemwise)) or not isinstance(
561+
x.owner.op.scalar_op, ps.Sqrt
562+
):
561563
return
562564

563-
prev_op = x.owner.op.scalar_op
564-
565-
if isinstance(prev_op, ps.Sqrt):
566-
# Case for log(sqrt(x)) -> 0.5 * log(x)
567-
x = x.owner.inputs[0]
568-
old_out = node.outputs[0]
569-
new_out = mul(0.5, log(x))
570-
if new_out.dtype != old_out.dtype:
571-
new_out = cast(new_out, old_out.dtype)
572-
return [new_out]
565+
# Case for log(sqrt(x)) -> 0.5 * log(x)
566+
x = x.owner.inputs[0]
567+
old_out = node.outputs[0]
568+
new_out = mul(0.5, log(x))
569+
if new_out.dtype != old_out.dtype:
570+
new_out = cast(new_out, old_out.dtype)
571+
return [new_out]
573572

574573

575574
@register_specialize

0 commit comments

Comments
 (0)