Skip to content

Commit bec8539

Browse files
Disable Enzyme extension on prerelease
This triggers failures here and downstream because if Enzyme is in the environment this tries to precompile. Thus because we know Enzyme won't work on prereleases, we are disabling the extension
1 parent a7c72ca commit bec8539

File tree

1 file changed

+47
-44
lines changed

1 file changed

+47
-44
lines changed

ext/DiffEqBaseEnzymeExt.jl

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,59 @@
11
module DiffEqBaseEnzymeExt
22

3-
using DiffEqBase
4-
import DiffEqBase: value
5-
using Enzyme
6-
import Enzyme: Const
7-
using ChainRulesCore
8-
9-
function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{1},
10-
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob,
11-
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
12-
u0, p, args...; kwargs...) where {RT}
13-
@inline function copy_or_reuse(val, idx)
14-
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
15-
return deepcopy(val)
16-
else
17-
return val
3+
@static if isempty(VERSION.prerelease)
4+
using DiffEqBase
5+
import DiffEqBase: value
6+
using Enzyme
7+
import Enzyme: Const
8+
using ChainRulesCore
9+
10+
function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{1},
11+
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob,
12+
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
13+
u0, p, args...; kwargs...) where {RT}
14+
@inline function copy_or_reuse(val, idx)
15+
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
16+
return deepcopy(val)
17+
else
18+
return val
19+
end
1820
end
19-
end
2021

21-
@inline function arg_copy(i)
22-
copy_or_reuse(args[i].val, i + 5)
23-
end
22+
@inline function arg_copy(i)
23+
copy_or_reuse(args[i].val, i + 5)
24+
end
2425

25-
res = DiffEqBase._solve_adjoint(
26-
copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3),
27-
copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5),
28-
SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...;
29-
kwargs...)
26+
res = DiffEqBase._solve_adjoint(
27+
copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3),
28+
copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5),
29+
SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...;
30+
kwargs...)
3031

31-
dres = Enzyme.make_zero(res[1])::RT
32-
tup = (dres, res[2])
33-
return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
34-
end
32+
dres = Enzyme.make_zero(res[1])::RT
33+
tup = (dres, res[2])
34+
return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
35+
end
3536

36-
function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1},
37-
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob,
38-
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
39-
u0, p, args...; kwargs...) where {RT}
40-
dres, clos = tape
41-
dres = dres::RT
42-
dargs = clos(dres)
43-
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
44-
if ptr isa Enzyme.Const
45-
continue
37+
function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1},
38+
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob,
39+
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
40+
u0, p, args...; kwargs...) where {RT}
41+
dres, clos = tape
42+
dres = dres::RT
43+
dargs = clos(dres)
44+
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
45+
if ptr isa Enzyme.Const
46+
continue
47+
end
48+
if darg == ChainRulesCore.NoTangent()
49+
continue
50+
end
51+
ptr.dval .+= darg
4652
end
47-
if darg == ChainRulesCore.NoTangent()
48-
continue
49-
end
50-
ptr.dval .+= darg
53+
Enzyme.make_zero!(dres.u)
54+
return ntuple(_ -> nothing, Val(length(args) + 4))
5155
end
52-
Enzyme.make_zero!(dres.u)
53-
return ntuple(_ -> nothing, Val(length(args) + 4))
56+
5457
end
5558

5659
end

0 commit comments

Comments
 (0)