Skip to content

Commit 78b2061

Browse files
author
oscarddssmith
committed
add nlfunc to ODEFunction
1 parent 3c1211a commit 78b2061

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

src/scimlfunctions.jl

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ the usage of `f`. These include:
289289
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
290290
internally computed on demand when required. The cost of this operation is highly dependent
291291
on the sparsity pattern.
292+
- `nlfunc`: a `NonlinearFunction`
292293
293294
## iip: In-Place vs Out-Of-Place
294295
@@ -401,8 +402,8 @@ automatically symbolically generating the Jacobian and more from the
401402
numerically-defined functions.
402403
"""
403404
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
404-
O, TCV,
405-
SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
405+
O, TCV, SYS,
406+
IProb, UIProb, IProbMap, IProbPmap, NLF} <: AbstractODEFunction{iip}
406407
f::F
407408
mass_matrix::TMM
408409
analytic::Ta
@@ -423,6 +424,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
423424
update_initializeprob!::UIProb
424425
initializeprobmap::IProbMap
425426
initializeprobpmap::IProbPmap
427+
nlfunc::NLF
426428
end
427429

428430
@doc doc"""
@@ -525,8 +527,8 @@ information on generating the SplitFunction from this symbolic engine.
525527
"""
526528
struct SplitFunction{
527529
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
528-
TPJ, O,
529-
TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
530+
TPJ, O, TCV, SYS,
531+
IProb, UIProb, IProbMap, IProbPmap, NLF} <: AbstractODEFunction{iip}
530532
f1::F1
531533
f2::F2
532534
mass_matrix::TMM
@@ -549,6 +551,7 @@ struct SplitFunction{
549551
update_initializeprob!::UIProb
550552
initializeprobmap::IProbMap
551553
initializeprobpmap::IProbPmap
554+
nlfunc::NLF
552555
end
553556

554557
@doc doc"""
@@ -2426,7 +2429,8 @@ function ODEFunction{iip, specialize}(f;
24262429
update_initializeprob! = __has_update_initializeprob!(f) ?
24272430
f.update_initializeprob! : nothing,
24282431
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
2429-
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing
2432+
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
2433+
nlfunc = __has_nlfunc(f) ? f.nlfunc : nothing,
24302434
) where {iip,
24312435
specialize
24322436
}
@@ -2484,11 +2488,11 @@ function ODEFunction{iip, specialize}(f;
24842488
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
24852489
Any,
24862490
typeof(_colorvec),
2487-
typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
2491+
typeof(sys), Any, Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
24882492
jvp, vjp, jac_prototype, sparsity, Wfact,
24892493
Wfact_t, W_prototype, paramjac,
24902494
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2491-
initializeprobpmap)
2495+
initializeprobpmap, nlfunc)
24922496
elseif specialize === false
24932497
ODEFunction{iip, FunctionWrapperSpecialize,
24942498
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2497,13 +2501,15 @@ function ODEFunction{iip, specialize}(f;
24972501
typeof(paramjac),
24982502
typeof(observed),
24992503
typeof(_colorvec),
2500-
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
2501-
typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix,
2504+
typeof(sys), typeof(initializeprob),
2505+
typeof(update_initializeprob!),
2506+
typeof(initializeprobmap),
2507+
typeof(initializeprobpmap),
2508+
typeof(nlfunc)}(_f, mass_matrix,
25022509
analytic, tgrad, jac,
25032510
jvp, vjp, jac_prototype, sparsity, Wfact,
25042511
Wfact_t, W_prototype, paramjac,
2505-
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2506-
initializeprobpmap)
2512+
observed, _colorvec, sys, initializeprob, initializeprobmap, nlfunc)
25072513
else
25082514
ODEFunction{iip, specialize,
25092515
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2514,11 +2520,12 @@ function ODEFunction{iip, specialize}(f;
25142520
typeof(_colorvec),
25152521
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
25162522
typeof(initializeprobmap),
2517-
typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac,
2523+
typeof(initializeprobpmap)
2524+
typeof(nlfunc)}(_f, mass_matrix, analytic, tgrad, jac,
25182525
jvp, vjp, jac_prototype, sparsity, Wfact,
25192526
Wfact_t, W_prototype, paramjac,
25202527
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
2521-
initializeprobpmap)
2528+
initializeprobpmap, nlfunc)
25222529
end
25232530
end
25242531

@@ -2535,13 +2542,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25352542
Any, Any, Any, Any, typeof(f.jac_prototype),
25362543
typeof(f.sparsity), Any, Any, Any,
25372544
Any, typeof(f.colorvec),
2538-
typeof(f.sys), Any, Any, Any, Any}(
2545+
typeof(f.sys), Any, Any, Any, Any, Any}(
25392546
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25402547
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25412548
f.Wfact_t, f.W_prototype, f.paramjac,
25422549
f.observed, f.colorvec, f.sys, f.initializeprob,
25432550
f.update_initializeprob!, f.initializeprobmap,
2544-
f.initializeprobpmap)
2551+
f.initializeprobpmap, f.nlfunc)
25452552
else
25462553
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
25472554
typeof(f.analytic), typeof(f.tgrad),
@@ -2551,11 +2558,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25512558
typeof(f.observed), typeof(f.colorvec),
25522559
typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!),
25532560
typeof(f.initializeprobmap),
2554-
typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2561+
typeof(f.initializeprobpmap),
2562+
typof(f.nlfunc)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25552563
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25562564
f.Wfact_t, f.W_prototype, f.paramjac,
25572565
f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!,
2558-
f.initializeprobmap, f.initializeprobpmap)
2566+
f.initializeprobmap, f.initializeprobpmap, f.nlfunc)
25592567
end
25602568
end
25612569

@@ -4376,6 +4384,7 @@ __has_initializeprob(f) = isdefined(f, :initializeprob)
43764384
__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!)
43774385
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
43784386
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
4387+
__has_nlfunc(f) = isdefined(f, :nl_func)
43794388

43804389
# compatibility
43814390
has_invW(f::AbstractSciMLFunction) = false

0 commit comments

Comments
 (0)