|
1 | 1 | module DiffEqBaseEnzymeExt |
2 | 2 |
|
3 | 3 | using DiffEqBase |
4 | | -import DiffEqBase: value |
| 4 | +import DiffEqBase: value, fastpow |
5 | 5 | using Enzyme |
6 | 6 | import Enzyme: Const |
7 | 7 | using ChainRulesCore |
@@ -53,4 +53,38 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1} |
53 | 53 | return ntuple(_ -> nothing, Val(length(args) + 4)) |
54 | 54 | end |
55 | 55 |
|
| 56 | +function EnzymeRules.forward(func::Const{typeof(fastpow)}, |
| 57 | + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, |
| 58 | + BatchDuplicated,BatchDuplicatedNoNeed}}, |
| 59 | + _x::Annotation, _y::Annotation) |
| 60 | + x = _x.val |
| 61 | + y = _y.val |
| 62 | + ret = func.val(x.val, y.val) |
| 63 | + dxval = x.dval * y * (fastpow(x,y - 1)) |
| 64 | + dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x) |
| 65 | + return Duplicated(ret, dxval + dyval) |
| 66 | +end |
| 67 | + |
| 68 | +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, |
| 69 | + func::Const{typeof(fastpow)}, |
| 70 | + ::Type{<:Active}, |
| 71 | + x::Active, x::Active) |
| 72 | + if EnzymeRules.needs_primal(config) |
| 73 | + primal = func.val(x.val, y.val) |
| 74 | + else |
| 75 | + primal = nothing |
| 76 | + end |
| 77 | + return EnzymeRules.AugmentedReturn(primal, nothing, nothing) |
| 78 | +end |
| 79 | + |
| 80 | +function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1}, |
| 81 | + func::Const{DiffEqBase.fastpow}, dret, tape::Nothing, |
| 82 | + _x, _y) |
| 83 | + x = _x.val |
| 84 | + y = _y.val |
| 85 | + dxval = x.dval * y * (fastpow(x,y - 1)) |
| 86 | + dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x) |
| 87 | + return (dxval, dyval) |
| 88 | +end |
| 89 | + |
56 | 90 | end |
0 commit comments