|
1 | 1 | module DiffEqBaseEnzymeExt
|
2 | 2 |
|
3 |
| -using DiffEqBase |
4 |
| -import DiffEqBase: value |
5 |
| -using Enzyme |
6 |
| -import Enzyme: Const |
7 |
| -using ChainRulesCore |
8 |
| - |
9 |
| -function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{1}, |
10 |
| - func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, |
11 |
| - sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, |
12 |
| - u0, p, args...; kwargs...) where {RT} |
13 |
| - @inline function copy_or_reuse(val, idx) |
14 |
| - if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val) |
15 |
| - return deepcopy(val) |
16 |
| - else |
17 |
| - return val |
| 3 | +@static if isempty(VERSION.prerelease) |
| 4 | + using DiffEqBase |
| 5 | + import DiffEqBase: value |
| 6 | + using Enzyme |
| 7 | + import Enzyme: Const |
| 8 | + using ChainRulesCore |
| 9 | + |
| 10 | + function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{1}, |
| 11 | + func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, |
| 12 | + sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, |
| 13 | + u0, p, args...; kwargs...) where {RT} |
| 14 | + @inline function copy_or_reuse(val, idx) |
| 15 | + if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val) |
| 16 | + return deepcopy(val) |
| 17 | + else |
| 18 | + return val |
| 19 | + end |
18 | 20 | end
|
19 |
| - end |
20 | 21 |
|
21 |
| - @inline function arg_copy(i) |
22 |
| - copy_or_reuse(args[i].val, i + 5) |
23 |
| - end |
| 22 | + @inline function arg_copy(i) |
| 23 | + copy_or_reuse(args[i].val, i + 5) |
| 24 | + end |
24 | 25 |
|
25 |
| - res = DiffEqBase._solve_adjoint( |
26 |
| - copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), |
27 |
| - copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), |
28 |
| - SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; |
29 |
| - kwargs...) |
| 26 | + res = DiffEqBase._solve_adjoint( |
| 27 | + copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), |
| 28 | + copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), |
| 29 | + SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; |
| 30 | + kwargs...) |
30 | 31 |
|
31 |
| - dres = Enzyme.make_zero(res[1])::RT |
32 |
| - tup = (dres, res[2]) |
33 |
| - return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) |
34 |
| -end |
| 32 | + dres = Enzyme.make_zero(res[1])::RT |
| 33 | + tup = (dres, res[2]) |
| 34 | + return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) |
| 35 | + end |
35 | 36 |
|
36 |
| -function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}, |
37 |
| - func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob, |
38 |
| - sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, |
39 |
| - u0, p, args...; kwargs...) where {RT} |
40 |
| - dres, clos = tape |
41 |
| - dres = dres::RT |
42 |
| - dargs = clos(dres) |
43 |
| - for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) |
44 |
| - if ptr isa Enzyme.Const |
45 |
| - continue |
| 37 | + function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}, |
| 38 | + func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob, |
| 39 | + sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, |
| 40 | + u0, p, args...; kwargs...) where {RT} |
| 41 | + dres, clos = tape |
| 42 | + dres = dres::RT |
| 43 | + dargs = clos(dres) |
| 44 | + for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) |
| 45 | + if ptr isa Enzyme.Const |
| 46 | + continue |
| 47 | + end |
| 48 | + if darg == ChainRulesCore.NoTangent() |
| 49 | + continue |
| 50 | + end |
| 51 | + ptr.dval .+= darg |
46 | 52 | end
|
47 |
| - if darg == ChainRulesCore.NoTangent() |
48 |
| - continue |
49 |
| - end |
50 |
| - ptr.dval .+= darg |
| 53 | + Enzyme.make_zero!(dres.u) |
| 54 | + return ntuple(_ -> nothing, Val(length(args) + 4)) |
51 | 55 | end
|
52 |
| - Enzyme.make_zero!(dres.u) |
53 |
| - return ntuple(_ -> nothing, Val(length(args) + 4)) |
| 56 | + |
54 | 57 | end
|
55 | 58 |
|
56 | 59 | end
|
0 commit comments