Skip to content

Commit c9cb8f8

Browse files
feat: add implementations of CheckInit and OverrideInit
1 parent 2c4dfc0 commit c9cb8f8

File tree

2 files changed

+166
-3
lines changed

2 files changed

+166
-3
lines changed

src/SciMLBase.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,11 @@ $(TYPEDEF)
348348
"""
349349
struct CheckInit <: DAEInitializationAlgorithm end
350350

351+
"""
352+
$(TYPEDEF)
353+
"""
354+
struct OverrideInit <: DAEInitializationAlgorithm end
355+
351356
# PDE Discretizations
352357

353358
"""
@@ -654,7 +659,6 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context.
654659
struct TrackerOriginator <: ADOriginator end
655660

656661
include("utils.jl")
657-
include("initialization.jl")
658662
include("function_wrappers.jl")
659663
include("scimlfunctions.jl")
660664
include("alg_traits.jl")
@@ -740,6 +744,7 @@ include("ensemble/ensemble_problems.jl")
740744
include("ensemble/basic_ensemble_solve.jl")
741745
include("ensemble/ensemble_analysis.jl")
742746

747+
include("initialization.jl")
743748
include("solve.jl")
744749
include("interpolation.jl")
745750
include("integrator_interface.jl")

src/initialization.jl

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
99
"""
1010
initializeprob::IProb
1111
"""
12-
A function which takes `(initializeprob, prob)` and updates
12+
A function which takes `(initializeprob, value_provider)` and updates
1313
the parameters of the former with their values in the latter.
14+
If absent (`nothing`) this will not be called, and the parameters
15+
in `initializeprob` will be used without modification. `value_provider`
16+
refers to a value provider as defined by SymbolicIndexingInterface.jl.
17+
Usually this will refer to a problem or integrator.
1418
"""
1519
update_initializeprob!::UIProb
1620
"""
@@ -20,7 +24,9 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
2024
initializeprobmap::IProbMap
2125
"""
2226
A function which takes the solution of `initializeprob` and returns
23-
the parameter object of the original problem.
27+
the parameter object of the original problem. If absent (`nothing`),
28+
this will not be called and the parameters of the problem being
29+
initialized will be returned as-is.
2430
"""
2531
initializeprobpmap::IProbPmap
2632

@@ -30,3 +36,155 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
3036
return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap)
3137
end
3238
end
39+
40+
"""
41+
get_initial_values(prob, valp, f, alg, isinplace; kwargs...)
42+
43+
Return the initial `u0` and `p` for the given SciMLProblem and initialization algorithm,
44+
and a boolean indicating whether the initialization process was successful. Keyword
45+
arguments to this function are dependent on the initialization algorithm. `prob` is only
46+
required for dispatching. `valp` refers the appropriate data structure from which the
47+
current state and parameter values should be obtained. `valp` is a non-timeseries value
48+
provider as defined by SymbolicIndexingInterface.jl. `f` is the SciMLFunction for the
49+
problem. `alg` is the initialization algorithm to use. `isinplace` is either `Val{true}`
50+
if `valp` and the SciMLFunction are inplace, and `Val{false}` otherwise.
51+
"""
52+
function get_initial_values end
53+
54+
struct CheckInitFailureError <: Exception
55+
normresid::Any
56+
abstol::Any
57+
end
58+
59+
function Base.showerror(io::IO, e::CheckInitFailureError)
60+
print(io,
61+
"CheckInit specified but initialization not satisfied. normresid = $(e.normresid) > abstol = $(e.abstol)")
62+
end
63+
64+
struct OverrideInitMissingAlgorithm <: Exception end
65+
66+
function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
67+
print(io,
68+
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.")
69+
end
70+
71+
"""
72+
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
73+
it is in-place or simply calling the function if not.
74+
"""
75+
function _evaluate_f_ode(integrator, f, isinplace::Val{true}, args...)
76+
tmp = first(get_tmp_cache(integrator))
77+
f(tmp, args...)
78+
return tmp
79+
end
80+
81+
function _evaluate_f_ode(integrator, f, isinplace::Val{false}, args...)
82+
return f(args...)
83+
end
84+
85+
"""
86+
$(TYPEDSIGNATURES)
87+
88+
A utility function equivalent to `Base.vec` but also handles `Number` and
89+
`AbstractSciMLScalarOperator`.
90+
"""
91+
_vec(v) = vec(v)
92+
_vec(v::Number) = v
93+
_vec(v::SciMLOperators.AbstractSciMLScalarOperator) = v
94+
_vec(v::AbstractVector) = v
95+
96+
"""
97+
$(TYPEDSIGNATURES)
98+
99+
Check if the algebraic constraints are satisfied, and error if they aren't. Returns
100+
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
101+
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
102+
"""
103+
function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit,
104+
isinplace::Union{Val{true}, Val{false}}; kwargs...)
105+
u0 = state_values(integrator)
106+
p = parameter_values(integrator)
107+
t = current_time(integrator)
108+
M = f.mass_matrix
109+
110+
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
111+
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
112+
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
113+
update_coefficients!(M, u0, p, t)
114+
tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t)
115+
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
116+
117+
normresid = integrator.opts.internalnorm(tmp, t)
118+
if normresid > integrator.opts.abstol
119+
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
120+
end
121+
return u0, p, true
122+
end
123+
124+
"""
125+
Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if
126+
it is in-place or simply calling the function if not.
127+
"""
128+
function _evaluate_f_dae(integrator, f, isinplace::Val{true}, args...)
129+
tmp = get_tmp_cache(integrator)[2]
130+
f(tmp, args...)
131+
return tmp
132+
end
133+
134+
function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...)
135+
return f(args...)
136+
end
137+
138+
function get_initial_values(prob::DAEProblem, integrator, f, alg::CheckInit,
139+
isinplace::Union{Val{true}, Val{false}}; kwargs...)
140+
u0 = state_values(integrator)
141+
p = parameter_values(integrator)
142+
t = current_time(integrator)
143+
144+
resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t)
145+
normresid = integrator.opts.internalnorm(resid, t)
146+
if normresid > integrator.opts.abstol
147+
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
148+
end
149+
return u0, p, true
150+
end
151+
152+
"""
153+
$(TYPEDSIGNATURES)
154+
155+
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
156+
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
157+
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
158+
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
159+
argument, failing which this function will throw an error. The success value returned
160+
depends on the success of the nonlinear solve.
161+
"""
162+
function get_initial_values(prob, valp, f, alg::OverrideInit,
163+
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
164+
u0 = state_values(valp)
165+
p = parameter_values(valp)
166+
167+
if !has_initialization_data(f)
168+
return u0, p, true
169+
end
170+
171+
initdata::OverrideInitData = f.initialization_data
172+
initprob = initdata.initializeprob
173+
174+
if nlsolve_alg === nothing
175+
throw(OverrideInitMissingAlgorithm())
176+
end
177+
178+
if initdata.update_initializeprob! !== nothing
179+
initdata.update_initializeprob!(initprob, valp)
180+
end
181+
182+
nlsol = solve(initprob, nlsolve_alg)
183+
184+
u0 = initdata.initializeprobmap(nlsol)
185+
if initdata.initializeprobpmap !== nothing
186+
p = initdata.initializeprobpmap(nlsol)
187+
end
188+
189+
return u0, p, SciMLBase.successful_retcode(nlsol)
190+
end

0 commit comments

Comments
 (0)