@@ -53,22 +53,31 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}
5353 return ntuple (_ -> nothing , Val (length (args) + 4 ))
5454end
5555
56- function EnzymeRules. forward (func:: Const{typeof(fastpow)} ,
57- RT:: Type {<: Union {Const,DuplicatedNoNeed,Duplicated,
58- BatchDuplicated,BatchDuplicatedNoNeed}},
59- _x:: Annotation , _y:: Annotation )
56+ function Enzyme. EnzymeRules. forward (func:: Const{typeof(DiffEqBase.fastpow)} ,
57+ RT:: Type{<:Union{Duplicated, DuplicatedNoNeed}} ,
58+ _x:: Union{Const, Duplicated} , _y:: Union{Const, Duplicated} )
6059 x = _x. val
6160 y = _y. val
62- ret = func. val (x. val, y. val)
63- dxval = x. dval * y * (fastpow (x,y - 1 ))
64- dyval = x isa Real && x<= 0 ? Base. oftype (float (x), NaN ) : y. dval* (fastpow (x,y))* log (x)
65- return Duplicated (ret, dxval + dyval)
61+ ret = func. val (x, y)
62+ if ! (_x isa Const)
63+ dxval = _x. dval * y * (fastpow (x,y - 1 ))
64+ else
65+ dxval = make_zero (_x. val)
66+ end
67+ if ! (_y isa Const)
68+ dyval = x isa Real && x<= 0 ? Base. oftype (float (x), NaN ) : _y. dval* (fastpow (x,y))* log (x)
69+ else
70+ dyval = make_zero (_y. val)
71+ end
72+ if RT <: DuplicatedNoNeed
73+ return Float32 (dxval + dyval)
74+ else
75+ return Duplicated (ret, Float32 (dxval + dyval))
76+ end
6677end
6778
68- function EnzymeRules. augmented_primal (config:: ConfigWidth{1} ,
69- func:: Const{typeof(fastpow)} ,
70- :: Type{<:Active} ,
71- x:: Active , x:: Active )
79+ function EnzymeRules. augmented_primal (config:: Enzyme.EnzymeRules.ConfigWidth{1} ,
80+ func:: Const{typeof(fastpow)} , :: Type{<:Active} , x:: Active , y:: Active )
7281 if EnzymeRules. needs_primal (config)
7382 primal = func. val (x. val, y. val)
7483 else
@@ -77,14 +86,13 @@ function EnzymeRules.augmented_primal(config::ConfigWidth{1},
7786 return EnzymeRules. AugmentedReturn (primal, nothing , nothing )
7887end
7988
80- function EnzymeRules. reverse (config:: EnzymeRules.ConfigWidth{1} ,
81- func:: Const{DiffEqBase.fastpow} , dret, tape:: Nothing ,
82- _x, _y)
89+ function EnzymeRules. reverse (config:: Enzyme.EnzymeRules.ConfigWidth{1} ,
90+ func:: Const{typeof(DiffEqBase.fastpow)} , dret:: Active , tape, _x:: Active , _y:: Active )
8391 x = _x. val
8492 y = _y. val
85- dxval = x . dval * y * (fastpow (x,y - 1 ))
86- dyval = x isa Real && x<= 0 ? Base. oftype (float (x), NaN ) : y . dval * (fastpow (x,y))* log (x)
93+ dxval = y * (fastpow (x,y - 1 ))
94+ dyval = x isa Real && x<= 0 ? Base. oftype (float (x), NaN ) : (fastpow (x,y))* log (x)
8795 return (dxval, dyval)
8896end
8997
90- end
98+ end
0 commit comments