Skip to content

Commit 1effe56

Browse files
Merge pull request #1222 from wsmoses/cer
Correct enzymerules
2 parents 222eb95 + d739f7e commit 1effe56

File tree

2 files changed

+60
-31
lines changed

2 files changed

+60
-31
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ ChainRulesCore = "1"
7171
ConcreteStructs = "0.2.3"
7272
Distributions = "0.25"
7373
DocStringExtensions = "0.9"
74-
Enzyme = "0.13"
74+
Enzyme = "0.13.100"
7575
EnzymeCore = "0.7, 0.8"
7676
FastBroadcast = "0.3.5"
7777
FastClosures = "0.3.2"

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,54 +7,83 @@ module DiffEqBaseEnzymeExt
77
import Enzyme: Const
88
using ChainRulesCore
99

10+
11+
@inline function copy_or_reuse(config, val, idx)
12+
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
13+
return deepcopy(val)
14+
else
15+
return val
16+
end
17+
end
18+
19+
@inline function arg_copy(data, i)
20+
config, args = data
21+
copy_or_reuse(config, args[i].val, i + 5)
22+
end
23+
24+
# Note these following functions are generally not considered user facing from within Enzyme.
25+
# They enable additional performance/usability here (e.g. inactive kwargs).
26+
# Contact wsmoses@ before modifying (and beware their semantics may change without semver).
27+
28+
Enzyme.EnzymeRules.inactive_kwarg(::typeof(DiffEqBase.solve_up), prob, sensalg::Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, u0, p, args...; kwargs...) = nothing
29+
30+
Enzyme.EnzymeRules.has_easy_rule(::typeof(DiffEqBase.solve_up), prob, sensalg::Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm}, u0, p, args...; kwargs...) = nothing
31+
1032
function Enzyme.EnzymeRules.augmented_primal(
1133
config::Enzyme.EnzymeRules.RevConfigWidth{1},
12-
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob,
34+
func::Const{typeof(DiffEqBase.solve_up)}, RTA::Type{Duplicated{RT}}, prob,
1335
sensealg::Union{
1436
Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
1537
u0, p, args...; kwargs...) where {RT}
16-
@inline function copy_or_reuse(val, idx)
17-
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
18-
return deepcopy(val)
19-
else
20-
return val
21-
end
22-
end
23-
24-
@inline function arg_copy(i)
25-
copy_or_reuse(args[i].val, i + 5)
26-
end
2738

2839
res = DiffEqBase._solve_adjoint(
29-
copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3),
30-
copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5),
31-
SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...;
40+
copy_or_reuse(config, prob.val, 2), copy_or_reuse(config, sensealg.val, 3),
41+
copy_or_reuse(config, u0.val, 4), copy_or_reuse(config, p.val, 5),
42+
SciMLBase.EnzymeOriginator(), ntuple(Base.Fix1(arg_copy, (config, args)), Val(length(args)))...;
3243
kwargs...)
3344

34-
dres = Enzyme.make_zero(res[1])::RT
35-
tup = (dres, res[2])
36-
return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
45+
primal = if Enzyme.EnzymeRules.needs_primal(config)
46+
res[1]
47+
else
48+
nothing
49+
end
50+
51+
shadow = if Enzyme.EnzymeRules.needs_shadow(config)
52+
Enzyme.make_zero(res[1])::RT
53+
else
54+
nothing
55+
end
56+
tup = if Enzyme.EnzymeRules.needs_shadow(config)
57+
(shadow, res[2])
58+
else
59+
nothing
60+
end
61+
return Enzyme.EnzymeRules.augmented_rule_return_type(config, RTA)(primal, shadow, tup)
3762
end
3863

3964
function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1},
4065
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob,
4166
sensealg::Union{
4267
Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
4368
u0, p, args...; kwargs...) where {RT}
44-
dres, clos = tape
45-
dres = dres::RT
46-
dargs = clos(dres)
47-
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
48-
if ptr isa Enzyme.Const
49-
continue
50-
end
51-
if darg == ChainRulesCore.NoTangent()
52-
continue
69+
70+
if Enzyme.EnzymeRules.needs_shadow(config)
71+
dres, clos = tape
72+
dres = dres::RT
73+
dargs = clos(dres)
74+
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
75+
if ptr isa Enzyme.Const
76+
continue
77+
end
78+
if darg == ChainRulesCore.NoTangent()
79+
continue
80+
end
81+
ptr.dval .+= darg
5382
end
54-
ptr.dval .+= darg
83+
Enzyme.make_zero!(dres.u)
5584
end
56-
Enzyme.make_zero!(dres.u)
57-
return ntuple(_ -> nothing, Val(length(args) + 4))
85+
86+
return ntuple(Returns(nothing), Val(length(args) + 4))
5887
end
5988
end
6089

0 commit comments

Comments
 (0)