@@ -535,30 +535,59 @@ def local_mul_pow_to_pow_add(fgraph, node):
535
535
@register_stabilize
536
536
@register_specialize
537
537
@register_canonicalize
538
- @node_rewriter ([sub ])
538
+ @node_rewriter ([add , sub ])
539
539
def local_expm1 (fgraph , node ):
540
- """Detect ``exp(a) - 1`` and convert them to ``expm1(a)``."""
541
- in1 , in2 = node .inputs
542
- out = node .outputs [0 ]
540
+ """Detect ``exp(a) - 1`` or ``-1 + exp(a)`` and convert them to ``expm1(a)``."""
541
+ if len (node .inputs ) != 2 :
542
+ # TODO: handle more than two inputs in add
543
+ return None
543
544
544
- if (
545
- in1 .owner
546
- and isinstance (in1 .owner .op , Elemwise )
547
- and isinstance (in1 .owner .op .scalar_op , ps .Exp )
548
- and get_underlying_scalar_constant_value (in2 , raise_not_constant = False ) == 1
549
- ):
550
- in11 = in1 .owner .inputs [0 ]
551
- new_out = expm1 (in11 )
545
+ if isinstance (node .op .scalar_op , ps .Sub ):
546
+ exp_x , other_inp = node .inputs
547
+ if not (
548
+ exp_x .owner
549
+ and isinstance (exp_x .owner .op , Elemwise )
550
+ and isinstance (exp_x .owner .op .scalar_op , ps .Exp )
551
+ and get_underlying_scalar_constant_value (
552
+ other_inp , raise_not_constant = False
553
+ )
554
+ == 1
555
+ ):
556
+ return None
557
+ else :
558
+ # Try both orders
559
+ other_inp , exp_x = node .inputs
560
+ for i in range (2 ):
561
+ if i == 1 :
562
+ other_inp , exp_x = exp_x , other_inp
563
+ if (
564
+ exp_x .owner
565
+ and isinstance (exp_x .owner .op , Elemwise )
566
+ and isinstance (exp_x .owner .op .scalar_op , ps .Exp )
567
+ and get_underlying_scalar_constant_value (
568
+ other_inp , raise_not_constant = False
569
+ )
570
+ == - 1
571
+ ):
572
+ break
573
+ else : # no break
574
+ return None
552
575
553
- if new_out .type .broadcastable != out .type .broadcastable :
554
- new_out = broadcast_arrays (in11 , in2 )[0 ]
576
+ [old_out ] = node .outputs
555
577
556
- if new_out .dtype != out .dtype :
557
- new_out = cast (new_out , dtype = out .dtype )
578
+ [x ] = exp_x .owner .inputs
579
+ if x .type .broadcastable != old_out .type .broadcastable :
580
+ x = broadcast_arrays (x , other_inp )[0 ]
558
581
559
- if not out .type .is_super (new_out .type ):
560
- return
561
- return [new_out ]
582
+ new_out = expm1 (x )
583
+
584
+ if new_out .dtype != old_out .dtype :
585
+ new_out = cast (new_out , dtype = old_out .dtype )
586
+
587
+ if not old_out .type .is_super (new_out .type ):
588
+ return None
589
+
590
+ return [new_out ]
562
591
563
592
564
593
@register_specialize
@@ -1824,15 +1853,6 @@ def local_add_neg_to_sub(fgraph, node):
1824
1853
new_out = sub (first , pre_neg )
1825
1854
return [new_out ]
1826
1855
1827
- # Check if it is a negative constant
1828
- if (
1829
- isinstance (second , TensorConstant )
1830
- and second .unique_value is not None
1831
- and second .unique_value < 0
1832
- ):
1833
- new_out = sub (first , np .abs (second .data ))
1834
- return [new_out ]
1835
-
1836
1856
1837
1857
@register_canonicalize
1838
1858
@node_rewriter ([mul ])
@@ -2606,9 +2626,9 @@ def local_greedy_distributor(fgraph, node):
2606
2626
register_stabilize (local_one_minus_erfc )
2607
2627
register_specialize (local_one_minus_erfc )
2608
2628
2609
- # erfc(-x)-1 =>erf(x)
2629
+ # -1 + erfc(-x)=>erf(x)
2610
2630
local_erf_neg_minus_one = PatternNodeRewriter (
2611
- (sub , (erfc , (neg , "x" )), 1 ),
2631
+ (add , - 1 , (erfc , (neg , "x" ))),
2612
2632
(erf , "x" ),
2613
2633
allow_multiple_clients = True ,
2614
2634
name = "local_erf_neg_minus_one" ,
0 commit comments