Skip to content

Commit ed7212b

Browse files
fix: use SciMLBase.get_initial_values in linearization_function
1 parent 269ef94 commit ed7212b

File tree

2 files changed

+67
-82
lines changed

2 files changed

+67
-82
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ using Compat
4545
using AbstractTrees
4646
using DiffEqBase, SciMLBase, ForwardDiff
4747
using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, TimeDomain,
48-
PeriodicClock, Clock, SolverStepClock, Continuous
48+
PeriodicClock, Clock, SolverStepClock, Continuous, OverrideInit, NoInit
4949
using Distributed
5050
import JuliaFormatter
5151
using MLStyle

src/systems/abstractsystem.jl

Lines changed: 66 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,6 +2357,9 @@ See also [`linearize`](@ref) which provides a higher-level interface.
23572357
function linearization_function(sys::AbstractSystem, inputs,
23582358
outputs; simplify = false,
23592359
initialize = true,
2360+
initializealg = nothing,
2361+
initialization_abstol = 1e-6,
2362+
initialization_reltol = 1e-3,
23602363
op = Dict(),
23612364
p = DiffEqBase.NullParameters(),
23622365
zero_dummy_der = false,
@@ -2383,88 +2386,32 @@ function linearization_function(sys::AbstractSystem, inputs,
23832386
op = merge(defs, op)
23842387
end
23852388
sys = ssys
2386-
u0map = Dict(k => v for (k, v) in op if is_variable(ssys, k))
2387-
initsys = structural_simplify(
2388-
generate_initializesystem(
2389-
sys, u0map = u0map, guesses = guesses(sys), algebraic_only = true),
2390-
fully_determined = false)
2391-
2392-
# HACK: some unknowns may not be involved in any initialization equations, and are
2393-
# thus removed from the system during `structural_simplify`.
2394-
# This causes `getu(initsys, unknowns(sys))` to fail, so we add them back as parameters
2395-
# for now.
2396-
missing_unknowns = setdiff(unknowns(sys), all_symbols(initsys))
2397-
if !isempty(missing_unknowns)
2398-
if warn_initialize_determined
2399-
@warn "Initialization system is underdetermined. No equations for $(missing_unknowns). Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
2400-
end
2401-
new_parameters = [parameters(initsys); missing_unknowns]
2402-
@set! initsys.ps = new_parameters
2403-
initsys = complete(initsys)
2404-
end
2405-
2406-
if p isa SciMLBase.NullParameters
2407-
p = Dict()
2408-
else
2409-
p = todict(p)
2410-
end
2411-
x0 = merge(defaults_and_guesses(sys), op)
2412-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
2413-
sys_ps = MTKParameters(sys, p, x0)
2414-
else
2415-
sys_ps = varmap_to_vars(p, parameters(sys); defaults = x0)
2416-
end
2417-
p[get_iv(sys)] = NaN
2418-
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
2419-
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
2420-
initsys_ps = parameters(initsys)
2421-
p_getter = build_explicit_observed_function(
2422-
sys, initsys_ps; eval_expression, eval_module)
2423-
2424-
u_getter = isempty(unknowns(initsys)) ? (_...) -> nothing :
2425-
build_explicit_observed_function(
2426-
sys, unknowns(initsys); eval_expression, eval_module)
2427-
get_initprob_u_p = let p_getter = p_getter,
2428-
p_setter! = setp(initsys, initsys_ps),
2429-
u_getter = u_getter
2430-
2431-
function (u, p, t)
2432-
p_setter!(oldps, p_getter(u, p, t))
2433-
newu = u_getter(u, p, t)
2434-
return newu, oldps
2435-
end
2436-
end
2437-
else
2438-
get_initprob_u_p = let p_getter = getu(sys, parameters(initsys)),
2439-
u_getter = build_explicit_observed_function(
2440-
sys, unknowns(initsys); eval_expression, eval_module)
2441-
2442-
function (u, p, t)
2443-
state = ProblemState(; u, p, t)
2444-
return u_getter(
2445-
state_values(state), parameter_values(state), current_time(state)),
2446-
p_getter(state)
2447-
end
2448-
end
2389+
2390+
if initializealg === nothing
2391+
initializealg = initialize ? OverrideInit() : NoInit()
24492392
end
2450-
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
2451-
initprobmap = build_explicit_observed_function(
2452-
initsys, unknowns(sys); eval_expression, eval_module)
2393+
2394+
fun, u0, p = process_SciMLProblem(
2395+
ODEFunction{true, SciMLBase.FullSpecialize}, sys, op, p;
2396+
t = 0.0, build_initializeprob = initializealg isa OverrideInit,
2397+
allow_incomplete = true, algebraic_only = true)
2398+
prob = ODEProblem(fun, u0, (nothing, nothing), p)
2399+
24532400
ps = parameters(sys)
24542401
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
24552402
lin_fun = let diff_idxs = diff_idxs,
24562403
alge_idxs = alge_idxs,
24572404
input_idxs = input_idxs,
24582405
sts = unknowns(sys),
2459-
get_initprob_u_p = get_initprob_u_p,
2460-
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
2461-
sys, unknowns(sys), ps; eval_expression, eval_module),
2462-
initfn = initfn,
2463-
initprobmap = initprobmap,
2406+
fun = fun,
2407+
prob = prob,
2408+
sys_ps = p,
24642409
h = h,
2410+
integ_cache = (similar(u0)),
24652411
chunk = ForwardDiff.Chunk(input_idxs),
2466-
sys_ps = sys_ps,
2467-
initialize = initialize,
2412+
initializealg = initializealg,
2413+
initialization_abstol = initialization_abstol,
2414+
initialization_reltol = initialization_reltol,
24682415
initialization_solver_alg = initialization_solver_alg,
24692416
sys = sys
24702417

@@ -2484,14 +2431,14 @@ function linearization_function(sys::AbstractSystem, inputs,
24842431
if u !== nothing # Handle systems without unknowns
24852432
length(sts) == length(u) ||
24862433
error("Number of unknown variables ($(length(sts))) does not match the number of input unknowns ($(length(u)))")
2487-
if initialize && !isempty(alge_idxs) # This is expensive and can be omitted if the user knows that the system is already initialized
2488-
residual = fun(u, p, t)
2489-
if norm(residual[alge_idxs]) > (eps(eltype(residual)))
2490-
initu0, initp = get_initprob_u_p(u, p, t)
2491-
initprob = NonlinearLeastSquaresProblem(initfn, initu0, initp)
2492-
nlsol = solve(initprob, initialization_solver_alg)
2493-
u = initprobmap(state_values(nlsol), parameter_values(nlsol))
2494-
end
2434+
2435+
integ = MockIntegrator{true}(u, p, t, integ_cache)
2436+
u, p, success = SciMLBase.get_initial_values(
2437+
prob, integ, fun, initializealg, Val(true);
2438+
abstol = initialization_abstol, reltol = initialization_reltol,
2439+
nlsolve_alg = initialization_solver_alg)
2440+
if !success
2441+
error("Initialization algorithm $(initializealg) failed with `u = $u` and `p = $p`.")
24952442
end
24962443
uf = SciMLBase.UJacobianWrapper(fun, t, p)
24972444
fg_xz = ForwardDiff.jacobian(uf, u)
@@ -2526,6 +2473,44 @@ function linearization_function(sys::AbstractSystem, inputs,
25262473
return lin_fun, sys
25272474
end
25282475

2476+
"""
2477+
$(TYPEDEF)
2478+
2479+
Mock `DEIntegrator` to allow using `CheckInit` without having to create a new integrator
2480+
(and consequently depend on `OrdinaryDiffEq`).
2481+
2482+
# Fields
2483+
2484+
$(TYPEDFIELDS)
2485+
"""
2486+
struct MockIntegrator{iip, U, P, T, C} <: SciMLBase.DEIntegrator{Nothing, iip, U, T}
2487+
"""
2488+
The state vector.
2489+
"""
2490+
u::U
2491+
"""
2492+
The parameter object.
2493+
"""
2494+
p::P
2495+
"""
2496+
The current time.
2497+
"""
2498+
t::T
2499+
"""
2500+
The integrator cache.
2501+
"""
2502+
cache::C
2503+
end
2504+
2505+
function MockIntegrator{iip}(u::U, p::P, t::T, cache::C) where {iip, U, P, T, C}
2506+
return MockIntegrator{iip, U, P, T, C}(u, p, t, cache)
2507+
end
2508+
2509+
SymbolicIndexingInterface.state_values(integ::MockIntegrator) = integ.u
2510+
SymbolicIndexingInterface.parameter_values(integ::MockIntegrator) = integ.p
2511+
SymbolicIndexingInterface.current_time(integ::MockIntegrator) = integ.t
2512+
SciMLBase.get_tmp_cache(integ::MockIntegrator) = integ.cache
2513+
25292514
"""
25302515
(; A, B, C, D), simplified_sys = linearize_symbolic(sys::AbstractSystem, inputs, outputs; simplify = false, allow_input_derivatives = false, kwargs...)
25312516

0 commit comments

Comments
 (0)