Skip to content

Commit 7e3eff9

Browse files
Merge pull request #927 from wsmoses/master
Add Enzyme extension
2 parents eedbeef + 4f72bb1 commit 7e3eff9

File tree

3 files changed

+60
-0
lines changed

3 files changed

+60
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1010
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1111
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1212
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
13+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1314
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
1415
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1516
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
@@ -38,6 +39,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3839

3940
[weakdeps]
4041
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
42+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4143
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
4244
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
4345
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
@@ -57,6 +59,7 @@ DiffEqBaseReverseDiffExt = "ReverseDiff"
5759
DiffEqBaseTrackerExt = "Tracker"
5860
DiffEqBaseUnitfulExt = "Unitful"
5961
DiffEqBaseZygoteExt = "Zygote"
62+
DiffEqBaseEnzymeExt = "Enzyme"
6063

6164
[compat]
6265
ArrayInterface = "7"

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
module DiffEqBaseEnzymeExt
2+
3+
using DiffEqBase
4+
import DiffEqBase: value
5+
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)
6+
7+
using ChainRulesCore
8+
using EnzymeCore
9+
10+
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT
11+
@inline function copy_or_reuse(val, idx)
12+
if EnzymeCore.EnzymeRules.overwritten(config)[idx] && ismutable(val)
13+
return deepcopy(val)
14+
else
15+
return val
16+
end
17+
end
18+
19+
@inline function arg_copy(i)
20+
copy_or_reuse(args[i].val, i+5)
21+
end
22+
23+
res = DiffEqBase._solve_adjoint(copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), SciMLBase.ChainRulesOriginator(), ntuple(arg_copy, Val(length(args)))...;
24+
kwargs...)
25+
26+
dres = deepcopy(res[1])::RT
27+
for v in dres.u
28+
v.= 0
29+
end
30+
tup = (dres, res[2])
31+
return EnzymeCore.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
32+
end
33+
34+
function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{<:Duplicated{RT}}, tape, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT
35+
dres, clos = tape
36+
dres = dres::RT
37+
dargs = clos(dres)
38+
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
39+
if ptr isa EnzymeCore.Const
40+
continue
41+
end
42+
if darg == ChainRulesCore.NoTangent()
43+
continue
44+
end
45+
ptr.dval .+= darg
46+
end
47+
for v in dres.u
48+
v.= 0
49+
end
50+
return ntuple(_ -> nothing, Val(length(args)+4))
51+
end
52+
53+
end

src/init.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,9 @@ end
4747
@require MPI="da04e1cc-30fd-572f-bb4f-1f8673147195" begin
4848
include("../ext/DiffEqBaseMPIExt.jl")
4949
end
50+
51+
@require Enzyme="7da242da-08ed-463a-9acd-ee780be4f1d9" begin
52+
include("../ext/DiffEqBaseEnzymeExt.jl")
53+
end
5054
end
5155
end

0 commit comments

Comments
 (0)