@@ -152,20 +152,32 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
152152
153153 nlequation! = @closure (out, u, p) -> begin
154154 update_coefficients! (M, u, p, t)
155- # M * (u-u0)/dt - f(u,p,t)
155+ # f(u,p,t) + M * (u0 - u)/dt
156156 tmp = isAD ? PreallocationTools. get_tmp (_tmp, u) : _tmp
157- @. tmp = (u - u0 ) / dt
157+ @. tmp = (u0 - u ) / dt
158158 mul! (_vec (out), M, _vec (tmp))
159159 f (tmp, u, p, t)
160- out .- = tmp
160+ out .+ = tmp
161161 nothing
162162 end
163163
164- nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, isAD)
164+ jac = if isnothing (f. jac)
165+ f. jac
166+ else
167+ @closure (J, u, p) -> begin
168+ # f(u,p,t) + M * (u0 - u)/dt
169+ # df(u,p,t)/du - M/dt
170+ f. jac (J, u, p, t)
171+ J .- = M .* inv (dt)
172+ nothing
173+ end
174+ end
165175
166176 nlfunc = NonlinearFunction (nlequation!;
167- jac_prototype = f. jac_prototype)
177+ jac_prototype = f. jac_prototype,
178+ jac = jac)
168179 nlprob = NonlinearProblem (nlfunc, integrator. u, p)
180+ nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, isAD)
169181 nlsol = solve (nlprob, nlsolve; abstol = integrator. opts. abstol,
170182 reltol = integrator. opts. reltol)
171183 integrator. u .= nlsol. u
@@ -227,10 +239,19 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
227239 M * (u - u0) / dt - f (u, p, t)
228240 end
229241
242+ jac = if isnothing (f. jac)
243+ f. jac
244+ else
245+ @closure (u, p) -> begin
246+ return M * (u .- u0) ./ dt .- f. jac (u, p, t)
247+ end
248+ end
249+
230250 nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0)
231251
232252 nlfunc = NonlinearFunction (nlequation_oop;
233- jac_prototype = f. jac_prototype)
253+ jac_prototype = f. jac_prototype,
254+ jac = jac)
234255 nlprob = NonlinearProblem (nlfunc, u0)
235256 nlsol = solve (nlprob, nlsolve; abstol = integrator. opts. abstol,
236257 reltol = integrator. opts. reltol)
@@ -281,10 +302,20 @@ function _initialize_dae!(integrator, prob::DAEProblem,
281302 nothing
282303 end
283304
284- nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, isAD)
305+ jac = if isnothing (f. jac)
306+ f. jac
307+ else
308+ @closure (J, u, p) -> begin
309+ f. jac (J, u, p, inv (dt), t)
310+ nothing
311+ end
312+ end
285313
286- nlfunc = NonlinearFunction (nlequation!; jac_prototype = f. jac_prototype)
314+ nlfunc = NonlinearFunction (nlequation!;
315+ jac_prototype = f. jac_prototype,
316+ jac = jac)
287317 nlprob = NonlinearProblem (nlfunc, u0, p)
318+ nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, isAD)
288319 nlsol = solve (nlprob, nlsolve; abstol = integrator. opts. abstol,
289320 reltol = integrator. opts. reltol)
290321
@@ -318,6 +349,16 @@ function _initialize_dae!(integrator, prob::DAEProblem,
318349 resid = f (integrator. du, u0, p, t)
319350 integrator. opts. internalnorm (resid, t) <= integrator. opts. abstol && return
320351
352+ jac = if isnothing (f. jac)
353+ f. jac
354+ else
355+ @closure (u, p) -> begin
356+ return f. jac (u, p, inv (dt), t)
357+ end
358+ end
359+ nlfunc = NonlinearFunction (nlequation; jac_prototype = f. jac_prototype,
360+ jac = jac)
361+ nlprob = NonlinearProblem (nlfunc, u0)
321362 nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0)
322363
323364 nlfunc = NonlinearFunction (nlequation; jac_prototype = f. jac_prototype)
0 commit comments