Skip to content

Commit f4d8e36

Browse files
author
oscarddssmith
committed
add nlfunc to ODEFunction
1 parent 06864fd commit f4d8e36

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

src/scimlfunctions.jl

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ the usage of `f`. These include:
289289
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
290290
internally computed on demand when required. The cost of this operation is highly dependent
291291
on the sparsity pattern.
292+
- `nlfunc`: a `NonlinearFunction`
292293
293294
## iip: In-Place vs Out-Of-Place
294295
@@ -401,8 +402,8 @@ automatically symbolically generating the Jacobian and more from the
401402
numerically-defined functions.
402403
"""
403404
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
404-
O, TCV,
405-
SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
405+
O, TCV, SYS, IProb, IProbMap,
406+
NLF} <: AbstractODEFunction{iip}
406407
f::F
407408
mass_matrix::TMM
408409
analytic::Ta
@@ -421,6 +422,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
421422
sys::SYS
422423
initializeprob::IProb
423424
initializeprobmap::IProbMap
425+
nlfunc::NLF
424426
end
425427

426428
@doc doc"""
@@ -517,8 +519,8 @@ information on generating the SplitFunction from this symbolic engine.
517519
"""
518520
struct SplitFunction{
519521
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt,
520-
TPJ, O,
521-
TCV, SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
522+
TPJ, O, TCV, SYS, IProb, IProbMap,
523+
NLF} <: AbstractODEFunction{iip}
522524
f1::F1
523525
f2::F2
524526
mass_matrix::TMM
@@ -538,6 +540,7 @@ struct SplitFunction{
538540
sys::SYS
539541
initializeprob::IProb
540542
initializeprobmap::IProbMap
543+
nlfunc::NLF
541544
end
542545

543546
@doc doc"""
@@ -2415,7 +2418,8 @@ function ODEFunction{iip, specialize}(f;
24152418
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
24162419
sys = __has_sys(f) ? f.sys : nothing,
24172420
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
2418-
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing
2421+
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
2422+
nlfunc = __has_nlfunc(f) ? f.nlfunc : nothing,
24192423
) where {iip,
24202424
specialize
24212425
}
@@ -2471,12 +2475,13 @@ function ODEFunction{iip, specialize}(f;
24712475
Any, Any, Any, Any,
24722476
Any, Any, Any, typeof(jac_prototype),
24732477
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
2474-
Any,
2475-
typeof(_colorvec),
2476-
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
2478+
Any,typeof(_colorvec),
2479+
typeof(sys), Any, Any,
2480+
Any}(_f, mass_matrix, analytic, tgrad, jac,
24772481
jvp, vjp, jac_prototype, sparsity, Wfact,
24782482
Wfact_t, W_prototype, paramjac,
2479-
observed, _colorvec, sys, initializeprob, initializeprobmap)
2483+
observed, _colorvec, sys, initializeprob, initializeprobmap,
2484+
nlfunc)
24802485
elseif specialize === false
24812486
ODEFunction{iip, FunctionWrapperSpecialize,
24822487
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2486,10 +2491,12 @@ function ODEFunction{iip, specialize}(f;
24862491
typeof(observed),
24872492
typeof(_colorvec),
24882493
typeof(sys), typeof(initializeprob),
2489-
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
2494+
typeof(initializeprobmap),
2495+
typeof(nlfunc)}(_f, mass_matrix, analytic, tgrad, jac,
24902496
jvp, vjp, jac_prototype, sparsity, Wfact,
24912497
Wfact_t, W_prototype, paramjac,
2492-
observed, _colorvec, sys, initializeprob, initializeprobmap)
2498+
observed, _colorvec, sys, initializeprob, initializeprobmap,
2499+
nlfun)
24932500
else
24942501
ODEFunction{iip, specialize,
24952502
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2499,10 +2506,12 @@ function ODEFunction{iip, specialize}(f;
24992506
typeof(observed),
25002507
typeof(_colorvec),
25012508
typeof(sys), typeof(initializeprob),
2502-
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
2509+
typeof(initializeprobmap),
2510+
typeof(nlfunc))}(_f, mass_matrix, analytic, tgrad, jac,
25032511
jvp, vjp, jac_prototype, sparsity, Wfact,
25042512
Wfact_t, W_prototype, paramjac,
2505-
observed, _colorvec, sys, initializeprob, initializeprobmap)
2513+
observed, _colorvec, sys, initializeprob, initializeprobmap,
2514+
nlfunc)
25062515
end
25072516
end
25082517

@@ -2519,10 +2528,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25192528
Any, Any, Any, Any, typeof(f.jac_prototype),
25202529
typeof(f.sparsity), Any, Any, Any,
25212530
Any, typeof(f.colorvec),
2522-
typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2531+
typeof(f.sys), Any, Any,
2532+
Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25232533
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25242534
f.Wfact_t, f.W_prototype, f.paramjac,
2525-
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap)
2535+
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap,
2536+
f.nlfunc)
25262537
else
25272538
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
25282539
typeof(f.analytic), typeof(f.tgrad),
@@ -2531,11 +2542,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25312542
typeof(f.paramjac),
25322543
typeof(f.observed), typeof(f.colorvec),
25332544
typeof(f.sys), typeof(f.initializeprob),
2534-
typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2545+
typeof(f.initializeprobmap),
2546+
typeof(f.nlfunc)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25352547
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25362548
f.Wfact_t, f.W_prototype, f.paramjac,
25372549
f.observed, f.colorvec, f.sys, f.initializeprob,
2538-
f.initializeprobmap)
2550+
f.initializeprobmap,
2551+
f.nlfunc)
25392552
end
25402553
end
25412554

@@ -4336,6 +4349,7 @@ __has_analytic_full(f) = isdefined(f, :analytic_full)
43364349
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
43374350
__has_initializeprob(f) = isdefined(f, :initializeprob)
43384351
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
4352+
__has_nlfunc(f) = isdefined(f, :nl_func)
43394353

43404354
# compatibility
43414355
has_invW(f::AbstractSciMLFunction) = false

0 commit comments

Comments
 (0)