Skip to content

Commit 45f4520

Browse files
author
oscarddssmith
committed
add nlfunc to ODEFunction
1 parent 06864fd commit 45f4520

File tree

1 file changed

+47
-16
lines changed

1 file changed

+47
-16
lines changed

src/scimlfunctions.jl

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ the usage of `f`. These include:
289289
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
290290
internally computed on demand when required. The cost of this operation is highly dependent
291291
on the sparsity pattern.
292+
- `nlfunc`: a `NonlinearFunction`
293+
- `nl_state_compres`: maps u->nlfunc_u
294+
- `nl_state_decompres`: maps nlfunc_u->u
292295
293296
## iip: In-Place vs Out-Of-Place
294297
@@ -401,8 +404,8 @@ automatically symbolically generating the Jacobian and more from the
401404
numerically-defined functions.
402405
"""
403406
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
404-
O, TCV,
405-
SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
407+
O, TCV, SYS, IProb, IProbMap,
408+
NLF<:Union{Nothing, NonlinearFunction}, NLSC, NLISC} <: AbstractODEFunction{iip}
406409
f::F
407410
mass_matrix::TMM
408411
analytic::Ta
@@ -421,6 +424,9 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
421424
sys::SYS
422425
initializeprob::IProb
423426
initializeprobmap::IProbMap
427+
nlfunc::NLF
428+
nl_state_compres::NLSC
429+
nl_state_decompres::NLISC
424430
end
425431

426432
@doc doc"""
@@ -517,8 +523,8 @@ information on generating the SplitFunction from this symbolic engine.
517523
"""
518524
struct SplitFunction{
519525
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
520-
TPJ, O,
521-
TCV, SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
526+
TPJ, O, TCV, SYS, IProb, IProbMap,
527+
NLF<:Union{Nothing, NonlinearFunction}, NLSC, NLISC} <: AbstractODEFunction{iip}
522528
f1::F1
523529
f2::F2
524530
mass_matrix::TMM
@@ -538,6 +544,9 @@ struct SplitFunction{
538544
sys::SYS
539545
initializeprob::IProb
540546
initializeprobmap::IProbMap
547+
nlfunc::NLF
548+
nl_state_compres::NLSC
549+
nl_state_decompres::NLISC
541550
end
542551

543552
@doc doc"""
@@ -2416,6 +2425,9 @@ function ODEFunction{iip, specialize}(f;
24162425
sys = __has_sys(f) ? f.sys : nothing,
24172426
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
24182427
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing
2428+
nlfunc = __has_nlfunc(f) ? f.nlfunc : nothing
2429+
nl_state_compres = __has_nl_state_compres(f) ? f.nl_state_compres : identity
2430+
nl_state_decompres = __has_nl_state_decompres(f) ? f.nl_state_decompres : identity
24192431
) where {iip,
24202432
specialize
24212433
}
@@ -2471,12 +2483,13 @@ function ODEFunction{iip, specialize}(f;
24712483
Any, Any, Any, Any,
24722484
Any, Any, Any, typeof(jac_prototype),
24732485
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
2474-
Any,
2475-
typeof(_colorvec),
2476-
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
2486+
Any,typeof(_colorvec),
2487+
typeof(sys), Any, Any,
2488+
Union{Nothing, NonlinearFunction}, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
24772489
jvp, vjp, jac_prototype, sparsity, Wfact,
24782490
Wfact_t, W_prototype, paramjac,
2479-
observed, _colorvec, sys, initializeprob, initializeprobmap)
2491+
observed, _colorvec, sys, initializeprob, initializeprobmap,
2492+
nlfunc, nl_state_compres, nl_state_decompres)
24802493
elseif specialize === false
24812494
ODEFunction{iip, FunctionWrapperSpecialize,
24822495
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2486,10 +2499,14 @@ function ODEFunction{iip, specialize}(f;
24862499
typeof(observed),
24872500
typeof(_colorvec),
24882501
typeof(sys), typeof(initializeprob),
2489-
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
2502+
typeof(initializeprobmap),
2503+
typeof(nlfunc),
2504+
typeof(nl_state_compres),
2505+
typeof(nl_state_decompres)}(_f, mass_matrix, analytic, tgrad, jac,
24902506
jvp, vjp, jac_prototype, sparsity, Wfact,
24912507
Wfact_t, W_prototype, paramjac,
2492-
observed, _colorvec, sys, initializeprob, initializeprobmap)
2508+
observed, _colorvec, sys, initializeprob, initializeprobmap,
2509+
nlfunc, nl_state_compres, nl_state_decompres)
24932510
else
24942511
ODEFunction{iip, specialize,
24952512
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2499,10 +2516,14 @@ function ODEFunction{iip, specialize}(f;
24992516
typeof(observed),
25002517
typeof(_colorvec),
25012518
typeof(sys), typeof(initializeprob),
2502-
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
2519+
typeof(initializeprobmap),
2520+
typeof(nlfunc),
2521+
typeof(nl_state_compres),
2522+
typeof(nl_state_decompres)}(_f, mass_matrix, analytic, tgrad, jac,
25032523
jvp, vjp, jac_prototype, sparsity, Wfact,
25042524
Wfact_t, W_prototype, paramjac,
2505-
observed, _colorvec, sys, initializeprob, initializeprobmap)
2525+
observed, _colorvec, sys, initializeprob, initializeprobmap,
2526+
nlfunc, nl_state_compres, nl_state_decompres)
25062527
end
25072528
end
25082529

@@ -2519,10 +2540,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25192540
Any, Any, Any, Any, typeof(f.jac_prototype),
25202541
typeof(f.sparsity), Any, Any, Any,
25212542
Any, typeof(f.colorvec),
2522-
typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2543+
typeof(f.sys), Any, Any
2544+
Union{Nothing, NonlinearFunction}, Any, Any
2545+
}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25232546
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25242547
f.Wfact_t, f.W_prototype, f.paramjac,
2525-
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap)
2548+
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap,
2549+
f.nlfunc, f.nl_state_compres, f.nl_state_decompres)
25262550
else
25272551
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
25282552
typeof(f.analytic), typeof(f.tgrad),
@@ -2531,11 +2555,15 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25312555
typeof(f.paramjac),
25322556
typeof(f.observed), typeof(f.colorvec),
25332557
typeof(f.sys), typeof(f.initializeprob),
2534-
typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2558+
typeof(f.initializeprobmap),
2559+
typeof(nlfunc),
2560+
typeof(nl_state_compres),
2561+
typeof(nl_state_decompres)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25352562
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25362563
f.Wfact_t, f.W_prototype, f.paramjac,
25372564
f.observed, f.colorvec, f.sys, f.initializeprob,
2538-
f.initializeprobmap)
2565+
f.initializeprobmap,
2566+
f.nlfunc, f.nl_state_compres, f.nl_state_decompres)
25392567
end
25402568
end
25412569

@@ -4336,6 +4364,9 @@ __has_analytic_full(f) = isdefined(f, :analytic_full)
43364364
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
43374365
__has_initializeprob(f) = isdefined(f, :initializeprob)
43384366
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
4367+
__has_nl_state_compres(f) = isdefined(f, :nl_state_compres)
4368+
__has_nl_state_decompres(f) = isdefined(f, :nl_state_decompres)
4369+
43394370

43404371
# compatibility
43414372
has_invW(f::AbstractSciMLFunction) = false

0 commit comments

Comments
 (0)