|
1 | 1 | module DiffEqBaseEnzymeExt |
2 | 2 |
|
3 | | -using DiffEqBase |
4 | | -import DiffEqBase: value |
5 | | -using Enzyme |
6 | | -import Enzyme: Const |
7 | | -using ChainRulesCore |
8 | | - |
9 | | -function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{1}, |
10 | | - func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, |
11 | | - sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, |
12 | | - u0, p, args...; kwargs...) where {RT} |
13 | | - @inline function copy_or_reuse(val, idx) |
14 | | - if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val) |
15 | | - return deepcopy(val) |
16 | | - else |
17 | | - return val |
| 3 | +@static if isempty(VERSION.prerelease) |
| 4 | + using DiffEqBase |
| 5 | + import DiffEqBase: value |
| 6 | + using Enzyme |
| 7 | + import Enzyme: Const |
| 8 | + using ChainRulesCore |
| 9 | + |
| 10 | + function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{1}, |
| 11 | + func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, |
| 12 | + sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, |
| 13 | + u0, p, args...; kwargs...) where {RT} |
| 14 | + @inline function copy_or_reuse(val, idx) |
| 15 | + if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val) |
| 16 | + return deepcopy(val) |
| 17 | + else |
| 18 | + return val |
| 19 | + end |
18 | 20 | end |
19 | | - end |
20 | 21 |
|
21 | | - @inline function arg_copy(i) |
22 | | - copy_or_reuse(args[i].val, i + 5) |
23 | | - end |
| 22 | + @inline function arg_copy(i) |
| 23 | + copy_or_reuse(args[i].val, i + 5) |
| 24 | + end |
24 | 25 |
|
25 | | - res = DiffEqBase._solve_adjoint( |
26 | | - copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), |
27 | | - copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), |
28 | | - SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; |
29 | | - kwargs...) |
| 26 | + res = DiffEqBase._solve_adjoint( |
| 27 | + copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), |
| 28 | + copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), |
| 29 | + SciMLBase.EnzymeOriginator(), ntuple(arg_copy, Val(length(args)))...; |
| 30 | + kwargs...) |
30 | 31 |
|
31 | | - dres = Enzyme.make_zero(res[1])::RT |
32 | | - tup = (dres, res[2]) |
33 | | - return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) |
34 | | -end |
| 32 | + dres = Enzyme.make_zero(res[1])::RT |
| 33 | + tup = (dres, res[2]) |
| 34 | + return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) |
| 35 | + end |
35 | 36 |
|
36 | | -function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}, |
37 | | - func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob, |
38 | | - sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, |
39 | | - u0, p, args...; kwargs...) where {RT} |
40 | | - dres, clos = tape |
41 | | - dres = dres::RT |
42 | | - dargs = clos(dres) |
43 | | - for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) |
44 | | - if ptr isa Enzyme.Const |
45 | | - continue |
| 37 | + function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}, |
| 38 | + func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob, |
| 39 | + sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, |
| 40 | + u0, p, args...; kwargs...) where {RT} |
| 41 | + dres, clos = tape |
| 42 | + dres = dres::RT |
| 43 | + dargs = clos(dres) |
| 44 | + for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) |
| 45 | + if ptr isa Enzyme.Const |
| 46 | + continue |
| 47 | + end |
| 48 | + if darg == ChainRulesCore.NoTangent() |
| 49 | + continue |
| 50 | + end |
| 51 | + ptr.dval .+= darg |
46 | 52 | end |
47 | | - if darg == ChainRulesCore.NoTangent() |
48 | | - continue |
49 | | - end |
50 | | - ptr.dval .+= darg |
| 53 | + Enzyme.make_zero!(dres.u) |
| 54 | + return ntuple(_ -> nothing, Val(length(args) + 4)) |
51 | 55 | end |
52 | | - Enzyme.make_zero!(dres.u) |
53 | | - return ntuple(_ -> nothing, Val(length(args) + 4)) |
| 56 | + |
54 | 57 | end |
55 | 58 |
|
56 | 59 | end |
0 commit comments