File tree Expand file tree Collapse file tree 1 file changed +10
-11
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 1 file changed +10
-11
lines changed Original file line number Diff line number Diff line change @@ -557,19 +557,18 @@ def local_sqrt_sqr(fgraph, node):
557
557
def local_log_sqrt (fgraph , node ):
558
558
x = node .inputs [0 ]
559
559
560
- if not (x .owner and isinstance (x .owner .op , Elemwise )):
560
+ if not (x .owner and isinstance (x .owner .op , Elemwise )) or not isinstance (
561
+ x .owner .op .scalar_op , ps .Sqrt
562
+ ):
561
563
return
562
564
563
- prev_op = x .owner .op .scalar_op
564
-
565
- if isinstance (prev_op , ps .Sqrt ):
566
- # Case for log(sqrt(x)) -> 0.5 * log(x)
567
- x = x .owner .inputs [0 ]
568
- old_out = node .outputs [0 ]
569
- new_out = mul (0.5 , log (x ))
570
- if new_out .dtype != old_out .dtype :
571
- new_out = cast (new_out , old_out .dtype )
572
- return [new_out ]
565
+ # Case for log(sqrt(x)) -> 0.5 * log(x)
566
+ x = x .owner .inputs [0 ]
567
+ old_out = node .outputs [0 ]
568
+ new_out = mul (0.5 , log (x ))
569
+ if new_out .dtype != old_out .dtype :
570
+ new_out = cast (new_out , old_out .dtype )
571
+ return [new_out ]
573
572
574
573
575
574
@register_specialize
You can’t perform that action at this time.
0 commit comments