Skip to content

Commit 8c7b8f1

Browse files
WIP: add Enzyme support for fastpow
Straightforward since fastpow is simply ^. Still needs: - [ ] Tests - [ ] Generalize to batchduplicated
1 parent 84cbb9d commit 8c7b8f1

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module DiffEqBaseEnzymeExt
22

33
using DiffEqBase
4-
import DiffEqBase: value
4+
import DiffEqBase: value, fastpow
55
using Enzyme
66
import Enzyme: Const
77
using ChainRulesCore
@@ -53,4 +53,38 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1},
5353
return ntuple(_ -> nothing, Val(length(args) + 4))
5454
end
5555

56+
function EnzymeRules.forward(func::Const{typeof(fastpow)},
57+
RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated,
58+
BatchDuplicated,BatchDuplicatedNoNeed}},
59+
_x::Annotation, _y::Annotation)
60+
x = _x.val
61+
y = _y.val
62+
ret = func.val(x.val, y.val)
63+
dxval = x.dval * y * (fastpow(x,y - 1))
64+
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x)
65+
return Duplicated(ret, dxval + dyval)
66+
end
67+
68+
function EnzymeRules.augmented_primal(config::ConfigWidth{1},
69+
func::Const{typeof(fastpow)},
70+
::Type{<:Active},
71+
x::Active, x::Active)
72+
if EnzymeRules.needs_primal(config)
73+
primal = func.val(x.val, y.val)
74+
else
75+
primal = nothing
76+
end
77+
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
78+
end
79+
80+
function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1},
81+
func::Const{DiffEqBase.fastpow}, dret, tape::Nothing,
82+
_x, _y)
83+
x = _x.val
84+
y = _y.val
85+
dxval = x.dval * y * (fastpow(x,y - 1))
86+
dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : y.dval*(fastpow(x,y))*log(x)
87+
return (dxval, dyval)
88+
end
89+
5690
end

0 commit comments

Comments
 (0)