@@ -7,54 +7,83 @@ module DiffEqBaseEnzymeExt
77 import Enzyme: Const
88 using ChainRulesCore
99
10+
11+ @inline function copy_or_reuse (config, val, idx)
12+ if Enzyme. EnzymeRules. overwritten (config)[idx] && ismutable (val)
13+ return deepcopy (val)
14+ else
15+ return val
16+ end
17+ end
18+
19+ @inline function arg_copy (data, i)
20+ config, args = data
21+ copy_or_reuse (config, args[i]. val, i + 5 )
22+ end
23+
24+ # Note these following functions are generally not considered user facing from within Enzyme.
25+ # They enable additional performance/usability here (e.g. inactive kwargs).
26+ # Contact wsmoses@ before modifying (and beware their semantics may change without semver).
27+
28+ Enzyme. EnzymeRules. inactive_kwarg (:: typeof (DiffEqBase. solve_up), prob, sensalg:: Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm} , u0, p, args... ; kwargs... ) = nothing
29+
30+ Enzyme. EnzymeRules. has_easy_rule (:: typeof (DiffEqBase. solve_up), prob, sensalg:: Union{Nothing, DiffEqBase.AbstractSensitivityAlgorithm} , u0, p, args... ; kwargs... ) = nothing
31+
1032 function Enzyme. EnzymeRules. augmented_primal (
1133 config:: Enzyme.EnzymeRules.RevConfigWidth{1} ,
12- func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{Duplicated{RT}} , prob,
34+ func:: Const{typeof(DiffEqBase.solve_up)} , RTA :: Type{Duplicated{RT}} , prob,
1335 sensealg:: Union {
1436 Const{Nothing}, Const{<: DiffEqBase.AbstractSensitivityAlgorithm }},
1537 u0, p, args... ; kwargs... ) where {RT}
16- @inline function copy_or_reuse (val, idx)
17- if Enzyme. EnzymeRules. overwritten (config)[idx] && ismutable (val)
18- return deepcopy (val)
19- else
20- return val
21- end
22- end
23-
24- @inline function arg_copy (i)
25- copy_or_reuse (args[i]. val, i + 5 )
26- end
2738
2839 res = DiffEqBase. _solve_adjoint (
29- copy_or_reuse (prob. val, 2 ), copy_or_reuse (sensealg. val, 3 ),
30- copy_or_reuse (u0. val, 4 ), copy_or_reuse (p. val, 5 ),
31- SciMLBase. EnzymeOriginator (), ntuple (arg_copy, Val (length (args)))... ;
40+ copy_or_reuse (config, prob. val, 2 ), copy_or_reuse (config, sensealg. val, 3 ),
41+ copy_or_reuse (config, u0. val, 4 ), copy_or_reuse (config, p. val, 5 ),
42+ SciMLBase. EnzymeOriginator (), ntuple (Base . Fix1 ( arg_copy, (config, args)) , Val (length (args)))... ;
3243 kwargs... )
3344
34- dres = Enzyme. make_zero (res[1 ]):: RT
35- tup = (dres, res[2 ])
36- return Enzyme. EnzymeRules. AugmentedReturn {RT, RT, Any} (res[1 ], dres, tup:: Any )
45+ primal = if Enzyme. EnzymeRules. needs_primal (config)
46+ res[1 ]
47+ else
48+ nothing
49+ end
50+
51+ shadow = if Enzyme. EnzymeRules. needs_shadow (config)
52+ Enzyme. make_zero (res[1 ]):: RT
53+ else
54+ nothing
55+ end
56+ tup = if Enzyme. EnzymeRules. needs_shadow (config)
57+ (shadow, res[2 ])
58+ else
59+ nothing
60+ end
61+ return Enzyme. EnzymeRules. augmented_rule_return_type (config, RTA)(primal, shadow, tup)
3762 end
3863
3964 function Enzyme. EnzymeRules. reverse (config:: Enzyme.EnzymeRules.RevConfigWidth{1} ,
4065 func:: Const{typeof(DiffEqBase.solve_up)} , :: Type{Duplicated{RT}} , tape, prob,
4166 sensealg:: Union {
4267 Const{Nothing}, Const{<: DiffEqBase.AbstractSensitivityAlgorithm }},
4368 u0, p, args... ; kwargs... ) where {RT}
44- dres, clos = tape
45- dres = dres:: RT
46- dargs = clos (dres)
47- for (darg, ptr) in zip (dargs, (func, prob, sensealg, u0, p, args... ))
48- if ptr isa Enzyme. Const
49- continue
50- end
51- if darg == ChainRulesCore. NoTangent ()
52- continue
69+
70+ if Enzyme. EnzymeRules. needs_shadow (config)
71+ dres, clos = tape
72+ dres = dres:: RT
73+ dargs = clos (dres)
74+ for (darg, ptr) in zip (dargs, (func, prob, sensealg, u0, p, args... ))
75+ if ptr isa Enzyme. Const
76+ continue
77+ end
78+ if darg == ChainRulesCore. NoTangent ()
79+ continue
80+ end
81+ ptr. dval .+ = darg
5382 end
54- ptr . dval .+ = darg
83+ Enzyme . make_zero! (dres . u)
5584 end
56- Enzyme . make_zero! (dres . u)
57- return ntuple (_ -> nothing , Val (length (args) + 4 ))
85+
86+ return ntuple (Returns ( nothing ) , Val (length (args) + 4 ))
5887 end
5988end
6089
0 commit comments