Skip to content

Commit c443ac7

Browse files
Fix guess and initial condition length checking
1 parent 4702426 commit c443ac7

File tree

5 files changed

+127
-35
lines changed

5 files changed

+127
-35
lines changed

src/systems/abstractsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2268,14 +2268,14 @@ function UnPack.unpack(sys::ModelingToolkit.AbstractSystem, ::Val{p}) where {p}
22682268
end
22692269

22702270
"""
2271-
missing_variable_defaults(sys::AbstractSystem, default = 0.0)
2271+
missing_variable_defaults(sys::AbstractSystem, default = 0.0; subset = unknowns(sys))
22722272
22732273
returns a `Vector{Pair}` of variables set to `default` which are missing from `get_defaults(sys)`. The `default` argument can be a single value or vector to set the missing defaults respectively.
22742274
"""
2275-
function missing_variable_defaults(sys::AbstractSystem, default = 0.0)
2275+
function missing_variable_defaults(sys::AbstractSystem, default = 0.0; subset = unknowns(sys))
22762276
varmap = get_defaults(sys)
22772277
varmap = Dict(Symbolics.diff2term(value(k)) => value(varmap[k]) for k in keys(varmap))
2278-
missingvars = setdiff(unknowns(sys), keys(varmap))
2278+
missingvars = setdiff(subset, keys(varmap))
22792279
ds = Pair[]
22802280

22812281
n = length(missingvars)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -848,18 +848,34 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
848848
tofloat = true,
849849
symbolic_u0 = false,
850850
u0_constructor = identity,
851+
guesses = Dict(),
852+
warn_initialize_determined = true,
851853
kwargs...)
852854
eqs = equations(sys)
853855
dvs = unknowns(sys)
854856
ps = full_parameters(sys)
855857
iv = get_iv(sys)
856858

859+
initializeprob = ModelingToolkit.InitializationProblem(sys, u0map, parammap; guesses, warn_initialize_determined)
860+
initializeprobmap = getu(initializeprob, unknowns(sys))
861+
862+
# Append zeros to the variables which are determined by the initialization system
863+
# This essentially bypasses the check for if initial conditions are defined for DAEs
864+
# since they will be checked in the initialization problem's construction
865+
# TODO: make check for if a DAE cheaper than calculating the mass matrix a second time!
866+
if implicit_dae || calculate_massmatrix(sys) !== I
867+
zerovars = setdiff(unknowns(sys),defaults(sys)) .=> 0.0
868+
trueinit = identity.([zerovars;u0map])
869+
else
870+
trueinit = u0map
871+
end
872+
857873
if has_index_cache(sys) && get_index_cache(sys) !== nothing
858-
u0, defs = get_u0(sys, u0map, parammap; symbolic_u0)
874+
u0, defs = get_u0(sys, trueinit, parammap; symbolic_u0)
859875
p = MTKParameters(sys, parammap)
860876
else
861877
u0, p, defs = get_u0_p(sys,
862-
u0map,
878+
trueinit,
863879
parammap;
864880
tofloat,
865881
use_union,
@@ -886,9 +902,6 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
886902

887903
check_eqs_u0(eqs, dvs, u0; kwargs...)
888904

889-
initializeprob = ModelingToolkit.InitializationProblem(sys, u0map, parammap)
890-
initializeprobmap = getu(initializeprob, unknowns(sys))
891-
892905
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
893906
checkbounds = checkbounds, p = p,
894907
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
@@ -998,13 +1011,14 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
9981011
parammap = DiffEqBase.NullParameters();
9991012
callback = nothing,
10001013
check_length = true,
1014+
warn_initialize_determined = true,
10011015
kwargs...) where {iip, specialize}
10021016
if !iscomplete(sys)
10031017
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
10041018
end
10051019
f, u0, p = process_DEProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
10061020
t = tspan !== nothing ? tspan[1] : tspan,
1007-
check_length, kwargs...)
1021+
check_length, warn_initialize_determined, kwargs...)
10081022
cbs = process_events(sys; callback, kwargs...)
10091023
inits = []
10101024
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
@@ -1069,13 +1083,14 @@ end
10691083

10701084
function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
10711085
parammap = DiffEqBase.NullParameters();
1086+
warn_initialize_determined = true,
10721087
check_length = true, kwargs...) where {iip}
10731088
if !iscomplete(sys)
10741089
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblem`")
10751090
end
10761091
f, du0, u0, p = process_DEProblem(DAEFunction{iip}, sys, u0map, parammap;
10771092
implicit_dae = true, du0map = du0map, check_length,
1078-
kwargs...)
1093+
warn_initialize_determined, kwargs...)
10791094
diffvars = collect_differential_variables(sys)
10801095
sts = unknowns(sys)
10811096
differential_vars = map(Base.Fix2(in, diffvars), sts)
@@ -1496,18 +1511,33 @@ end
14961511

14971512
function InitializationProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
14981513
parammap = DiffEqBase.NullParameters();
1514+
guesses = [],
14991515
check_length = true,
1516+
warn_initialize_determined = true,
15001517
kwargs...) where {iip, specialize}
15011518
if !iscomplete(sys)
15021519
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
15031520
end
15041521

1505-
isys = get_initializesystem(sys)
1522+
if isempty(u0map)
1523+
isys = get_initializesystem(sys)
1524+
else
1525+
isys = structural_simplify(generate_initializesystem(sys; u0map); fully_determined = false)
1526+
end
1527+
15061528
neqs = length(equations(isys))
15071529
nunknown = length(unknowns(isys))
1530+
1531+
if warn_initialize_determined && neqs > nunknown
1532+
@warn "Initialization system is overdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
1533+
end
1534+
if warn_initialize_determined && neqs < nunknown
1535+
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
1536+
end
1537+
15081538
if neqs == nunknown
1509-
NonlinearProblem(isys, u0map, parammap)
1539+
NonlinearProblem(isys, guesses, parammap)
15101540
else
1511-
NonlinearLeastSquaresProblem(isys, u0map, parammap)
1541+
NonlinearLeastSquaresProblem(isys, guesses, parammap)
15121542
end
15131543
end

src/systems/nonlinear/initializesystem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ $(TYPEDSIGNATURES)
33
44
Generate `NonlinearSystem` which initializes an ODE problem from specified initial conditions of an `ODESystem`.
55
"""
6-
function generate_initializesystem(sys::ODESystem; name = nameof(sys),
6+
function generate_initializesystem(sys::ODESystem;
7+
u0map = Dict(),
8+
name = nameof(sys),
79
guesses = Dict(), check_defguess = false, kwargs...)
810
sts, eqs = unknowns(sys), equations(sys)
911
idxs_diff = isdiffeq.(eqs)
@@ -13,7 +15,7 @@ function generate_initializesystem(sys::ODESystem; name = nameof(sys),
1315
# Start the equations list with algebraic equations
1416
eqs_ics = eqs[idxs_alge]
1517
u0 = Vector{Pair}(undef, 0)
16-
defs = defaults(sys)
18+
defs = merge(defaults(sys),todict(u0map))
1719

1820
full_states = [sts; getfield.((observed(sys)), :lhs)]
1921

src/systems/systemstructure.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
615615
end
616616

617617
function _structural_simplify!(state::TearingState, io; simplify = false,
618-
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
618+
check_consistency = true, fully_determined = true, warn_initialize_determined = false,
619619
kwargs...)
620620
check_consistency &= fully_determined
621621
has_io = io !== nothing
@@ -644,7 +644,7 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
644644
neqs = length(equations(isys))
645645
nunknown = length(unknowns(isys))
646646
if warn_initialize_determined && neqs > nunknown
647-
@warn "Initialization system is overdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares"
647+
@warn "Initialization system is overdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares if $(nunknown - neqs) defaults are not supplied at construction time."
648648
end
649649
if warn_initialize_determined && neqs < nunknown
650650
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares"

test/initializationsystem.jl

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,49 @@
11
using ModelingToolkit, OrdinaryDiffEq, NonlinearSolve, Test
22
using ModelingToolkit: t_nounits as t, D_nounits as D
33

4+
@parameters g
5+
@variables x(t) y(t) [state_priority = 10] λ(t)
6+
eqs = [
7+
D(D(x)) ~ λ * x
8+
D(D(y)) ~ λ * y - g
9+
x^2 + y^2 ~ 1
10+
]
11+
@mtkbuild pend = ODESystem(eqs,t)
12+
13+
initprob = ModelingToolkit.InitializationProblem(pend, [], [g => 1]; guesses = [ModelingToolkit.missing_variable_defaults(pend); x => 1; y => 0.2])
14+
conditions = getfield.(equations(initprob.f.sys),:rhs)
15+
16+
@test initprob isa NonlinearLeastSquaresProblem
17+
sol = solve(initprob)
18+
@test SciMLBase.successful_retcode(sol)
19+
@test maximum(abs.(sol[conditions])) < 1e-14
20+
21+
initprob = ModelingToolkit.InitializationProblem(pend, [x => 1, y => 0], [g => 1]; guesses = ModelingToolkit.missing_variable_defaults(pend))
22+
@test initprob isa NonlinearProblem
23+
sol = solve(initprob)
24+
@test SciMLBase.successful_retcode(sol)
25+
@test sol.u == [1.0,0.0,0.0,0.0]
26+
@test maximum(abs.(sol[conditions])) < 1e-14
27+
28+
initprob = ModelingToolkit.InitializationProblem(pend, [], [g => 1]; guesses = ModelingToolkit.missing_variable_defaults(pend))
29+
@test initprob isa NonlinearLeastSquaresProblem
30+
sol = solve(initprob)
31+
@test !SciMLBase.successful_retcode(sol)
32+
33+
prob = ODEProblem(pend, [x => 1, y => 0], (0.0, 1.5), [g => 1], guesses = ModelingToolkit.missing_variable_defaults(pend))
34+
prob.f.initializeprob isa NonlinearProblem
35+
sol = solve(prob.f.initializeprob)
36+
@test maximum(abs.(sol[conditions])) < 1e-14
37+
sol = solve(prob, Rodas5P())
38+
@test maximum(abs.(sol[conditions][1])) < 1e-14
39+
40+
prob = ODEProblem(pend, [x => 1], (0.0, 1.5), [g => 1], guesses = ModelingToolkit.missing_variable_defaults(pend))
41+
prob.f.initializeprob isa NonlinearLeastSquaresProblem
42+
sol = solve(prob.f.initializeprob)
43+
@test maximum(abs.(sol[conditions])) < 1e-14
44+
sol = solve(prob, Rodas5P())
45+
@test maximum(abs.(sol[conditions][1])) < 1e-14
46+
447
@connector Port begin
548
p(t)
649
dm(t) = 0, [connect = Flow]
@@ -171,34 +214,26 @@ end
171214

172215
@mtkbuild sys = System()
173216
initprob = ModelingToolkit.InitializationProblem(sys)
217+
conditions = getfield.(equations(initprob.f.sys),:rhs)
218+
174219
@test initprob isa NonlinearLeastSquaresProblem
175220
@test length(initprob.u0) == 2
176221
initsol = solve(initprob, reltol = 1e-12, abstol = 1e-12)
177222
@test SciMLBase.successful_retcode(initsol)
223+
@test maximum(abs.(initsol[conditions])) < 1e-14
178224

179225
allinit = unknowns(sys) .=> initsol[unknowns(sys)]
180226
prob = ODEProblem(sys, allinit, (0, 0.1))
181-
sol = solve(prob, Rodas5P())
227+
sol = solve(prob, Rodas5P(), initializealg = BrownFullBasicInit())
182228
# If initialized incorrectly, then it would be InitialFailure
183229
@test sol.retcode == SciMLBase.ReturnCode.Unstable
184-
SciMLBase.has_initializeprob(prob.f)
185-
186-
isys = ModelingToolkit.get_initializesystem(sys)
187-
unknowns(isys)
188-
189-
initprob = ModelingToolkit.InitializationProblem(sys)
190-
sol = solve(initprob)
191-
192-
unknowns(sys)
193-
194-
[sys.act.vol₁.dr]
195-
196-
getter = ModelingToolkit.getu(initprob, unknowns(sys)[end-1:end])
197-
getter(sol)
230+
@test maximum(abs.(initsol[conditions][1])) < 1e-14
198231

199-
prob.f.initializeprobmap(initsol)
200-
201-
initsol[unknowns(isys)]
232+
prob = ODEProblem(sys, [], (0, 0.1), check=false)
233+
sol = solve(prob, Rodas5P())
234+
# If initialized incorrectly, then it would be InitialFailure
235+
@test sol.retcode == SciMLBase.ReturnCode.Unstable
236+
@test maximum(abs.(initsol[conditions][1])) < 1e-14
202237

203238
@connector Flange begin
204239
dx(t), [guess = 0]
@@ -269,3 +304,28 @@ sol = solve(prob, Rodas5P())
269304
# If initialized incorrectly, then it would be InitialFailure
270305
@test sol.retcode == SciMLBase.ReturnCode.Success
271306

307+
prob = ODEProblem(sys, [], (0, 0.1))
308+
sol = solve(prob, Rodas5P())
309+
@test sol.retcode == SciMLBase.ReturnCode.Success
310+
311+
### Ensure that non-DAEs still throw for missing variables without the initialize system
312+
313+
@parameters σ ρ β
314+
@variables x(t) y(t) z(t)
315+
316+
eqs = [D(D(x)) ~ σ * (y - x),
317+
D(y) ~ x *- z) - y,
318+
D(z) ~ x * y - β * z]
319+
320+
@mtkbuild sys = ODESystem(eqs, t)
321+
322+
u0 = [D(x) => 2.0,
323+
y => 0.0,
324+
z => 0.0]
325+
326+
p ==> 28.0,
327+
ρ => 10.0,
328+
β => 8 / 3]
329+
330+
tspan = (0.0, 100.0)
331+
@test_throws ArgumentError prob = ODEProblem(sys, u0, tspan, p, jac = true)

0 commit comments

Comments
 (0)