Skip to content

Commit 3cf7743

Browse files
Merge pull request #881 from AayushSabharwal/as/lazy-remake
feat: add lazy initialization to `remake`
2 parents 598c7cd + 723f561 commit 3cf7743

File tree

4 files changed

+112
-33
lines changed

4 files changed

+112
-33
lines changed

src/initialization.jl

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,24 @@ end
6060

6161
function Base.showerror(io::IO, e::CheckInitFailureError)
6262
print(io,
63-
"DAE initialization failed: your u0 did not satisfy the initialization requirements,
64-
normresid = $(e.normresid) > abstol = $(e.abstol)."
65-
)
63+
"""
64+
DAE initialization failed: your u0 did not satisfy the initialization requirements, \
65+
normresid = $(e.normresid) > abstol = $(e.abstol).
66+
""")
6667

6768
if e.isdae
68-
print(io, " If you wish for the system to
69-
automatically change the algebraic variables to satisfy the algebraic constraints,
70-
please pass `initializealg = BrownBasicInit()` to solve (this option will require
71-
`using OrdinaryDiffEqNonlinearSolve`). If you wish to perform an initialization on the
72-
complete u0, please pass `initializealg = ShampineCollocationInit()` to solve. Note that
73-
initialization can be a very difficult process for DAEs and in many cases can be
74-
numerically intractable without symbolic manipulation of the system. For an automated
75-
system that will generate numerically stable initializations, see ModelingToolkit.jl
76-
structural simplification for more details."
77-
)
69+
print(io,
70+
"""
71+
If you wish for the system to automatically change the algebraic variables to \
72+
satisfy the algebraic constraints, please pass `initializealg = BrownBasicInit()` \
73+
to solve (this option will require `using OrdinaryDiffEqNonlinearSolve`). If you \
74+
wish to perform an initialization on the complete u0, please pass \
75+
`initializealg = ShampineCollocationInit()` to `solve`. Note that initialization \
76+
can be a very difficult process for DAEs and in many cases can be numerically \
77+
intractable without symbolic manipulation of the system. For an automated \
78+
system that will generate numerically stable initializations, see \
79+
ModelingToolkit.jl structural simplification for more details.
80+
""")
7881
end
7982
end
8083

@@ -188,6 +191,9 @@ Keyword arguments:
188191
provided to the `OverrideInit` constructor takes priority over this keyword argument.
189192
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
190193
an error will be thrown.
194+
195+
In case the initialization problem is trivial, `nlsolve_alg`, `abstol` and `reltol` are
196+
not required.
191197
"""
192198
function get_initial_values(prob, valp, f, alg::OverrideInit,
193199
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
@@ -201,35 +207,55 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
201207
initdata::OverrideInitData = f.initialization_data
202208
initprob = initdata.initializeprob
203209

204-
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
205-
if nlsolve_alg === nothing && state_values(initprob) !== nothing
206-
throw(OverrideInitMissingAlgorithm())
207-
end
208-
209210
if initdata.update_initializeprob! !== nothing
210211
initdata.update_initializeprob!(initprob, valp)
211212
end
212213

213-
if alg.abstol !== nothing
214-
_abstol = alg.abstol
215-
elseif abstol !== nothing
216-
_abstol = abstol
214+
if is_trivial_initialization(initdata)
215+
nlsol = initprob
216+
success = true
217217
else
218-
throw(OverrideInitNoTolerance(:abstol))
218+
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
219+
if nlsolve_alg === nothing && state_values(initprob) !== nothing
220+
throw(OverrideInitMissingAlgorithm())
221+
end
222+
if alg.abstol !== nothing
223+
_abstol = alg.abstol
224+
elseif abstol !== nothing
225+
_abstol = abstol
226+
else
227+
throw(OverrideInitNoTolerance(:abstol))
228+
end
229+
if alg.reltol !== nothing
230+
_reltol = alg.reltol
231+
elseif reltol !== nothing
232+
_reltol = reltol
233+
else
234+
throw(OverrideInitNoTolerance(:reltol))
235+
end
236+
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
237+
success = SciMLBase.successful_retcode(nlsol)
219238
end
220-
if alg.reltol !== nothing
221-
_reltol = alg.reltol
222-
elseif reltol !== nothing
223-
_reltol = reltol
224-
else
225-
throw(OverrideInitNoTolerance(:reltol))
226-
end
227-
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
228239

229240
u0 = initdata.initializeprobmap(nlsol)
230241
if initdata.initializeprobpmap !== nothing
231242
p = initdata.initializeprobpmap(valp, nlsol)
232243
end
233244

234-
return u0, p, SciMLBase.successful_retcode(nlsol)
245+
return u0, p, success
246+
end
247+
248+
is_trivial_initialization(::Nothing) = true
249+
250+
function is_trivial_initialization(initdata::OverrideInitData)
251+
!(initdata.initializeprob isa NonlinearLeastSquaresProblem) &&
252+
state_values(initdata.initializeprob) === nothing
253+
end
254+
255+
function is_trivial_initialization(f::AbstractSciMLFunction)
256+
has_initialization_data(f) && is_trivial_initialization(f.initialization_data)
257+
end
258+
259+
function is_trivial_initialization(prob::AbstractSciMLProblem)
260+
is_trivial_initialization(prob.f)
235261
end

src/remake.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ function remake(prob::ODEProblem; f = missing,
114114
interpret_symbolicmap = true,
115115
build_initializeprob = true,
116116
use_defaults = false,
117+
lazy_initialization = nothing,
117118
_kwargs...)
118119
if tspan === missing
119120
tspan = prob.tspan
@@ -123,6 +124,8 @@ function remake(prob::ODEProblem; f = missing,
123124

124125
iip = isinplace(prob)
125126

127+
initialization_data = prob.f.initialization_data
128+
126129
if f === missing
127130
if build_initializeprob
128131
initialization_data = remake_initialization_data_compat_wrapper(
@@ -170,13 +173,28 @@ function remake(prob::ODEProblem; f = missing,
170173
_f = ODEFunction{isinplace(prob), specialization(prob.f)}(f)
171174
end
172175

173-
if kwargs === missing
176+
prob = if kwargs === missing
174177
ODEProblem{isinplace(prob)}(
175178
_f, newu0, tspan, newp, prob.problem_type; prob.kwargs...,
176179
_kwargs...)
177180
else
178181
ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...)
179182
end
183+
184+
if lazy_initialization === nothing
185+
lazy_initialization = !is_trivial_initialization(initialization_data)
186+
end
187+
if !lazy_initialization
188+
u0, p, _ = get_initial_values(
189+
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
190+
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
191+
u0 = nothing
192+
end
193+
@reset prob.u0 = u0
194+
@reset prob.p = p
195+
end
196+
197+
return prob
180198
end
181199

182200
"""

test/downstream/modelingtoolkit_remake.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,12 @@ end
336336
@test sccprob4.p !== sccprob4.probs[1].p
337337
@test sccprob4.p !== sccprob4.probs[2].p
338338
end
339+
340+
@testset "Lazy initialization" begin
341+
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
342+
@parameters p=missing [guess = 1.0]
343+
@mtkbuild sys = ODESystem([D(x) ~ x, x + y ~ p], t)
344+
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
345+
prob2 = remake(prob; u0 = [x => 2.0])
346+
@test prob2.ps[p] 3.0
347+
end

test/initialization.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,4 +244,30 @@ end
244244
@test p 0.0
245245
@test success
246246
end
247+
248+
@testset "Trivial initialization" begin
249+
initprob = NonlinearProblem(Returns(nothing), nothing, [1.0])
250+
update_initializeprob! = function (iprob, integ)
251+
iprob.p[1] = integ.u[1]
252+
end
253+
initprobmap = function (nlsol)
254+
u1 = parameter_values(nlsol)[1]
255+
return [u1, u1]
256+
end
257+
initprobpmap = function (_, nlsol)
258+
return 0.0
259+
end
260+
initialization_data = SciMLBase.OverrideInitData(
261+
initprob, update_initializeprob!, initprobmap, initprobpmap)
262+
fn = ODEFunction(rhs2; initialization_data)
263+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
264+
integ = init(prob; initializealg = NoInit())
265+
266+
u0, p, success = SciMLBase.get_initial_values(
267+
prob, integ, fn, SciMLBase.OverrideInit(), Val(false)
268+
)
269+
@test u0 [2.0, 2.0]
270+
@test p 0.0
271+
@test success
272+
end
247273
end

0 commit comments

Comments
 (0)