| 
1 | 1 | module DiffEqBaseEnzymeExt  | 
2 | 2 | 
 
  | 
3 | 3 | using DiffEqBase  | 
4 |  | -import DiffEqBase: value  | 
 | 4 | +import DiffEqBase: value, fastpow  | 
5 | 5 | using Enzyme  | 
6 | 6 | import Enzyme: Const  | 
7 | 7 | using ChainRulesCore  | 
@@ -53,4 +53,38 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1},  | 
53 | 53 |     return ntuple(_ -> nothing, Val(length(args) + 4))  | 
54 | 54 | end  | 
55 | 55 | 
 
  | 
 | 56 | +function EnzymeRules.forward(func::Const{typeof(fastpow)},  | 
 | 57 | +            RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated,  | 
 | 58 | +            BatchDuplicated,BatchDuplicatedNoNeed}},  | 
 | 59 | +            _x::Annotation, _y::Annotation)  | 
 | 60 | +    x = _x.val  | 
 | 61 | +    y = _y.val  | 
 | 62 | +    ret = func.val(x.val, y.val)  | 
 | 63 | +    dxval = x.dval * y * (fastpow(x,y - 1))  | 
 | 64 | +    dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x)  | 
 | 65 | +    return Duplicated(ret, dxval + dyval)  | 
 | 66 | +end  | 
 | 67 | + | 
 | 68 | +function EnzymeRules.augmented_primal(config::ConfigWidth{1},   | 
 | 69 | +                          func::Const{typeof(fastpow)},   | 
 | 70 | +                          ::Type{<:Active},  | 
 | 71 | +                          x::Active, x::Active)  | 
 | 72 | +    if EnzymeRules.needs_primal(config)  | 
 | 73 | +        primal = func.val(x.val, y.val)  | 
 | 74 | +    else  | 
 | 75 | +        primal = nothing  | 
 | 76 | +    end  | 
 | 77 | +    return EnzymeRules.AugmentedReturn(primal, nothing, nothing)  | 
 | 78 | +end  | 
 | 79 | + | 
 | 80 | +function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1},   | 
 | 81 | +                            func::Const{DiffEqBase.fastpow}, dret, tape::Nothing,  | 
 | 82 | +                            _x, _y)  | 
 | 83 | +    x = _x.val  | 
 | 84 | +    y = _y.val  | 
 | 85 | +    dxval = x.dval * y * (fastpow(x,y - 1))  | 
 | 86 | +    dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x)  | 
 | 87 | +    return (dxval, dyval)  | 
 | 88 | +end  | 
 | 89 | + | 
56 | 90 | end  | 
0 commit comments