Skip to content

Commit 7894397

Browse files
fix: use SciMLBase.get_initial_values in linearization_function
1 parent e2263e7 commit 7894397

File tree

2 files changed

+61
-82
lines changed

2 files changed

+61
-82
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ using Compat
4545
using AbstractTrees
4646
using DiffEqBase, SciMLBase, ForwardDiff
4747
using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, TimeDomain,
48-
PeriodicClock, Clock, SolverStepClock, Continuous
48+
PeriodicClock, Clock, SolverStepClock, Continuous, OverrideInit, NoInit
4949
using Distributed
5050
import JuliaFormatter
5151
using MLStyle

src/systems/abstractsystem.jl

Lines changed: 60 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2377,6 +2377,9 @@ See also [`linearize`](@ref) which provides a higher-level interface.
23772377
function linearization_function(sys::AbstractSystem, inputs,
23782378
outputs; simplify = false,
23792379
initialize = true,
2380+
initializealg = nothing,
2381+
initialization_abstol = 1e-6,
2382+
initialization_reltol = 1e-3,
23802383
op = Dict(),
23812384
p = DiffEqBase.NullParameters(),
23822385
zero_dummy_der = false,
@@ -2403,88 +2406,29 @@ function linearization_function(sys::AbstractSystem, inputs,
24032406
op = merge(defs, op)
24042407
end
24052408
sys = ssys
2406-
u0map = Dict(k => v for (k, v) in op if is_variable(ssys, k))
2407-
initsys = structural_simplify(
2408-
generate_initializesystem(
2409-
sys, u0map = u0map, guesses = guesses(sys), algebraic_only = true),
2410-
fully_determined = false)
2411-
2412-
# HACK: some unknowns may not be involved in any initialization equations, and are
2413-
# thus removed from the system during `structural_simplify`.
2414-
# This causes `getu(initsys, unknowns(sys))` to fail, so we add them back as parameters
2415-
# for now.
2416-
missing_unknowns = setdiff(unknowns(sys), all_symbols(initsys))
2417-
if !isempty(missing_unknowns)
2418-
if warn_initialize_determined
2419-
@warn "Initialization system is underdetermined. No equations for $(missing_unknowns). Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
2420-
end
2421-
new_parameters = [parameters(initsys); missing_unknowns]
2422-
@set! initsys.ps = new_parameters
2423-
initsys = complete(initsys)
2424-
end
2425-
2426-
if p isa SciMLBase.NullParameters
2427-
p = Dict()
2428-
else
2429-
p = todict(p)
2430-
end
2431-
x0 = merge(defaults_and_guesses(sys), op)
2432-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
2433-
sys_ps = MTKParameters(sys, p, x0)
2434-
else
2435-
sys_ps = varmap_to_vars(p, parameters(sys); defaults = x0)
2436-
end
2437-
p[get_iv(sys)] = NaN
2438-
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
2439-
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
2440-
initsys_ps = parameters(initsys)
2441-
p_getter = build_explicit_observed_function(
2442-
sys, initsys_ps; eval_expression, eval_module)
2443-
2444-
u_getter = isempty(unknowns(initsys)) ? (_...) -> nothing :
2445-
build_explicit_observed_function(
2446-
sys, unknowns(initsys); eval_expression, eval_module)
2447-
get_initprob_u_p = let p_getter = p_getter,
2448-
p_setter! = setp(initsys, initsys_ps),
2449-
u_getter = u_getter
2450-
2451-
function (u, p, t)
2452-
p_setter!(oldps, p_getter(u, p, t))
2453-
newu = u_getter(u, p, t)
2454-
return newu, oldps
2455-
end
2456-
end
2457-
else
2458-
get_initprob_u_p = let p_getter = getu(sys, parameters(initsys)),
2459-
u_getter = build_explicit_observed_function(
2460-
sys, unknowns(initsys); eval_expression, eval_module)
2461-
2462-
function (u, p, t)
2463-
state = ProblemState(; u, p, t)
2464-
return u_getter(
2465-
state_values(state), parameter_values(state), current_time(state)),
2466-
p_getter(state)
2467-
end
2468-
end
2409+
2410+
if initializealg === nothing
2411+
initializealg = initialize ? OverrideInit() : NoInit()
24692412
end
2470-
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
2471-
initprobmap = build_explicit_observed_function(
2472-
initsys, unknowns(sys); eval_expression, eval_module)
2413+
2414+
fun, u0, p = process_SciMLProblem(ODEFunction{true, SciMLBase.FullSpecialize}, sys, op, p; t = 0.0, build_initializeprob = initializealg isa OverrideInit, allow_incomplete = true, algebraic_only = true)
2415+
prob = ODEProblem(fun, u0, (nothing, nothing), p)
2416+
24732417
ps = parameters(sys)
24742418
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
24752419
lin_fun = let diff_idxs = diff_idxs,
24762420
alge_idxs = alge_idxs,
24772421
input_idxs = input_idxs,
24782422
sts = unknowns(sys),
2479-
get_initprob_u_p = get_initprob_u_p,
2480-
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
2481-
sys, unknowns(sys), ps; eval_expression, eval_module),
2482-
initfn = initfn,
2483-
initprobmap = initprobmap,
2423+
fun = fun,
2424+
prob = prob,
2425+
sys_ps = p,
24842426
h = h,
2427+
integ_cache = (similar(u0)),
24852428
chunk = ForwardDiff.Chunk(input_idxs),
2486-
sys_ps = sys_ps,
2487-
initialize = initialize,
2429+
initializealg = initializealg,
2430+
initialization_abstol = initialization_abstol,
2431+
initialization_reltol = initialization_reltol,
24882432
initialization_solver_alg = initialization_solver_alg,
24892433
sys = sys
24902434

@@ -2504,14 +2448,11 @@ function linearization_function(sys::AbstractSystem, inputs,
25042448
if u !== nothing # Handle systems without unknowns
25052449
length(sts) == length(u) ||
25062450
error("Number of unknown variables ($(length(sts))) does not match the number of input unknowns ($(length(u)))")
2507-
if initialize && !isempty(alge_idxs) # This is expensive and can be omitted if the user knows that the system is already initialized
2508-
residual = fun(u, p, t)
2509-
if norm(residual[alge_idxs]) > (eps(eltype(residual)))
2510-
initu0, initp = get_initprob_u_p(u, p, t)
2511-
initprob = NonlinearLeastSquaresProblem(initfn, initu0, initp)
2512-
nlsol = solve(initprob, initialization_solver_alg)
2513-
u = initprobmap(state_values(nlsol), parameter_values(nlsol))
2514-
end
2451+
2452+
integ = MockIntegrator{true}(u, p, t, integ_cache)
2453+
u, p, success = SciMLBase.get_initial_values(prob, integ, fun, initializealg, Val(true); abstol = initialization_abstol, reltol = initialization_reltol, nlsolve_alg = initialization_solver_alg)
2454+
if !success
2455+
error("Initialization algorithm $(initializealg) failed with `u = $u` and `p = $p`.")
25152456
end
25162457
uf = SciMLBase.UJacobianWrapper(fun, t, p)
25172458
fg_xz = ForwardDiff.jacobian(uf, u)
@@ -2546,6 +2487,44 @@ function linearization_function(sys::AbstractSystem, inputs,
25462487
return lin_fun, sys
25472488
end
25482489

2490+
"""
2491+
$(TYPEDEF)
2492+
2493+
Mock `DEIntegrator` to allow using `CheckInit` without having to create a new integrator
2494+
(and consequently depend on `OrdinaryDiffEq`).
2495+
2496+
# Fields
2497+
2498+
$(TYPEDFIELDS)
2499+
"""
2500+
struct MockIntegrator{iip, U, P, T, C} <: SciMLBase.DEIntegrator{Nothing, iip, U, T}
2501+
"""
2502+
The state vector.
2503+
"""
2504+
u::U
2505+
"""
2506+
The parameter object.
2507+
"""
2508+
p::P
2509+
"""
2510+
The current time.
2511+
"""
2512+
t::T
2513+
"""
2514+
The integrator cache.
2515+
"""
2516+
cache::C
2517+
end
2518+
2519+
function MockIntegrator{iip}(u::U, p::P, t::T, cache::C) where {iip, U, P, T, C}
2520+
return MockIntegrator{iip, U, P, T, C}(u, p, t, cache)
2521+
end
2522+
2523+
SymbolicIndexingInterface.state_values(integ::MockIntegrator) = integ.u
2524+
SymbolicIndexingInterface.parameter_values(integ::MockIntegrator) = integ.p
2525+
SymbolicIndexingInterface.current_time(integ::MockIntegrator) = integ.t
2526+
SciMLBase.get_tmp_cache(integ::MockIntegrator) = integ.cache
2527+
25492528
"""
25502529
(; A, B, C, D), simplified_sys = linearize_symbolic(sys::AbstractSystem, inputs, outputs; simplify = false, allow_input_derivatives = false, kwargs...)
25512530

0 commit comments

Comments
 (0)