@@ -535,30 +535,59 @@ def local_mul_pow_to_pow_add(fgraph, node):
535535@register_stabilize
536536@register_specialize
537537@register_canonicalize
538- @node_rewriter ([sub ])
538+ @node_rewriter ([add , sub ])
539539def local_expm1 (fgraph , node ):
540- """Detect ``exp(a) - 1`` and convert them to ``expm1(a)``."""
541- in1 , in2 = node .inputs
542- out = node .outputs [0 ]
540+ """Detect ``exp(a) - 1`` or ``-1 + exp(a)`` and convert them to ``expm1(a)``."""
541+ if len (node .inputs ) != 2 :
542+ # TODO: handle more than two inputs in add
543+ return None
543544
544- if (
545- in1 .owner
546- and isinstance (in1 .owner .op , Elemwise )
547- and isinstance (in1 .owner .op .scalar_op , ps .Exp )
548- and get_underlying_scalar_constant_value (in2 , raise_not_constant = False ) == 1
549- ):
550- in11 = in1 .owner .inputs [0 ]
551- new_out = expm1 (in11 )
545+ if isinstance (node .op .scalar_op , ps .Sub ):
546+ exp_x , other_inp = node .inputs
547+ if not (
548+ exp_x .owner
549+ and isinstance (exp_x .owner .op , Elemwise )
550+ and isinstance (exp_x .owner .op .scalar_op , ps .Exp )
551+ and get_underlying_scalar_constant_value (
552+ other_inp , raise_not_constant = False
553+ )
554+ == 1
555+ ):
556+ return None
557+ else :
558+ # Try both orders
559+ other_inp , exp_x = node .inputs
560+ for i in range (2 ):
561+ if i == 1 :
562+ other_inp , exp_x = exp_x , other_inp
563+ if (
564+ exp_x .owner
565+ and isinstance (exp_x .owner .op , Elemwise )
566+ and isinstance (exp_x .owner .op .scalar_op , ps .Exp )
567+ and get_underlying_scalar_constant_value (
568+ other_inp , raise_not_constant = False
569+ )
570+ == - 1
571+ ):
572+ break
573+ else : # no break
574+ return None
552575
553- if new_out .type .broadcastable != out .type .broadcastable :
554- new_out = broadcast_arrays (in11 , in2 )[0 ]
576+ [old_out ] = node .outputs
555577
556- if new_out .dtype != out .dtype :
557- new_out = cast (new_out , dtype = out .dtype )
578+ [x ] = exp_x .owner .inputs
579+ if x .type .broadcastable != old_out .type .broadcastable :
580+ x = broadcast_arrays (x , other_inp )[0 ]
558581
559- if not out .type .is_super (new_out .type ):
560- return
561- return [new_out ]
582+ new_out = expm1 (x )
583+
584+ if new_out .dtype != old_out .dtype :
585+ new_out = cast (new_out , dtype = old_out .dtype )
586+
587+ if not old_out .type .is_super (new_out .type ):
588+ return None
589+
590+ return [new_out ]
562591
563592
564593@register_specialize
@@ -1824,15 +1853,6 @@ def local_add_neg_to_sub(fgraph, node):
18241853 new_out = sub (first , pre_neg )
18251854 return [new_out ]
18261855
1827- # Check if it is a negative constant
1828- if (
1829- isinstance (second , TensorConstant )
1830- and second .unique_value is not None
1831- and second .unique_value < 0
1832- ):
1833- new_out = sub (first , np .abs (second .data ))
1834- return [new_out ]
1835-
18361856
18371857@register_canonicalize
18381858@node_rewriter ([mul ])
@@ -2606,9 +2626,9 @@ def local_greedy_distributor(fgraph, node):
26062626register_stabilize (local_one_minus_erfc )
26072627register_specialize (local_one_minus_erfc )
26082628
2609- # erfc(-x)-1 =>erf(x)
2629+ # -1 + erfc(-x)=>erf(x)
26102630local_erf_neg_minus_one = PatternNodeRewriter (
2611- (sub , (erfc , (neg , "x" )), 1 ),
2631+ (add , - 1 , (erfc , (neg , "x" ))),
26122632 (erf , "x" ),
26132633 allow_multiple_clients = True ,
26142634 name = "local_erf_neg_minus_one" ,
0 commit comments