Skip to content

Commit 3ded92b

Browse files
author
oscarddssmith
committed
add nlfunc to ODEFunction
1 parent 17f4548 commit 3ded92b

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

src/scimlfunctions.jl

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ the usage of `f`. These include:
289289
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
290290
internally computed on demand when required. The cost of this operation is highly dependent
291291
on the sparsity pattern.
292+
- `nlfunc`: a `NonlinearFunction`
292293
293294
## iip: In-Place vs Out-Of-Place
294295
@@ -401,8 +402,8 @@ automatically symbolically generating the Jacobian and more from the
401402
numerically-defined functions.
402403
"""
403404
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
404-
O, TCV,
405-
SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
405+
O, TCV, SYS,
406+
IProb, UIProb, IProbMap, IProbPmap, NLF} <: AbstractODEFunction{iip}
406407
f::F
407408
mass_matrix::TMM
408409
analytic::Ta
@@ -423,6 +424,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
423424
update_initializeprob!::UIProb
424425
initializeprobmap::IProbMap
425426
initializeprobpmap::IProbPmap
427+
nlfunc::NLF
426428
end
427429

428430
@doc doc"""
@@ -519,8 +521,8 @@ information on generating the SplitFunction from this symbolic engine.
519521
"""
520522
struct SplitFunction{
521523
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
522-
TPJ, O,
523-
TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
524+
TPJ, O, TCV, SYS,
525+
IProb, UIProb, IProbMap, IProbPmap, NLF} <: AbstractODEFunction{iip}
524526
f1::F1
525527
f2::F2
526528
mass_matrix::TMM
@@ -543,6 +545,7 @@ struct SplitFunction{
543545
update_initializeprob!::UIProb
544546
initializeprobmap::IProbMap
545547
initializeprobpmap::IProbPmap
548+
nlfunc::NLF
546549
end
547550

548551
@doc doc"""
@@ -2420,7 +2423,8 @@ function ODEFunction{iip, specialize}(f;
24202423
update_initializeprob! = __has_update_initializeprob!(f) ?
24212424
f.update_initializeprob! : nothing,
24222425
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
2423-
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing
2426+
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
2427+
nlfunc = __has_nlfunc(f) ? f.nlfunc : nothing,
24242428
) where {iip,
24252429
specialize
24262430
}
@@ -2478,11 +2482,11 @@ function ODEFunction{iip, specialize}(f;
24782482
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
24792483
Any,
24802484
typeof(_colorvec),
2481-
typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
2485+
typeof(sys), Any, Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
24822486
jvp, vjp, jac_prototype, sparsity, Wfact,
24832487
Wfact_t, W_prototype, paramjac,
24842488
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2485-
initializeprobpmap)
2489+
initializeprobpmap, nlfunc)
24862490
elseif specialize === false
24872491
ODEFunction{iip, FunctionWrapperSpecialize,
24882492
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2491,13 +2495,15 @@ function ODEFunction{iip, specialize}(f;
24912495
typeof(paramjac),
24922496
typeof(observed),
24932497
typeof(_colorvec),
2494-
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
2495-
typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix,
2498+
typeof(sys), typeof(initializeprob),
2499+
typeof(update_initializeprob!),
2500+
typeof(initializeprobmap),
2501+
typeof(initializeprobpmap),
2502+
typeof(nlfunc)}(_f, mass_matrix,
24962503
analytic, tgrad, jac,
24972504
jvp, vjp, jac_prototype, sparsity, Wfact,
24982505
Wfact_t, W_prototype, paramjac,
2499-
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2500-
initializeprobpmap)
2506+
observed, _colorvec, sys, initializeprob, initializeprobmap, nlfunc)
25012507
else
25022508
ODEFunction{iip, specialize,
25032509
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2508,11 +2514,12 @@ function ODEFunction{iip, specialize}(f;
25082514
typeof(_colorvec),
25092515
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
25102516
typeof(initializeprobmap),
2511-
typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac,
2517+
typeof(initializeprobpmap)
2518+
typeof(nlfunc)}(_f, mass_matrix, analytic, tgrad, jac,
25122519
jvp, vjp, jac_prototype, sparsity, Wfact,
25132520
Wfact_t, W_prototype, paramjac,
25142521
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2515-
initializeprobpmap)
2522+
initializeprobpmap, nlfunc)
25162523
end
25172524
end
25182525

@@ -2529,13 +2536,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25292536
Any, Any, Any, Any, typeof(f.jac_prototype),
25302537
typeof(f.sparsity), Any, Any, Any,
25312538
Any, typeof(f.colorvec),
2532-
typeof(f.sys), Any, Any, Any, Any}(
2539+
typeof(f.sys), Any, Any, Any, Any, Any}(
25332540
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25342541
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25352542
f.Wfact_t, f.W_prototype, f.paramjac,
25362543
f.observed, f.colorvec, f.sys, f.initializeprob,
25372544
f.update_initializeprob!, f.initializeprobmap,
2538-
f.initializeprobpmap)
2545+
f.initializeprobpmap, f.nlfunc)
25392546
else
25402547
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
25412548
typeof(f.analytic), typeof(f.tgrad),
@@ -2545,11 +2552,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25452552
typeof(f.observed), typeof(f.colorvec),
25462553
typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!),
25472554
typeof(f.initializeprobmap),
2548-
typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2555+
typeof(f.initializeprobpmap),
2556+
typof(f.nlfunc)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25492557
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25502558
f.Wfact_t, f.W_prototype, f.paramjac,
25512559
f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!,
2552-
f.initializeprobmap, f.initializeprobpmap)
2560+
f.initializeprobmap, f.initializeprobpmap, f.nlfunc)
25532561
end
25542562
end
25552563

@@ -4370,6 +4378,7 @@ __has_initializeprob(f) = isdefined(f, :initializeprob)
43704378
__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!)
43714379
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
43724380
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
4381+
__has_nlfunc(f) = isdefined(f, :nl_func)
43734382

43744383
# compatibility
43754384
has_invW(f::AbstractSciMLFunction) = false

0 commit comments

Comments
 (0)