Skip to content

Commit c341819

Browse files
Merge pull request #800 from oscardssmith/os/ode-nlfunc-support
Add nlprob to ODEFunction
2 parents 4be9585 + a91d8b3 commit c341819

File tree

1 file changed

+41
-27
lines changed

1 file changed

+41
-27
lines changed

src/scimlfunctions.jl

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,10 @@ the usage of `f`. These include:
289289
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
290290
internally computed on demand when required. The cost of this operation is highly dependent
291291
on the sparsity pattern.
292+
- `nlprob`: a `NonlinearProblem` that solves `f(u, t, p) = u_tmp`
293+
where the nonlinear parameters are the tuple `(t, u_tmp, p)`.
294+
This will be used as the nonlinear problem inside an implicit solver by specifying `u, u_tmp` and `t`
295+
such that solving this function produces a solution to the implicit step of your solver.
292296
293297
## iip: In-Place vs Out-Of-Place
294298
@@ -401,8 +405,8 @@ automatically symbolically generating the Jacobian and more from the
401405
numerically-defined functions.
402406
"""
403407
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
404-
O, TCV,
405-
SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
408+
O, TCV, SYS,
409+
IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip}
406410
f::F
407411
mass_matrix::TMM
408412
analytic::Ta
@@ -423,6 +427,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
423427
update_initializeprob!::UIProb
424428
initializeprobmap::IProbMap
425429
initializeprobpmap::IProbPmap
430+
nlprob::NLP
426431
end
427432

428433
@doc doc"""
@@ -525,8 +530,8 @@ information on generating the SplitFunction from this symbolic engine.
525530
"""
526531
struct SplitFunction{
527532
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
528-
TPJ, O,
529-
TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
533+
TPJ, O, TCV, SYS,
534+
IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip}
530535
f1::F1
531536
f2::F2
532537
mass_matrix::TMM
@@ -549,6 +554,7 @@ struct SplitFunction{
549554
update_initializeprob!::UIProb
550555
initializeprobmap::IProbMap
551556
initializeprobpmap::IProbPmap
557+
nlprob::NLP
552558
end
553559

554560
@doc doc"""
@@ -2432,7 +2438,8 @@ function ODEFunction{iip, specialize}(f;
24322438
update_initializeprob! = __has_update_initializeprob!(f) ?
24332439
f.update_initializeprob! : nothing,
24342440
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
2435-
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing
2441+
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
2442+
nlprob = __has_nlprob(f) ? f.nlprob : nothing,
24362443
) where {iip,
24372444
specialize
24382445
}
@@ -2490,11 +2497,11 @@ function ODEFunction{iip, specialize}(f;
24902497
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
24912498
Any,
24922499
typeof(_colorvec),
2493-
typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
2500+
typeof(sys), Any, Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
24942501
jvp, vjp, jac_prototype, sparsity, Wfact,
24952502
Wfact_t, W_prototype, paramjac,
24962503
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2497-
initializeprobpmap)
2504+
initializeprobpmap, nlprob)
24982505
elseif specialize === false
24992506
ODEFunction{iip, FunctionWrapperSpecialize,
25002507
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2503,13 +2510,16 @@ function ODEFunction{iip, specialize}(f;
25032510
typeof(paramjac),
25042511
typeof(observed),
25052512
typeof(_colorvec),
2506-
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
2507-
typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix,
2513+
typeof(sys), typeof(initializeprob),
2514+
typeof(update_initializeprob!),
2515+
typeof(initializeprobmap),
2516+
typeof(initializeprobpmap),
2517+
typeof(nlprob)}(_f, mass_matrix,
25082518
analytic, tgrad, jac,
25092519
jvp, vjp, jac_prototype, sparsity, Wfact,
25102520
Wfact_t, W_prototype, paramjac,
2511-
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2512-
initializeprobpmap)
2521+
observed, _colorvec, sys, initializeprob, update_initializeprob!,
2522+
initializeprobmap, initializeprobpmap, nlprob)
25132523
else
25142524
ODEFunction{iip, specialize,
25152525
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2520,11 +2530,12 @@ function ODEFunction{iip, specialize}(f;
25202530
typeof(_colorvec),
25212531
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
25222532
typeof(initializeprobmap),
2523-
typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac,
2533+
typeof(initializeprobpmap),
2534+
typeof(nlprob)}(_f, mass_matrix, analytic, tgrad, jac,
25242535
jvp, vjp, jac_prototype, sparsity, Wfact,
25252536
Wfact_t, W_prototype, paramjac,
25262537
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2527-
initializeprobpmap)
2538+
initializeprobpmap, nlprob)
25282539
end
25292540
end
25302541

@@ -2541,13 +2552,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25412552
Any, Any, Any, Any, typeof(f.jac_prototype),
25422553
typeof(f.sparsity), Any, Any, Any,
25432554
Any, typeof(f.colorvec),
2544-
typeof(f.sys), Any, Any, Any, Any}(
2555+
typeof(f.sys), Any, Any, Any, Any, Any}(
25452556
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25462557
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25472558
f.Wfact_t, f.W_prototype, f.paramjac,
25482559
f.observed, f.colorvec, f.sys, f.initializeprob,
25492560
f.update_initializeprob!, f.initializeprobmap,
2550-
f.initializeprobpmap)
2561+
f.initializeprobpmap, f.nlprob)
25512562
else
25522563
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
25532564
typeof(f.analytic), typeof(f.tgrad),
@@ -2557,11 +2568,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25572568
typeof(f.observed), typeof(f.colorvec),
25582569
typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!),
25592570
typeof(f.initializeprobmap),
2560-
typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2571+
typeof(f.initializeprobpmap),
2572+
typeof(f.nlprob)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25612573
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25622574
f.Wfact_t, f.W_prototype, f.paramjac,
25632575
f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!,
2564-
f.initializeprobmap, f.initializeprobpmap)
2576+
f.initializeprobmap, f.initializeprobpmap, f.nlprob)
25652577
end
25662578
end
25672579

@@ -2693,7 +2705,7 @@ end
26932705
@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
26942706
vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
26952707
observed, colorvec, sys, initializeprob, update_initializeprob!,
2696-
initializeprobmap, initializeprobpmap)
2708+
initializeprobmap, initializeprobpmap, nlprob)
26972709
f1 = ODEFunction(f1)
26982710
f2 = ODEFunction(f2)
26992711

@@ -2708,11 +2720,11 @@ end
27082720
typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
27092721
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
27102722
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap),
2711-
typeof(initializeprobpmap)}(
2723+
typeof(initializeprobpmap), typeof(nlprob)}(
27122724
f1, f2, mass_matrix,
27132725
cache, analytic, tgrad, jac, jvp, vjp,
27142726
jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2715-
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
2727+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob)
27162728
end
27172729
function SplitFunction{iip, specialize}(f1, f2;
27182730
mass_matrix = __has_mass_matrix(f1) ?
@@ -2748,7 +2760,8 @@ function SplitFunction{iip, specialize}(f1, f2;
27482760
update_initializeprob! = __has_update_initializeprob!(f1) ?
27492761
f1.update_initializeprob! : nothing,
27502762
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
2751-
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing
2763+
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing,
2764+
nlprob = __has_nlprob(f1) ? f1.nlprob : nothing
27522765
) where {iip,
27532766
specialize
27542767
}
@@ -2759,12 +2772,12 @@ function SplitFunction{iip, specialize}(f1, f2;
27592772
if specialize === NoSpecialize
27602773
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
27612774
Any, Any, Any, Any, Any, Any, Any,
2762-
Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
2775+
Any, Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
27632776
analytic,
27642777
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
27652778
sparsity, Wfact, Wfact_t, paramjac,
27662779
observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap,
2767-
initializeprobpmap, initializeprobpmap)
2780+
initializeprobpmap, initializeprobpmap, nlprob)
27682781
else
27692782
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
27702783
typeof(_func_cache), typeof(analytic),
@@ -2774,11 +2787,11 @@ function SplitFunction{iip, specialize}(f1, f2;
27742787
typeof(colorvec),
27752788
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
27762789
typeof(initializeprobmap),
2777-
typeof(initializeprobpmap)}(f1, f2,
2790+
typeof(initializeprobpmap), typeof(nlprob)}(f1, f2,
27782791
mass_matrix, _func_cache, analytic, tgrad, jac,
27792792
jvp, vjp, jac_prototype, W_prototype,
27802793
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2781-
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
2794+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob)
27822795
end
27832796
end
27842797

@@ -3121,7 +3134,7 @@ SDEFunction(f::SDEFunction; kwargs...) = f
31213134

31223135
@add_kwonly function SplitSDEFunction(f1, f2, g, mass_matrix, cache, analytic, tgrad, jac,
31233136
jvp, vjp,
3124-
jac_prototype, W_prototype, Wfact, Wfact_t, paramjac, observed,
3137+
jac_prototype, Wfact, Wfact_t, paramjac, observed,
31253138
colorvec, sys)
31263139
f1 = f1 isa AbstractSciMLOperator ? f1 : SDEFunction(f1)
31273140
f2 = SDEFunction(f2)
@@ -3132,7 +3145,7 @@ SDEFunction(f::SDEFunction; kwargs...) = f
31323145
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
31333146
typeof(colorvec),
31343147
typeof(sys)}(f1, f2, mass_matrix, cache, analytic, tgrad, jac,
3135-
jac_prototype, W_prototype, Wfact, Wfact_t, paramjac, observed, colorvec, sys)
3148+
jac_prototype, Wfact, Wfact_t, paramjac, observed, colorvec, sys)
31363149
end
31373150

31383151
function SplitSDEFunction{iip, specialize}(f1, f2, g;
@@ -4411,6 +4424,7 @@ __has_initializeprob(f) = isdefined(f, :initializeprob)
44114424
__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!)
44124425
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
44134426
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
4427+
__has_nlprob(f) = isdefined(f, :nlprob)
44144428

44154429
# compatibility
44164430
has_invW(f::AbstractSciMLFunction) = false

0 commit comments

Comments
 (0)