Skip to content

Commit cec0c19

Browse files
author
oscarddssmith
committed
fix nlprob to match the initialization system
1 parent c61b13d commit cec0c19

File tree

1 file changed

+19
-24
lines changed

1 file changed

+19
-24
lines changed

src/scimlfunctions.jl

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,6 @@ the usage of `f`. These include:
289289
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
290290
internally computed on demand when required. The cost of this operation is highly dependent
291291
on the sparsity pattern.
292-
- `nlprob`: a `NonlinearProblem` that solves `f(u, t, p) = u_tmp`
293-
where the nonlinear parameters are the tuple `(t, u_tmp, p)`.
294-
This will be used as the nonlinear problem inside an implicit solver by specifying `u, u_tmp` and `t`
295-
such that solving this function produces a solution to the implicit step of your solver.
296-
297292
## iip: In-Place vs Out-Of-Place
298293
299294
`iip` is the optional boolean for determining whether a given function is written to
@@ -424,7 +419,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
424419
colorvec::TCV
425420
sys::SYS
426421
initialization_data::ID
427-
nlprob::NLP
422+
nlprob_data::NLP
428423
end
429424

430425
@doc doc"""
@@ -547,8 +542,8 @@ struct SplitFunction{
547542
observed::O
548543
colorvec::TCV
549544
sys::SYS
550-
nlprob::NLP
551545
initialization_data::ID
546+
nlprob_data::NLP
552547
end
553548

554549
@doc doc"""
@@ -2446,9 +2441,9 @@ function ODEFunction{iip, specialize}(f;
24462441
f.update_initializeprob! : nothing,
24472442
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
24482443
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
2449-
nlprob = __has_nlprob(f) ? f.nlprob : nothing,
24502444
initialization_data = __has_initialization_data(f) ? f.initialization_data :
24512445
nothing
2446+
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing,
24522447
) where {iip,
24532448
specialize
24542449
}
@@ -2509,7 +2504,7 @@ function ODEFunction{iip, specialize}(f;
25092504
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
25102505
jvp, vjp, jac_prototype, sparsity, Wfact,
25112506
Wfact_t, W_prototype, paramjac,
2512-
observed, _colorvec, sys, initdata, nlprob)
2507+
observed, _colorvec, sys, initdata, nlprob_data)
25132508
elseif specialize === false
25142509
ODEFunction{iip, FunctionWrapperSpecialize,
25152510
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2518,11 +2513,11 @@ function ODEFunction{iip, specialize}(f;
25182513
typeof(paramjac),
25192514
typeof(observed),
25202515
typeof(_colorvec),
2521-
typeof(sys), typeof(initdata), typeof(nlprob)}(_f, mass_matrix,
2516+
typeof(sys), typeof(initdata), typeof(nlprob_data)}(_f, mass_matrix,
25222517
analytic, tgrad, jac,
25232518
jvp, vjp, jac_prototype, sparsity, Wfact,
25242519
Wfact_t, W_prototype, paramjac,
2525-
observed, _colorvec, sys, initdata, nlprob)
2520+
observed, _colorvec, sys, initdata, nlprob_data)
25262521
else
25272522
ODEFunction{iip, specialize,
25282523
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2531,11 +2526,11 @@ function ODEFunction{iip, specialize}(f;
25312526
typeof(paramjac),
25322527
typeof(observed),
25332528
typeof(_colorvec),
2534-
typeof(sys), typeof(initdata), typeof(nlprob)}(
2529+
typeof(sys), typeof(initdata), typeof(nlprob_data)}(
25352530
_f, mass_matrix, analytic, tgrad,
25362531
jac, jvp, vjp, jac_prototype, sparsity, Wfact,
25372532
Wfact_t, W_prototype, paramjac,
2538-
observed, _colorvec, sys, initdata, nlprob)
2533+
observed, _colorvec, sys, initdata, nlprob_data)
25392534
end
25402535
end
25412536

@@ -2556,19 +2551,19 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
25562551
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25572552
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25582553
f.Wfact_t, f.W_prototype, f.paramjac,
2559-
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob)
2554+
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob_data)
25602555
else
25612556
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
25622557
typeof(f.analytic), typeof(f.tgrad),
25632558
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
25642559
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
25652560
typeof(f.paramjac),
25662561
typeof(f.observed), typeof(f.colorvec),
2567-
typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob)}(
2562+
typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob_data)}(
25682563
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
25692564
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
25702565
f.Wfact_t, f.W_prototype, f.paramjac,
2571-
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob)
2566+
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob_data)
25722567
end
25732568
end
25742569

@@ -2703,7 +2698,7 @@ end
27032698
@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
27042699
vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
27052700
observed, colorvec, sys, initializeprob = nothing, update_initializeprob! = nothing,
2706-
initializeprobmap = nothing, initializeprobpmap = nothing, nlprob = nothing, initialization_data = nothing)
2701+
initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing, nlprob_data = nothing)
27072702
f1 = ODEFunction(f1)
27082703
f2 = ODEFunction(f2)
27092704

@@ -2721,11 +2716,11 @@ end
27212716
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
27222717
typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
27232718
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
2724-
typeof(sys), typeof(initdata), typeof(nlprob)}(
2719+
typeof(sys), typeof(initdata), typeof(nlprob_data)}(
27252720
f1, f2, mass_matrix,
27262721
cache, analytic, tgrad, jac, jvp, vjp,
27272722
jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2728-
initdata, nlprob)
2723+
initdata, nlprob_data)
27292724
end
27302725
function SplitFunction{iip, specialize}(f1, f2;
27312726
mass_matrix = __has_mass_matrix(f1) ?
@@ -2762,7 +2757,7 @@ function SplitFunction{iip, specialize}(f1, f2;
27622757
f1.update_initializeprob! : nothing,
27632758
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
27642759
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing,
2765-
nlprob = __has_nlprob(f1) ? f1.nlprob : nothing,
2760+
nlprob_data = __has_nlprob_data(f1) ? f1.nlprob_data : nothing,
27662761
initialization_data = __has_initialization_data(f1) ? f1.initialization_data :
27672762
nothing
27682763
) where {iip,
@@ -2780,19 +2775,19 @@ function SplitFunction{iip, specialize}(f1, f2;
27802775
analytic,
27812776
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
27822777
sparsity, Wfact, Wfact_t, paramjac,
2783-
observed, colorvec, sys, initdata, nlprob)
2778+
observed, colorvec, sys, initdata, nlprob_data)
27842779
else
27852780
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
27862781
typeof(_func_cache), typeof(analytic),
27872782
typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp),
27882783
typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
27892784
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
27902785
typeof(colorvec),
2791-
typeof(sys), typeof(initdata), typeof(nlprob)}(f1, f2,
2786+
typeof(sys), typeof(initdata), typeof(nlprob_data)}(f1, f2,
27922787
mass_matrix, _func_cache, analytic, tgrad, jac,
27932788
jvp, vjp, jac_prototype, W_prototype,
27942789
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2795-
initdata, nlprob)
2790+
initdata, nlprob_data)
27962791
end
27972792
end
27982793

@@ -4488,7 +4483,7 @@ __has_colorvec(f) = isdefined(f, :colorvec)
44884483
__has_sys(f) = isdefined(f, :sys)
44894484
__has_analytic_full(f) = isdefined(f, :analytic_full)
44904485
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
4491-
__has_nlprob(f) = isdefined(f, :nlprob)
4486+
__has_nlprob_data(f) = isdefined(f, :nlprob_data)
44924487
function __has_initializeprob(f)
44934488
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprob)
44944489
end

0 commit comments

Comments
 (0)