@@ -59,11 +59,17 @@ function solve_call(_prob,args...;merge_callbacks = true, kwargs...)
59
59
else
60
60
__solve (_prob,args... ; kwargs... )# ::T
61
61
end
62
+ end
62
63
64
+ function solve (prob:: DEProblem ,args... ;sensealg= nothing ,
65
+ u0 = nothing , p = nothing ,kwargs... )
66
+ u0 = u0 != = nothing ? u0 : prob. u0
67
+ p = p != = nothing ? p : prob. p
68
+ solve_up (prob,sensealg,u0,p,args... ;kwargs... )
63
69
end
64
70
65
- function solve (prob:: DEProblem ,args... ;kwargs... )
66
- _prob = get_concrete_problem (prob, kwargs)
71
+ function solve_up (prob:: DEProblem ,sensealg,u0,p ,args... ;kwargs... )
72
+ _prob = get_concrete_problem (prob;u0 = u0,p = p, kwargs... )
67
73
if haskey (kwargs,:alg ) && (isempty (args) || args[1 ] === nothing )
68
74
alg = kwargs[:alg ]
69
75
isadaptive (alg) &&
@@ -93,21 +99,21 @@ function solve(prob::EnsembleProblem,args...;kwargs...)
93
99
end
94
100
end
95
101
96
- function solve (prob:: AbstractNoiseProblem ,args... ;kwargs... )
102
+ function solve (prob:: AbstractNoiseProblem ,args... ; kwargs... )
97
103
__solve (prob,args... ;kwargs... )
98
104
end
99
105
100
- function get_concrete_problem (prob:: AbstractJumpProblem , kwargs)
106
+ function get_concrete_problem (prob:: AbstractJumpProblem ; kwargs... )
101
107
prob
102
108
end
103
109
104
- function get_concrete_problem (prob:: AbstractSteadyStateProblem , kwargs)
110
+ function get_concrete_problem (prob:: AbstractSteadyStateProblem ; kwargs... )
105
111
u0 = get_concrete_u0 (prob, Inf , kwargs)
106
112
u0 = promote_u0 (u0, prob. p, nothing )
107
113
remake (prob; u0 = u0)
108
114
end
109
115
110
- function get_concrete_problem (prob:: AbstractEnsembleProblem , kwargs)
116
+ function get_concrete_problem (prob:: AbstractEnsembleProblem ; kwargs... )
111
117
prob
112
118
end
113
119
@@ -118,45 +124,45 @@ end
118
124
119
125
function discretize end
120
126
121
- function get_concrete_problem (prob, kwargs)
122
- tspan = get_concrete_tspan (prob, kwargs)
127
+ function get_concrete_problem (prob; kwargs... )
128
+ p = get_concrete_p (prob, kwargs)
129
+ tspan = get_concrete_tspan (prob, kwargs, p)
123
130
u0 = get_concrete_u0 (prob, tspan[1 ], kwargs)
124
- u0_promote = promote_u0 (u0, prob . p, tspan[1 ])
125
- tspan_promote = promote_tspan (u0, prob . p, tspan, prob, kwargs)
131
+ u0_promote = promote_u0 (u0, p, tspan[1 ])
132
+ tspan_promote = promote_tspan (u0, p, tspan, prob, kwargs)
126
133
if isconcreteu0 (prob, tspan[1 ], kwargs) && typeof (u0_promote) === typeof (u0) &&
127
134
prob. tspan == tspan && typeof (tspan) === typeof (tspan_promote)
128
135
return prob
129
136
else
130
- return remake (prob; u0 = u0_promote, tspan = tspan_promote)
137
+ return remake (prob; u0 = u0_promote, p = p, tspan = tspan_promote)
131
138
end
132
139
end
133
140
134
- function get_concrete_problem (prob:: DDEProblem , kwargs)
135
- tspan = get_concrete_tspan (prob, kwargs)
141
+ function get_concrete_problem (prob:: DDEProblem ; kwargs... )
142
+ p = get_concrete_p (prob, kwargs)
143
+ tspan = get_concrete_tspan (prob, kwargs, p)
136
144
137
145
u0 = get_concrete_u0 (prob, tspan[1 ], kwargs)
138
146
139
147
if prob. constant_lags isa Function
140
- constant_lags = prob. constant_lags (prob . p)
148
+ constant_lags = prob. constant_lags (p)
141
149
else
142
150
constant_lags = prob. constant_lags
143
151
end
144
152
145
- u0 = promote_u0 (u0, prob . p, tspan[1 ])
146
- tspan = promote_tspan (u0, prob . p, tspan, prob, kwargs)
153
+ u0 = promote_u0 (u0, p, tspan[1 ])
154
+ tspan = promote_tspan (u0, p, tspan, prob, kwargs)
147
155
148
- remake (prob; u0 = u0, tspan = tspan, constant_lags = constant_lags)
156
+ remake (prob; u0 = u0, tspan = tspan, p = p, constant_lags = constant_lags)
149
157
end
150
158
151
- function get_concrete_tspan (prob, kwargs)
159
+ function get_concrete_tspan (prob, kwargs, p )
152
160
if prob. tspan isa Function
153
- tspan = prob. tspan (prob. p)
154
- elseif prob. tspan === (nothing , nothing )
155
- if haskey (kwargs, :tspan )
161
+ tspan = prob. tspan (p)
162
+ elseif haskey (kwargs, :tspan )
156
163
tspan = kwargs[:tspan ]
157
- else
158
- error (" No tspan is set in the problem or chosen in the init/solve call" )
159
- end
164
+ elseif prob. tspan === (nothing , nothing )
165
+ error (" No tspan is set in the problem or chosen in the init/solve call" )
160
166
else
161
167
tspan = prob. tspan
162
168
end
171
177
function get_concrete_u0 (prob, t0, kwargs)
172
178
if eval_u0 (prob. u0)
173
179
u0 = prob. u0 (prob. p, t0)
174
- elseif prob . u0 === nothing
180
+ elseif haskey (kwargs, :u0 )
175
181
u0 = kwargs[:u0 ]
176
182
else
177
183
u0 = prob. u0
@@ -180,6 +186,14 @@ function get_concrete_u0(prob, t0, kwargs)
180
186
handle_distribution_u0 (u0)
181
187
end
182
188
189
+ function get_concrete_p (prob, kwargs)
190
+ if haskey (kwargs,:p )
191
+ p = kwargs[:p ]
192
+ else
193
+ p = prob. p
194
+ end
195
+ end
196
+
183
197
handle_distribution_u0 (_u0) = _u0
184
198
eval_u0 (u0:: Function ) = true
185
199
eval_u0 (u0) = false
@@ -218,38 +232,49 @@ end
218
232
219
233
# ################## Concrete Solve
220
234
221
- function _concrete_solve end
235
+ @deprecate concrete_solve (prob:: DiffEqBase.DEProblem ,alg:: Union{DiffEqBase.DEAlgorithm,Nothing} ,
236
+ u0= prob. u0,p= prob. p,args... ;kwargs... ) solve (prob,alg,args... ;u0= u0,p= p,kwargs... )
222
237
223
- function concrete_solve (prob:: DiffEqBase.DEProblem ,alg:: Union{DiffEqBase.DEAlgorithm,Nothing} ,
224
- u0= prob. u0,p= prob. p,args... ;kwargs... )
225
- _concrete_solve (prob,alg,u0,p,args... ;kwargs... )
226
- end
238
+ struct SensitivityADPassThrough <: DiffEqBase.DEAlgorithm end
227
239
228
- function _concrete_solve (prob:: DiffEqBase.DEProblem ,alg :: Union{DiffEqBase.DEAlgorithm, Nothing} ,
229
- u0 = prob . u0,p = prob . p ,args... ;kwargs ... )
230
- sol = solve ( remake (prob,u0 = u0,p = p),alg,args ... ; kwargs... )
231
- RecursiveArrayTools . DiffEqArray (sol . u,sol . t )
240
+ ZygoteRules . @adjoint function solve_up (prob,sensealg :: Union{Nothing,AbstractSensitivityAlgorithm } ,
241
+ u0,p ,args... ;
242
+ kwargs... )
243
+ _solve_adjoint (prob,sensealg,u0,p,args ... ;kwargs ... )
232
244
end
233
245
234
- function _concrete_solve (prob:: DiffEqBase.SteadyStateProblem ,alg:: Union{DiffEqBase.DEAlgorithm,Nothing} ,
235
- u0= prob. u0,p= prob. p,args... ;kwargs... )
236
- sol = solve (remake (prob,u0= u0,p= p),alg,args... ;kwargs... )
237
- RecursiveArrayTools. VectorOfArray (sol. u)
246
+ function ChainRulesCore. frule (:: typeof (solve_up),prob,
247
+ sensealg:: Union{Nothing,AbstractSensitivityAlgorithm} ,
248
+ u0,p,args... ;
249
+ kwargs... )
250
+ _solve_forward (prob,sensealg,u0,p,args... ;kwargs... )
238
251
end
239
252
240
- function ChainRulesCore. frule (:: typeof (concrete_solve),prob,alg,u0,p,args... ;
241
- sensealg= nothing ,kwargs... )
242
- _concrete_solve_forward (prob,alg,sensealg,u0,p,args... ;kwargs... )
253
+ function ChainRulesCore. rrule (:: typeof (solve_up),prob,
254
+ sensealg:: Union{Nothing,AbstractSensitivityAlgorithm} ,
255
+ u0,p,args... ;
256
+ kwargs... )
257
+ _solve_adjoint (prob,sensealg,u0,p,args... ;kwargs... )
243
258
end
244
259
245
- function ChainRulesCore. rrule (:: typeof (concrete_solve),prob,alg,u0,p,args... ;
246
- sensealg= nothing ,kwargs... )
247
- _concrete_solve_adjoint (prob,alg,sensealg,u0,p,args... ;kwargs... )
260
+ # ##
261
+ # ## Legacy Dispatches to be Non-Breaking
262
+ # ##
263
+
264
+ function _solve_adjoint (prob,sensealg,u0,p,args... ;kwargs... )
265
+ if isempty (args)
266
+ _concrete_solve_adjoint (prob,nothing ,sensealg,u0,p;kwargs... )
267
+ else
268
+ _concrete_solve_adjoint (prob,args[1 ],sensealg,u0,p,Base. tail (args)... ;kwargs... )
269
+ end
248
270
end
249
271
250
- ZygoteRules. @adjoint function concrete_solve (prob,alg,u0,p,args... ;
251
- sensealg= nothing ,kwargs... )
252
- _concrete_solve_adjoint (prob,alg,sensealg,u0,p,args... ;kwargs... )
272
+ function _solve_forward (prob,sensealg,u0,p,args... ;kwargs... )
273
+ if isempty (args)
274
+ _concrete_solve_forward (prob,nothing ,sensealg,u0,p;kwargs... )
275
+ else
276
+ _concrete_solve_forward (prob,args[1 ],sensealg,u0,p,Base. tail (args)... ;kwargs... )
277
+ end
253
278
end
254
279
255
280
function _concrete_solve_adjoint (args... ;kwargs... )
0 commit comments