Skip to content

Commit cf142a8

Browse files
Merge pull request #845 from AayushSabharwal/as/move-initalgs
feat: add implementations of `CheckInit` and `OverrideInit`
2 parents 2c4dfc0 + b3105cc commit cf142a8

File tree

6 files changed

+330
-5
lines changed

6 files changed

+330
-5
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ LinearAlgebra = "1.10"
7070
Logging = "1.10"
7171
Makie = "0.20, 0.21"
7272
Markdown = "1.10"
73+
NonlinearSolve = "3, 4"
7374
PartialFunctions = "1.1"
7475
PrecompileTools = "1.2"
7576
Preferences = "1.3"
@@ -98,6 +99,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9899
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
99100
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
100101
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
102+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
101103
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
102104
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
103105
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -114,4 +116,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
114116
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
115117

116118
[targets]
117-
test = ["Pkg", "Plots", "UnicodePlots", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "OrdinaryDiffEq", "ForwardDiff", "Tables"]
119+
test = ["Pkg", "Plots", "UnicodePlots", "SafeTestsets", "Serialization", "Test", "StableRNGs", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "NonlinearSolve", "OrdinaryDiffEq", "ForwardDiff", "Tables"]

src/SciMLBase.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,11 @@ $(TYPEDEF)
348348
"""
349349
struct CheckInit <: DAEInitializationAlgorithm end
350350

351+
"""
352+
$(TYPEDEF)
353+
"""
354+
struct OverrideInit <: DAEInitializationAlgorithm end
355+
351356
# PDE Discretizations
352357

353358
"""
@@ -654,7 +659,6 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context.
654659
struct TrackerOriginator <: ADOriginator end
655660

656661
include("utils.jl")
657-
include("initialization.jl")
658662
include("function_wrappers.jl")
659663
include("scimlfunctions.jl")
660664
include("alg_traits.jl")
@@ -740,6 +744,7 @@ include("ensemble/ensemble_problems.jl")
740744
include("ensemble/basic_ensemble_solve.jl")
741745
include("ensemble/ensemble_analysis.jl")
742746

747+
include("initialization.jl")
743748
include("solve.jl")
744749
include("interpolation.jl")
745750
include("integrator_interface.jl")

src/initialization.jl

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
99
"""
1010
initializeprob::IProb
1111
"""
12-
A function which takes `(initializeprob, prob)` and updates
12+
A function which takes `(initializeprob, value_provider)` and updates
1313
the parameters of the former with their values in the latter.
14+
If absent (`nothing`) this will not be called, and the parameters
15+
in `initializeprob` will be used without modification. `value_provider`
16+
refers to a value provider as defined by SymbolicIndexingInterface.jl.
17+
Usually this will refer to a problem or integrator.
1418
"""
1519
update_initializeprob!::UIProb
1620
"""
@@ -20,7 +24,9 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
2024
initializeprobmap::IProbMap
2125
"""
2226
A function which takes the solution of `initializeprob` and returns
23-
the parameter object of the original problem.
27+
the parameter object of the original problem. If absent (`nothing`),
28+
this will not be called and the parameters of the problem being
29+
initialized will be returned as-is.
2430
"""
2531
initializeprobpmap::IProbPmap
2632

@@ -30,3 +36,155 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
3036
return new{I, J, K, L}(initprob, update_initprob!, initprobmap, initprobpmap)
3137
end
3238
end
39+
40+
"""
41+
get_initial_values(prob, valp, f, alg, isinplace; kwargs...)
42+
43+
Return the initial `u0` and `p` for the given SciMLProblem and initialization algorithm,
44+
and a boolean indicating whether the initialization process was successful. Keyword
45+
arguments to this function are dependent on the initialization algorithm. `prob` is only
46+
required for dispatching. `valp` refers the appropriate data structure from which the
47+
current state and parameter values should be obtained. `valp` is a non-timeseries value
48+
provider as defined by SymbolicIndexingInterface.jl. `f` is the SciMLFunction for the
49+
problem. `alg` is the initialization algorithm to use. `isinplace` is either `Val{true}`
50+
if `valp` and the SciMLFunction are inplace, and `Val{false}` otherwise.
51+
"""
52+
function get_initial_values end
53+
54+
struct CheckInitFailureError <: Exception
55+
normresid::Any
56+
abstol::Any
57+
end
58+
59+
function Base.showerror(io::IO, e::CheckInitFailureError)
60+
print(io,
61+
"CheckInit specified but initialization not satisfied. normresid = $(e.normresid) > abstol = $(e.abstol)")
62+
end
63+
64+
struct OverrideInitMissingAlgorithm <: Exception end
65+
66+
function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
67+
print(io,
68+
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.")
69+
end
70+
71+
"""
72+
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
73+
it is in-place or simply calling the function if not.
74+
"""
75+
function _evaluate_f_ode(integrator, f, isinplace::Val{true}, args...)
76+
tmp = first(get_tmp_cache(integrator))
77+
f(tmp, args...)
78+
return tmp
79+
end
80+
81+
function _evaluate_f_ode(integrator, f, isinplace::Val{false}, args...)
82+
return f(args...)
83+
end
84+
85+
"""
86+
$(TYPEDSIGNATURES)
87+
88+
A utility function equivalent to `Base.vec` but also handles `Number` and
89+
`AbstractSciMLScalarOperator`.
90+
"""
91+
_vec(v) = vec(v)
92+
_vec(v::Number) = v
93+
_vec(v::SciMLOperators.AbstractSciMLScalarOperator) = v
94+
_vec(v::AbstractVector) = v
95+
96+
"""
97+
$(TYPEDSIGNATURES)
98+
99+
Check if the algebraic constraints are satisfied, and error if they aren't. Returns
100+
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
101+
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
102+
"""
103+
function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit,
104+
isinplace::Union{Val{true}, Val{false}}; kwargs...)
105+
u0 = state_values(integrator)
106+
p = parameter_values(integrator)
107+
t = current_time(integrator)
108+
M = f.mass_matrix
109+
110+
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
111+
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
112+
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
113+
update_coefficients!(M, u0, p, t)
114+
tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t)
115+
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
116+
117+
normresid = integrator.opts.internalnorm(tmp, t)
118+
if normresid > integrator.opts.abstol
119+
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
120+
end
121+
return u0, p, true
122+
end
123+
124+
"""
125+
Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if
126+
it is in-place or simply calling the function if not.
127+
"""
128+
function _evaluate_f_dae(integrator, f, isinplace::Val{true}, args...)
129+
tmp = get_tmp_cache(integrator)[2]
130+
f(tmp, args...)
131+
return tmp
132+
end
133+
134+
function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...)
135+
return f(args...)
136+
end
137+
138+
function get_initial_values(prob::DAEProblem, integrator, f, alg::CheckInit,
139+
isinplace::Union{Val{true}, Val{false}}; kwargs...)
140+
u0 = state_values(integrator)
141+
p = parameter_values(integrator)
142+
t = current_time(integrator)
143+
144+
resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t)
145+
normresid = integrator.opts.internalnorm(resid, t)
146+
if normresid > integrator.opts.abstol
147+
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
148+
end
149+
return u0, p, true
150+
end
151+
152+
"""
153+
$(TYPEDSIGNATURES)
154+
155+
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
156+
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
157+
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
158+
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
159+
argument, failing which this function will throw an error. The success value returned
160+
depends on the success of the nonlinear solve.
161+
"""
162+
function get_initial_values(prob, valp, f, alg::OverrideInit,
163+
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
164+
u0 = state_values(valp)
165+
p = parameter_values(valp)
166+
167+
if !has_initialization_data(f)
168+
return u0, p, true
169+
end
170+
171+
initdata::OverrideInitData = f.initialization_data
172+
initprob = initdata.initializeprob
173+
174+
if nlsolve_alg === nothing
175+
throw(OverrideInitMissingAlgorithm())
176+
end
177+
178+
if initdata.update_initializeprob! !== nothing
179+
initdata.update_initializeprob!(initprob, valp)
180+
end
181+
182+
nlsol = solve(initprob, nlsolve_alg)
183+
184+
u0 = initdata.initializeprobmap(nlsol)
185+
if initdata.initializeprobpmap !== nothing
186+
p = initdata.initializeprobpmap(nlsol)
187+
end
188+
189+
return u0, p, SciMLBase.successful_retcode(nlsol)
190+
end

src/solutions/save_idxs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,8 @@ function get_save_idxs_and_saved_subsystem(prob, save_idxs)
372372
if isempty(_save_idxs)
373373
# no states to save
374374
save_idxs = Int[]
375-
elseif !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
375+
elseif !(save_idxs isa AbstractArray) ||
376+
symbolic_type(save_idxs) != NotSymbolic()
376377
# only a single state to save, and save it as a scalar timeseries instead of
377378
# single-element array
378379
save_idxs = only(_save_idxs)

test/initialization.jl

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
using OrdinaryDiffEq, NonlinearSolve, SymbolicIndexingInterface, Test
2+
3+
@testset "CheckInit" begin
4+
@testset "ODEProblem" begin
5+
function rhs(u, p, t)
6+
return [u[1] * t, u[1]^2 - u[2]^2]
7+
end
8+
function rhs!(du, u, p, t)
9+
du[1] = u[1] * t
10+
du[2] = u[1]^2 - u[2]^2
11+
end
12+
13+
oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0])
14+
iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0])
15+
16+
@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
17+
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0))
18+
integ = init(prob)
19+
u0, _, success = SciMLBase.get_initial_values(
20+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
21+
@test success
22+
@test u0 == prob.u0
23+
24+
integ.u[2] = 2.0
25+
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
26+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
27+
end
28+
end
29+
30+
@testset "DAEProblem" begin
31+
function daerhs(du, u, p, t)
32+
return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2]
33+
end
34+
function daerhs!(resid, du, u, p, t)
35+
resid[1] = du[1] - u[1] * t - p
36+
resid[2] = u[1]^2 - u[2]^2
37+
end
38+
39+
oopfn = DAEFunction{false}(daerhs)
40+
iipfn = DAEFunction{true}(daerhs!)
41+
42+
@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
43+
prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0)
44+
integ = init(prob, DImplicitEuler())
45+
u0, _, success = SciMLBase.get_initial_values(
46+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
47+
@test success
48+
@test u0 == prob.u0
49+
50+
integ.u[2] = 2.0
51+
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
52+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
53+
54+
integ.u[2] = 1.0
55+
integ.du[1] = 2.0
56+
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
57+
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)))
58+
end
59+
end
60+
end
61+
62+
@testset "OverrideInit" begin
63+
function rhs2(u, p, t)
64+
return [u[1] * t + p, u[1]^2 - u[2]^2]
65+
end
66+
67+
@testset "No-op without `initialization_data`" begin
68+
prob = ODEProblem(rhs2, [1.0, 2.0], (0.0, 1.0), 1.0)
69+
integ = init(prob)
70+
integ.u[2] = 3.0
71+
u0, p, success = SciMLBase.get_initial_values(
72+
prob, integ, prob.f, SciMLBase.OverrideInit(), Val(false))
73+
@test u0 [1.0, 3.0]
74+
@test success
75+
end
76+
77+
# unknowns are u[2], p. Parameter is u[1]
78+
initprob = NonlinearProblem([1.0, 1.0], [1.0]) do x, _u1
79+
u2, p = x
80+
u1 = _u1[1]
81+
return [u1^2 - u2^2, p^2 - 2p + 1]
82+
end
83+
update_initializeprob! = function (iprob, integ)
84+
iprob.p[1] = integ.u[1]
85+
end
86+
initprobmap = function (nlsol)
87+
return [parameter_values(nlsol)[1], nlsol.u[1]]
88+
end
89+
initprobpmap = function (nlsol)
90+
return nlsol.u[2]
91+
end
92+
initialization_data = SciMLBase.OverrideInitData(
93+
initprob, update_initializeprob!, initprobmap, initprobpmap)
94+
fn = ODEFunction(rhs2; initialization_data)
95+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
96+
integ = init(prob; initializealg = NoInit())
97+
98+
@testset "Errors without `nlsolve_alg`" begin
99+
@test_throws SciMLBase.OverrideInitMissingAlgorithm SciMLBase.get_initial_values(
100+
prob, integ, fn, SciMLBase.OverrideInit(), Val(false))
101+
end
102+
103+
@testset "Solves" begin
104+
u0, p, success = SciMLBase.get_initial_values(
105+
prob, integ, fn, SciMLBase.OverrideInit(),
106+
Val(false); nlsolve_alg = NewtonRaphson())
107+
108+
@test u0 [2.0, 2.0]
109+
@test p 1.0
110+
@test success
111+
112+
initprob.p[1] = 1.0
113+
end
114+
115+
@testset "Solves with non-integrator value provider" begin
116+
_integ = ProblemState(; u = integ.u, p = parameter_values(integ), t = integ.t)
117+
u0, p, success = SciMLBase.get_initial_values(
118+
prob, _integ, fn, SciMLBase.OverrideInit(),
119+
Val(false); nlsolve_alg = NewtonRaphson())
120+
121+
@test u0 [2.0, 2.0]
122+
@test p 1.0
123+
@test success
124+
125+
initprob.p[1] = 1.0
126+
end
127+
128+
@testset "Solves without `update_initializeprob!`" begin
129+
initdata = SciMLBase.@set initialization_data.update_initializeprob! = nothing
130+
fn = ODEFunction(rhs2; initialization_data = initdata)
131+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
132+
integ = init(prob; initializealg = NoInit())
133+
134+
u0, p, success = SciMLBase.get_initial_values(
135+
prob, integ, fn, SciMLBase.OverrideInit(),
136+
Val(false); nlsolve_alg = NewtonRaphson())
137+
@test u0 [1.0, 1.0]
138+
@test p 1.0
139+
@test success
140+
end
141+
142+
@testset "Solves without `initializeprobpmap`" begin
143+
initdata = SciMLBase.@set initialization_data.initializeprobpmap = nothing
144+
fn = ODEFunction(rhs2; initialization_data = initdata)
145+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
146+
integ = init(prob; initializealg = NoInit())
147+
148+
u0, p, success = SciMLBase.get_initial_values(
149+
prob, integ, fn, SciMLBase.OverrideInit(),
150+
Val(false); nlsolve_alg = NewtonRaphson())
151+
152+
@test u0 [2.0, 2.0]
153+
@test p 0.0
154+
@test success
155+
end
156+
end

0 commit comments

Comments
 (0)