Skip to content

Commit d9d00b3

Browse files
Merge pull request #1067 from AayushSabharwal/as/ode-nlprob
refactor: update `ODE_NLProbData`
2 parents e5675bd + d65c25b commit d9d00b3

File tree

4 files changed

+55
-71
lines changed

4 files changed

+55
-71
lines changed

src/ODE_nlsolve.jl

Lines changed: 0 additions & 45 deletions
This file was deleted.

src/SciMLBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ Internal. Used for signifying the AD context comes from a Mooncake.jl context.
674674
struct MooncakeOriginator <: ADOriginator end
675675

676676
include("initialization.jl")
677-
include("ODE_nlsolve.jl")
677+
include("odenlstep.jl")
678678
include("utils.jl")
679679
include("function_wrappers.jl")
680680
include("scimlfunctions.jl")

src/odenlstep.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
$(TYPEDEF)
3+
4+
A collection of all the data required for custom ODE Nonlinear problem solving
5+
"""
6+
struct ODENLStepData{NLProb, SetU0, SetGammaC, SetOuterTmp, SetInnerTmp, NLProbMap}
7+
"""
8+
The `AbstractNonlinearProblem` to define custom nonlinear problems to be used for
9+
implicit time discretizations. This allows to use extra structure of the ODE function (e.g.
10+
multi-level structure). The nonlinear function must match that form of the function implicit
11+
ODE integration algorithms need do solve the a nonlinear problems,
12+
specifically of the form `M*z = outer_tmp + γ₁⋅f(γ₂⋅z+inner_tmp,p,t_c)`.
13+
Here `z` is the stage solution vector, `p` is the parameter of the ODE problem, `t_c` is
14+
the time of evaluation (`t_c = t + c*dt`), `γ₁` and `γ₂` are some scaling factors determined
15+
by the solver algorithm and the temporary variables are some compatible vectors set by the specific solver.
16+
The inner nonlinear function of the nonlinear problem is in general of the form `g(z,p') = 0` such that
17+
`g(z,p') = γ₁⋅f(γ₂⋅z+inner_tmp,p,t_c) + outer_tmp - M*z = 0`.
18+
"""
19+
nlprob::NLProb
20+
u0perm::SetU0
21+
set_γ_c::SetGammaC
22+
set_outer_tmp::SetOuterTmp
23+
set_inner_tmp::SetInnerTmp
24+
"""
25+
A function which takes the solution of `nlprob` and returns
26+
the state vector of the original problem.
27+
"""
28+
nlprobmap::NLProbMap
29+
end

src/scimlfunctions.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ numerically-defined functions.
409409
"""
410410
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
411411
O, TCV,
412-
SYS, ID <: Union{Nothing, OverrideInitData}, NLP <: Union{Nothing, ODE_NLProbData}} <:
412+
SYS, ID <: Union{Nothing, OverrideInitData}, NLP <: Union{Nothing, ODENLStepData}} <:
413413
AbstractODEFunction{iip}
414414
f::F
415415
mass_matrix::TMM
@@ -428,7 +428,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
428428
colorvec::TCV
429429
sys::SYS
430430
initialization_data::ID
431-
nlprob_data::NLP
431+
nlstep_data::NLP
432432
end
433433

434434
@doc doc"""
@@ -532,7 +532,7 @@ information on generating the SplitFunction from this symbolic engine.
532532
struct SplitFunction{
533533
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
534534
TPJ, O, TCV, SYS, ID <: Union{Nothing, OverrideInitData},
535-
NLP <: Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
535+
NLP <: Union{Nothing, ODENLStepData}} <: AbstractODEFunction{iip}
536536
f1::F1
537537
f2::F2
538538
mass_matrix::TMM
@@ -552,7 +552,7 @@ struct SplitFunction{
552552
colorvec::TCV
553553
sys::SYS
554554
initialization_data::ID
555-
nlprob_data::NLP
555+
nlstep_data::NLP
556556
end
557557

558558
@doc doc"""
@@ -2691,7 +2691,7 @@ function ODEFunction{iip, specialize}(f;
26912691
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
26922692
initialization_data = __has_initialization_data(f) ? f.initialization_data :
26932693
nothing,
2694-
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing
2694+
nlstep_data = __has_nlstep_data(f) ? f.nlstep_data : nothing
26952695
) where {iip,
26962696
specialize
26972697
}
@@ -2749,11 +2749,11 @@ function ODEFunction{iip, specialize}(f;
27492749
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
27502750
Any,
27512751
typeof(_colorvec),
2752-
typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
2752+
typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODENLStepData}}(
27532753
_f, mass_matrix, analytic, tgrad, jac,
27542754
jvp, vjp, jac_prototype, sparsity, Wfact,
27552755
Wfact_t, W_prototype, paramjac,
2756-
observed, _colorvec, sys, initdata, nlprob_data)
2756+
observed, _colorvec, sys, initdata, nlstep_data)
27572757
elseif specialize === false
27582758
ODEFunction{iip, FunctionWrapperSpecialize,
27592759
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2762,11 +2762,11 @@ function ODEFunction{iip, specialize}(f;
27622762
typeof(paramjac),
27632763
typeof(observed),
27642764
typeof(_colorvec),
2765-
typeof(sys), typeof(initdata), typeof(nlprob_data)}(_f, mass_matrix,
2765+
typeof(sys), typeof(initdata), typeof(nlstep_data)}(_f, mass_matrix,
27662766
analytic, tgrad, jac,
27672767
jvp, vjp, jac_prototype, sparsity, Wfact,
27682768
Wfact_t, W_prototype, paramjac,
2769-
observed, _colorvec, sys, initdata, nlprob_data)
2769+
observed, _colorvec, sys, initdata, nlstep_data)
27702770
else
27712771
ODEFunction{iip, specialize,
27722772
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2775,11 +2775,11 @@ function ODEFunction{iip, specialize}(f;
27752775
typeof(paramjac),
27762776
typeof(observed),
27772777
typeof(_colorvec),
2778-
typeof(sys), typeof(initdata), typeof(nlprob_data)}(
2778+
typeof(sys), typeof(initdata), typeof(nlstep_data)}(
27792779
_f, mass_matrix, analytic, tgrad,
27802780
jac, jvp, vjp, jac_prototype, sparsity, Wfact,
27812781
Wfact_t, W_prototype, paramjac,
2782-
observed, _colorvec, sys, initdata, nlprob_data)
2782+
observed, _colorvec, sys, initdata, nlstep_data)
27832783
end
27842784
end
27852785

@@ -2796,23 +2796,23 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
27962796
Any, Any, Any, Any, typeof(f.jac_prototype),
27972797
typeof(f.sparsity), Any, Any, Any, Any,
27982798
Any, typeof(f.colorvec),
2799-
typeof(f.sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
2799+
typeof(f.sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODENLStepData}}(
28002800
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
28012801
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
28022802
f.Wfact_t, f.W_prototype, f.paramjac,
2803-
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob_data)
2803+
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlstep_data)
28042804
else
28052805
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
28062806
typeof(f.analytic), typeof(f.tgrad),
28072807
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
28082808
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
28092809
typeof(f.paramjac),
28102810
typeof(f.observed), typeof(f.colorvec),
2811-
typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob_data)}(
2811+
typeof(f.sys), typeof(f.initialization_data), typeof(f.nlstep_data)}(
28122812
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
28132813
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
28142814
f.Wfact_t, f.W_prototype, f.paramjac,
2815-
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob_data)
2815+
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlstep_data)
28162816
end
28172817
end
28182818

@@ -2948,7 +2948,7 @@ end
29482948
f1, f2, mass_matrix, _func_cache, analytic, tgrad, jac, jvp,
29492949
vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
29502950
observed, colorvec, sys, initializeprob = nothing, update_initializeprob! = nothing,
2951-
initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing, nlprob_data = nothing)
2951+
initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing, nlstep_data = nothing)
29522952
f1 = ODEFunction(f1)
29532953
f2 = ODEFunction(f2)
29542954

@@ -2966,11 +2966,11 @@ end
29662966
typeof(_func_cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
29672967
typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
29682968
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
2969-
typeof(sys), typeof(initdata), typeof(nlprob_data)}(
2969+
typeof(sys), typeof(initdata), typeof(nlstep_data)}(
29702970
f1, f2, mass_matrix,
29712971
_func_cache, analytic, tgrad, jac, jvp, vjp,
29722972
jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
2973-
initdata, nlprob_data)
2973+
initdata, nlstep_data)
29742974
end
29752975
function SplitFunction{iip, specialize}(f1, f2;
29762976
mass_matrix = __has_mass_matrix(f1) ?
@@ -3007,7 +3007,7 @@ function SplitFunction{iip, specialize}(f1, f2;
30073007
f1.update_initializeprob! : nothing,
30083008
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
30093009
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing,
3010-
nlprob_data = __has_nlprob_data(f1) ? f1.nlprob_data : nothing,
3010+
nlstep_data = __has_nlstep_data(f1) ? f1.nlstep_data : nothing,
30113011
initialization_data = __has_initialization_data(f1) ? f1.initialization_data :
30123012
nothing
30133013
) where {iip,
@@ -3021,24 +3021,24 @@ function SplitFunction{iip, specialize}(f1, f2;
30213021
if specialize === NoSpecialize
30223022
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
30233023
Any, Any, Any, Any, Any, Any, Any,
3024-
Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
3024+
Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODENLStepData}}(
30253025
f1, f2, mass_matrix, _func_cache,
30263026
analytic,
30273027
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
30283028
sparsity, Wfact, Wfact_t, paramjac,
3029-
observed, colorvec, sys, initdata, nlprob_data)
3029+
observed, colorvec, sys, initdata, nlstep_data)
30303030
else
30313031
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
30323032
typeof(_func_cache), typeof(analytic),
30333033
typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp),
30343034
typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
30353035
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
30363036
typeof(colorvec),
3037-
typeof(sys), typeof(initdata), typeof(nlprob_data)}(f1, f2,
3037+
typeof(sys), typeof(initdata), typeof(nlstep_data)}(f1, f2,
30383038
mass_matrix, _func_cache, analytic, tgrad, jac,
30393039
jvp, vjp, jac_prototype, W_prototype,
30403040
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
3041-
initdata, nlprob_data)
3041+
initdata, nlstep_data)
30423042
end
30433043
end
30443044

@@ -4779,7 +4779,7 @@ function ODEInputFunction{iip, specialize}(f;
47794779
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
47804780
initialization_data = __has_initialization_data(f) ? f.initialization_data :
47814781
nothing,
4782-
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing
4782+
nlstep_data = __has_nlstep_data(f) ? f.nlstep_data : nothing
47834783
) where {iip,
47844784
specialize
47854785
}
@@ -4938,7 +4938,7 @@ __has_colorvec(f) = isdefined(f, :colorvec)
49384938
__has_sys(f) = isdefined(f, :sys)
49394939
__has_analytic_full(f) = isdefined(f, :analytic_full)
49404940
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
4941-
__has_nlprob_data(f) = isdefined(f, :nlprob_data)
4941+
__has_nlstep_data(f) = isdefined(f, :nlstep_data)
49424942
function __has_initializeprob(f)
49434943
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprob)
49444944
end

0 commit comments

Comments
 (0)