Skip to content

Commit 77692d6

Browse files
Merge pull request #1171 from SciML/enzyme_prerelease
Disable Enzyme extension on prerelease
2 parents a7c72ca + 6c321fa commit 77692d6

File tree

2 files changed

+48
-45
lines changed

2 files changed

+48
-45
lines changed

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,59 @@
11
module DiffEqBaseEnzymeExt
22

3-
using DiffEqBase
4-
import DiffEqBase: value
5-
using Enzyme
6-
import Enzyme: Const
7-
using ChainRulesCore
8-
9-
function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{1},
10-
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob,
11-
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
12-
u0, p, args...; kwargs...) where {RT}
13-
@inline function copy_or_reuse(val, idx)
14-
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
15-
return deepcopy(val)
16-
else
17-
return val
3+
@static if isempty(VERSION.prerelease)
4+
using DiffEqBase
5+
import DiffEqBase: value
6+
using Enzyme
7+
import Enzyme: Const
8+
using ChainRulesCore
9+
10+
function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{1},
11+
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob,
12+
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
13+
u0, p, args...; kwargs...) where {RT}
14+
@inline function copy_or_reuse(val, idx)
15+
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
16+
return deepcopy(val)
17+
else
18+
return val
19+
end
1820
end
19-
end
2021

21-
@inline function arg_copy(i)
22-
copy_or_reuse(args[i].val, i + 5)
23-
end
22+
@inline function arg_copy(i)
23+
copy_or_reuse(args[i].val, i + 5)
24+
end
2425

25-
res = DiffEqBase._solve_adjoint(
26-
copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3),
27-
copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5),
28-
SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...;
29-
kwargs...)
26+
res = DiffEqBase._solve_adjoint(
27+
copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3),
28+
copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5),
29+
SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...;
30+
kwargs...)
3031

31-
dres = Enzyme.make_zero(res[1])::RT
32-
tup = (dres, res[2])
33-
return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
34-
end
32+
dres = Enzyme.make_zero(res[1])::RT
33+
tup = (dres, res[2])
34+
return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
35+
end
3536

36-
function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1},
37-
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob,
38-
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
39-
u0, p, args...; kwargs...) where {RT}
40-
dres, clos = tape
41-
dres = dres::RT
42-
dargs = clos(dres)
43-
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
44-
if ptr isa Enzyme.Const
45-
continue
37+
function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1},
38+
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob,
39+
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
40+
u0, p, args...; kwargs...) where {RT}
41+
dres, clos = tape
42+
dres = dres::RT
43+
dargs = clos(dres)
44+
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
45+
if ptr isa Enzyme.Const
46+
continue
47+
end
48+
if darg == ChainRulesCore.NoTangent()
49+
continue
50+
end
51+
ptr.dval .+= darg
4652
end
47-
if darg == ChainRulesCore.NoTangent()
48-
continue
49-
end
50-
ptr.dval .+= darg
53+
Enzyme.make_zero!(dres.u)
54+
return ntuple(_ -> nothing, Val(length(args) + 4))
5155
end
52-
Enzyme.make_zero!(dres.u)
53-
return ntuple(_ -> nothing, Val(length(args) + 4))
56+
5457
end
5558

5659
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ end
7272
@time @safetestset "Unwrapping" include("downstream/unwrapping.jl")
7373
@time @safetestset "Callback BigFloats" include("downstream/bigfloat_events.jl")
7474
@time @safetestset "DE stats" include("downstream/stats_tests.jl")
75-
@time @safetestset "Ensemble AD Tests" include("downstream/ensemble_ad.jl")
75+
isempty(VERSION.prerelease) && @time @safetestset "Ensemble AD Tests" include("downstream/ensemble_ad.jl")
7676
@time @safetestset "Community Callback Tests" include("downstream/community_callback_tests.jl")
7777
@time @safetestset "AD via ode with complex numbers" include("downstream/complex_number_ad.jl")
7878
@time @testset "Distributed Ensemble Tests" include("downstream/distributed_ensemble.jl")

0 commit comments

Comments
 (0)