Skip to content

Commit a170fd9

Browse files
Setup solve for adjoints to deprecate concrete_solve
Fixes SciML/DifferentialEquations.jl#610 Is non-breaking
1 parent 00ad2e3 commit a170fd9

File tree

4 files changed

+92
-71
lines changed

4 files changed

+92
-71
lines changed

src/reversediff.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
1-
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p::ReverseDiff.TrackedArray,args...;
2-
sensealg=nothing,kwargs...)
3-
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
1+
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p::ReverseDiff.TrackedArray,args...;kwargs...)
2+
ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
43
end
54

6-
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0,p::ReverseDiff.TrackedArray,args...;
7-
sensealg=nothing,kwargs...)
8-
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
5+
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0,p::ReverseDiff.TrackedArray,args...;kwargs...)
6+
ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
97
end
108

11-
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p,args...;
12-
sensealg=nothing,kwargs...)
13-
ReverseDiff.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
9+
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::ReverseDiff.TrackedArray,p,args...;kwargs...)
10+
ReverseDiff.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
1411
end
1512

16-
ReverseDiff.@grad function concrete_solve(prob,alg,u0,p,args...;
17-
sensealg=nothing,kwargs...)
18-
out = _concrete_solve_adjoint(prob,alg,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p),args...;kwargs...)
13+
ReverseDiff.@grad function solve_up(prob,sensealg,u0,p,args...;kwargs...)
14+
out = _solve_adjoint(prob,sensealg,ReverseDiff.value(u0),ReverseDiff.value(p),args...;kwargs...)
1915
Array(out[1]),out[2]
2016
end

src/solve.jl

Lines changed: 72 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,17 @@ function solve_call(_prob,args...;merge_callbacks = true, kwargs...)
5959
else
6060
__solve(_prob,args...; kwargs...)#::T
6161
end
62+
end
6263

64+
function solve(prob::DEProblem,args...;sensealg=nothing,
65+
u0 = nothing, p = nothing,kwargs...)
66+
u0 = u0 !== nothing ? u0 : prob.u0
67+
p = p !== nothing ? p : prob.p
68+
solve_up(prob,sensealg,u0,p,args...;kwargs...)
6369
end
6470

65-
function solve(prob::DEProblem,args...;kwargs...)
66-
_prob = get_concrete_problem(prob,kwargs)
71+
function solve_up(prob::DEProblem,sensealg,u0,p,args...;kwargs...)
72+
_prob = get_concrete_problem(prob;u0=u0,p=p,kwargs...)
6773
if haskey(kwargs,:alg) && (isempty(args) || args[1] === nothing)
6874
alg = kwargs[:alg]
6975
isadaptive(alg) &&
@@ -93,21 +99,21 @@ function solve(prob::EnsembleProblem,args...;kwargs...)
9399
end
94100
end
95101

96-
function solve(prob::AbstractNoiseProblem,args...;kwargs...)
102+
function solve(prob::AbstractNoiseProblem,args...; kwargs...)
97103
__solve(prob,args...;kwargs...)
98104
end
99105

100-
function get_concrete_problem(prob::AbstractJumpProblem,kwargs)
106+
function get_concrete_problem(prob::AbstractJumpProblem; kwargs...)
101107
prob
102108
end
103109

104-
function get_concrete_problem(prob::AbstractSteadyStateProblem, kwargs)
110+
function get_concrete_problem(prob::AbstractSteadyStateProblem; kwargs...)
105111
u0 = get_concrete_u0(prob, Inf, kwargs)
106112
u0 = promote_u0(u0, prob.p, nothing)
107113
remake(prob; u0 = u0)
108114
end
109115

110-
function get_concrete_problem(prob::AbstractEnsembleProblem, kwargs)
116+
function get_concrete_problem(prob::AbstractEnsembleProblem; kwargs...)
111117
prob
112118
end
113119

@@ -118,45 +124,45 @@ end
118124

119125
function discretize end
120126

121-
function get_concrete_problem(prob, kwargs)
122-
tspan = get_concrete_tspan(prob, kwargs)
127+
function get_concrete_problem(prob; kwargs...)
128+
p = get_concrete_p(prob, kwargs)
129+
tspan = get_concrete_tspan(prob, kwargs, p)
123130
u0 = get_concrete_u0(prob, tspan[1], kwargs)
124-
u0_promote = promote_u0(u0, prob.p, tspan[1])
125-
tspan_promote = promote_tspan(u0, prob.p, tspan, prob, kwargs)
131+
u0_promote = promote_u0(u0, p, tspan[1])
132+
tspan_promote = promote_tspan(u0, p, tspan, prob, kwargs)
126133
if isconcreteu0(prob, tspan[1], kwargs) && typeof(u0_promote) === typeof(u0) &&
127134
prob.tspan == tspan && typeof(tspan) === typeof(tspan_promote)
128135
return prob
129136
else
130-
return remake(prob; u0 = u0_promote, tspan = tspan_promote)
137+
return remake(prob; u0 = u0_promote, p=p, tspan = tspan_promote)
131138
end
132139
end
133140

134-
function get_concrete_problem(prob::DDEProblem, kwargs)
135-
tspan = get_concrete_tspan(prob, kwargs)
141+
function get_concrete_problem(prob::DDEProblem; kwargs...)
142+
p = get_concrete_p(prob, kwargs)
143+
tspan = get_concrete_tspan(prob, kwargs, p)
136144

137145
u0 = get_concrete_u0(prob, tspan[1], kwargs)
138146

139147
if prob.constant_lags isa Function
140-
constant_lags = prob.constant_lags(prob.p)
148+
constant_lags = prob.constant_lags(p)
141149
else
142150
constant_lags = prob.constant_lags
143151
end
144152

145-
u0 = promote_u0(u0, prob.p, tspan[1])
146-
tspan = promote_tspan(u0, prob.p, tspan, prob, kwargs)
153+
u0 = promote_u0(u0, p, tspan[1])
154+
tspan = promote_tspan(u0, p, tspan, prob, kwargs)
147155

148-
remake(prob; u0 = u0, tspan = tspan, constant_lags = constant_lags)
156+
remake(prob; u0 = u0, tspan = tspan, p=p, constant_lags = constant_lags)
149157
end
150158

151-
function get_concrete_tspan(prob, kwargs)
159+
function get_concrete_tspan(prob, kwargs, p)
152160
if prob.tspan isa Function
153-
tspan = prob.tspan(prob.p)
154-
elseif prob.tspan === (nothing, nothing)
155-
if haskey(kwargs, :tspan)
161+
tspan = prob.tspan(p)
162+
elseif haskey(kwargs, :tspan)
156163
tspan = kwargs[:tspan]
157-
else
158-
error("No tspan is set in the problem or chosen in the init/solve call")
159-
end
164+
elseif prob.tspan === (nothing, nothing)
165+
error("No tspan is set in the problem or chosen in the init/solve call")
160166
else
161167
tspan = prob.tspan
162168
end
@@ -171,7 +177,7 @@ end
171177
function get_concrete_u0(prob, t0, kwargs)
172178
if eval_u0(prob.u0)
173179
u0 = prob.u0(prob.p, t0)
174-
elseif prob.u0 === nothing
180+
elseif haskey(kwargs,:u0)
175181
u0 = kwargs[:u0]
176182
else
177183
u0 = prob.u0
@@ -180,6 +186,14 @@ function get_concrete_u0(prob, t0, kwargs)
180186
handle_distribution_u0(u0)
181187
end
182188

189+
function get_concrete_p(prob, kwargs)
190+
if haskey(kwargs,:p)
191+
p = kwargs[:p]
192+
else
193+
p = prob.p
194+
end
195+
end
196+
183197
handle_distribution_u0(_u0) = _u0
184198
eval_u0(u0::Function) = true
185199
eval_u0(u0) = false
@@ -218,38 +232,49 @@ end
218232

219233
################### Concrete Solve
220234

221-
function _concrete_solve end
235+
@deprecate concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
236+
u0=prob.u0,p=prob.p,args...;kwargs...) solve(prob,alg,args...;u0=u0,p=p,kwargs...)
222237

223-
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
224-
u0=prob.u0,p=prob.p,args...;kwargs...)
225-
_concrete_solve(prob,alg,u0,p,args...;kwargs...)
226-
end
238+
struct SensitivityADPassThrough <: DiffEqBase.DEAlgorithm end
227239

228-
function _concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
229-
u0=prob.u0,p=prob.p,args...;kwargs...)
230-
sol = solve(remake(prob,u0=u0,p=p),alg,args...;kwargs...)
231-
RecursiveArrayTools.DiffEqArray(sol.u,sol.t)
240+
ZygoteRules.@adjoint function solve_up(prob,sensealg::Union{Nothing,AbstractSensitivityAlgorithm},
241+
u0,p,args...;
242+
kwargs...)
243+
_solve_adjoint(prob,sensealg,u0,p,args...;kwargs...)
232244
end
233245

234-
function _concrete_solve(prob::DiffEqBase.SteadyStateProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},
235-
u0=prob.u0,p=prob.p,args...;kwargs...)
236-
sol = solve(remake(prob,u0=u0,p=p),alg,args...;kwargs...)
237-
RecursiveArrayTools.VectorOfArray(sol.u)
246+
function ChainRulesCore.frule(::typeof(solve_up),prob,
247+
sensealg::Union{Nothing,AbstractSensitivityAlgorithm},
248+
u0,p,args...;
249+
kwargs...)
250+
_solve_forward(prob,sensealg,u0,p,args...;kwargs...)
238251
end
239252

240-
function ChainRulesCore.frule(::typeof(concrete_solve),prob,alg,u0,p,args...;
241-
sensealg=nothing,kwargs...)
242-
_concrete_solve_forward(prob,alg,sensealg,u0,p,args...;kwargs...)
253+
function ChainRulesCore.rrule(::typeof(solve_up),prob,
254+
sensealg::Union{Nothing,AbstractSensitivityAlgorithm},
255+
u0,p,args...;
256+
kwargs...)
257+
_solve_adjoint(prob,sensealg,u0,p,args...;kwargs...)
243258
end
244259

245-
function ChainRulesCore.rrule(::typeof(concrete_solve),prob,alg,u0,p,args...;
246-
sensealg=nothing,kwargs...)
247-
_concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...)
260+
###
261+
### Legacy Dispatches to be Non-Breaking
262+
###
263+
264+
function _solve_adjoint(prob,sensealg,u0,p,args...;kwargs...)
265+
if isempty(args)
266+
_concrete_solve_adjoint(prob,nothing,sensealg,u0,p;kwargs...)
267+
else
268+
_concrete_solve_adjoint(prob,args[1],sensealg,u0,p,Base.tail(args)...;kwargs...)
269+
end
248270
end
249271

250-
ZygoteRules.@adjoint function concrete_solve(prob,alg,u0,p,args...;
251-
sensealg=nothing,kwargs...)
252-
_concrete_solve_adjoint(prob,alg,sensealg,u0,p,args...;kwargs...)
272+
function _solve_forward(prob,sensealg,u0,p,args...;kwargs...)
273+
if isempty(args)
274+
_concrete_solve_forward(prob,nothing,sensealg,u0,p;kwargs...)
275+
else
276+
_concrete_solve_forward(prob,args[1],sensealg,u0,p,Base.tail(args)...;kwargs...)
277+
end
253278
end
254279

255280
function _concrete_solve_adjoint(args...;kwargs...)

src/tracker.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,20 @@ end
3030
end
3131
@inline ODE_DEFAULT_NORM(u::Tracker.TrackedReal,t::Tracker.TrackedReal) = abs(u)
3232

33-
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::Tracker.TrackedArray,p::Tracker.TrackedArray,args...;
34-
sensealg=nothing,kwargs...)
35-
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
33+
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::Tracker.TrackedArray,p::Tracker.TrackedArray,args...;kwargs...)
34+
Tracker.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
3635
end
3736

38-
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0,p::Tracker.TrackedArray,args...;
39-
sensealg=nothing,kwargs...)
40-
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
37+
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0::Tracker.TrackedArray,p,args...;kwargs...)
38+
Tracker.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
4139
end
4240

43-
function concrete_solve(prob::DiffEqBase.DEProblem,alg::Union{DiffEqBase.DEAlgorithm,Nothing},u0::Tracker.TrackedArray,p,args...;
44-
sensealg=nothing,kwargs...)
45-
Tracker.track(concrete_solve,prob,alg,u0,p,args...;sensealg=sensealg,kwargs...)
41+
function solve_up(prob::DiffEqBase.DEProblem,sensealg::Union{AbstractSensitivityAlgorithm,Nothing},u0,p::Tracker.TrackedArray,args...;kwargs...)
42+
Tracker.track(solve_up,prob,sensealg,u0,p,args...;kwargs...)
4643
end
4744

48-
Tracker.@grad function concrete_solve(prob,alg,u0,p,args...;
49-
sensealg=nothing,kwargs...)
50-
_concrete_solve_adjoint(prob,alg,sensealg,Tracker.data(u0),Tracker.data(p),args...;kwargs...)
45+
Tracker.@grad function solve_up(prob,sensealg::Union{Nothing,AbstractSensitivityAlgorithm},
46+
u0,p,args...;
47+
kwargs...)
48+
_solve_adjoint(prob,sensealg,Tracker.data(u0),Tracker.data(p),args...;kwargs...)
5149
end

src/zygote.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#=
12
ZygoteRules.@adjoint function ODESolution(u,args...)
23
function ODESolutionAdjoint(ȳ)
34
(ȳ,ntuple(_->nothing, length(args))...)
@@ -32,6 +33,7 @@ ZygoteRules.@adjoint function getindex(sol::DESolution, i, j...)
3233
end
3334
sol[i,j...],DESolution_getindex_adjoint
3435
end
36+
=#
3537

3638
ZygoteRules.@adjoint function (f::ODEFunction)(u,p,t)
3739
if f.vjp === nothing

0 commit comments

Comments
 (0)