Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -495,10 +495,18 @@ function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
end

function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym)
return let _fn = build_explicit_observed_function(sys, sym)
fn(u, p, t) = _fn(u, p, t)
fn(u, p::MTKParameters, t) = _fn(u, p..., t)
fn
if is_time_dependent(sys)
return let _fn = build_explicit_observed_function(sys, sym)
fn(u, p, t) = _fn(u, p, t)
fn(u, p::MTKParameters, t) = _fn(u, p..., t)
fn
end
else
return let _fn = build_explicit_observed_function(sys, sym)
fn2(u, p) = _fn(u, p)
fn2(u, p::MTKParameters) = _fn(u, p...)
fn2
end
end
end

Expand Down Expand Up @@ -1849,14 +1857,17 @@ function linearization_function(sys::AbstractSystem, inputs,
end
initfn = NonlinearFunction(initsys)
initprobmap = getu(initsys, unknowns(sys))
initprob_init! = generate_initializeprob_init(sys, initsys)
initprob_update! = generate_initializeprob_update(sys, initsys)
ps = full_parameters(sys)
lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
input_idxs = input_idxs,
sts = unknowns(sys),
get_initprob_u_p = get_initprob_u_p,
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
sys, unknowns(sys), ps; initializeprobmap = initprobmap),
sys, unknowns(sys), ps; initializeprob_init! = initprob_init!,
initializeprob_update! = initprob_update!),
initfn = initfn,
h = build_explicit_observed_function(sys, outputs),
chunk = ForwardDiff.Chunk(input_idxs),
Expand Down
82 changes: 62 additions & 20 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,25 @@ function isautonomous(sys::AbstractODESystem)
all(iszero, tgrad)
end

struct GetAndSetFunctor{G, S}
getter::G
setter::S
end

function (gs::GetAndSetFunctor)(dest, source)
gs.setter(dest, gs.getter(source))
end

function generate_initializeprob_init(sys::AbstractSystem, initsys::AbstractSystem)
syms = vcat(variable_symbols(initsys), parameter_symbols(initsys))
return GetAndSetFunctor(getu(sys, syms), setu(initsys, syms))
end

function generate_initializeprob_update(sys::AbstractSystem, initsys::AbstractSystem)
syms = vcat(variable_symbols(sys), parameter_symbols(sys))
return GetAndSetFunctor(getu(initsys, syms), setu(sys, syms))
end

"""
```julia
DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
Expand Down Expand Up @@ -323,7 +342,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
analytic = nothing,
split_idxs = nothing,
initializeprob = nothing,
initializeprobmap = nothing,
initializeprob_init! = nothing,
initializeprob_update! = nothing,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
Expand Down Expand Up @@ -506,7 +526,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
analytic = analytic,
initializeprob = initializeprob,
initializeprobmap = initializeprobmap)
initializeprob_init! = initializeprob_init!,
initializeprob_update! = initializeprob_update!)
end

"""
Expand Down Expand Up @@ -537,7 +558,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
eval_module = @__MODULE__,
checkbounds = false,
initializeprob = nothing,
initializeprobmap = nothing,
initializeprob_init! = nothing,
initializeprob_update! = nothing,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
Expand Down Expand Up @@ -611,7 +633,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
jac_prototype = jac_prototype,
observed = observedfun,
initializeprob = initializeprob,
initializeprobmap = initializeprobmap)
initializeprob_init! = initializeprob_init!,
initializeprob_update! = initializeprob_update!)
end

function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
Expand Down Expand Up @@ -862,7 +885,6 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
varmap = canonicalize_varmap(varmap)
varlist = collect(map(unwrap, dvs))
missingvars = setdiff(varlist, collect(keys(varmap)))

# Append zeros to the variables which are determined by the initialization system
# This essentially bypasses the check for if initial conditions are defined for DAEs
# since they will be checked in the initialization problem's construction
Expand All @@ -873,11 +895,14 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
parammap = Dict(unwrap(k) => v for (k, v) in todict(parammap))
elseif parammap isa AbstractArray
if isempty(parammap)
parammap = SciMLBase.NullParameters()
parammap = Dict()
else
parammap = Dict(unwrap.(parameters(sys)) .=> parammap)
end
elseif parammap === nothing || parammap isa SciMLBase.NullParameters
parammap = Dict()
end
missingpars = setdiff(parameters(sys), keys(parammap))

if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing
clockedparammap = Dict()
Expand All @@ -886,7 +911,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
v = unwrap(v)
is_discrete_domain(v) || continue
op = operation(v)
if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() &&
if !isa(op, Symbolics.Operator) && !isempty(parammap) &&
haskey(parammap, v)
error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).")
end
Expand All @@ -909,7 +934,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
# TODO: make it work with clocks
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
if sys isa ODESystem && build_initializeprob &&
(implicit_dae || !isempty(missingvars)) &&
(implicit_dae || !isempty(missingvars) || !isempty(missingpars)) &&
all(isequal(Continuous()), ci.var_domain) &&
ModelingToolkit.get_tearing_state(sys) !== nothing &&
t !== nothing
Expand All @@ -921,15 +946,28 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
end
initializeprob = ModelingToolkit.InitializationProblem(
sys, t, u0map, parammap; guesses, warn_initialize_determined)
initializeprobmap = getu(initializeprob, unknowns(sys))

punknowns = [p
for p in parameters(sys)
if is_variable(initializeprob, p) || is_observed(initializeprob, p)]
initializeprob_init! = generate_initializeprob_init(sys, initializeprob.f.sys)
initializeprob_update! = generate_initializeprob_update(sys, initializeprob.f.sys)
zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0)
zeropars = Dict()
for p in punknowns
zeropars[p] = if Symbolics.isarraysymbolic(p)
collect(unwrap.(zero(p)))
else
unwrap(zero(p))
end
end
trueinit = collect(merge(zerovars, eltype(u0map) <: Pair ? todict(u0map) : u0map))
u0map isa StaticArraysCore.StaticArray &&
(trueinit = SVector{length(trueinit)}(trueinit))
else
initializeprob = nothing
initializeprobmap = nothing
zeropars = Dict()
initializeprob_init! = nothing
initializeprob_update! = nothing
trueinit = u0map
end

Expand All @@ -940,7 +978,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
parammap == SciMLBase.NullParameters() && isempty(defs)
nothing
else
MTKParameters(sys, parammap, trueinit)
if parammap === nothing || parammap == SciMLBase.NullParameters()
parammap = Dict()
else
parammap = todict(parammap)
end
MTKParameters(sys, merge(parammap, zeropars), trueinit)
end
else
u0, p, defs = get_u0_p(sys,
Expand Down Expand Up @@ -973,8 +1016,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
checkbounds = checkbounds, p = p,
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
sparse = sparse, eval_expression = eval_expression,
initializeprob = initializeprob,
initializeprobmap = initializeprobmap,
initializeprob = initializeprob, initializeprob_init! = initializeprob_init!,
initializeprob_update! = initializeprob_update!,
kwargs...)
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
end
Expand Down Expand Up @@ -1602,13 +1645,15 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
end
parammap = parammap isa SciMLBase.NullParameters ? Dict() : todict(parammap)
if isempty(u0map) && get_initializesystem(sys) !== nothing
isys = get_initializesystem(sys)
elseif isempty(u0map) && get_initializesystem(sys) === nothing
isys = structural_simplify(generate_initializesystem(sys); fully_determined = false)
isys = structural_simplify(
generate_initializesystem(sys; pmap = parammap); fully_determined = false)
else
isys = structural_simplify(
generate_initializesystem(sys; u0map); fully_determined = false)
generate_initializesystem(sys; u0map, pmap = parammap); fully_determined = false)
end

uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)])
Expand All @@ -1628,10 +1673,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
if warn_initialize_determined && neqs < nunknown
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
end

parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
[get_iv(sys) => t] :
merge(todict(parammap), Dict(get_iv(sys) => t))
parammap[get_iv(sys)] = t
if isempty(u0map)
u0map = Dict()
end
Expand Down
29 changes: 28 additions & 1 deletion src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Generate `NonlinearSystem` which initializes an ODE problem from specified initi
"""
function generate_initializesystem(sys::ODESystem;
u0map = Dict(),
pmap = Dict(),
name = nameof(sys),
guesses = Dict(), check_defguess = false,
default_dd_value = 0.0,
Expand Down Expand Up @@ -69,6 +70,32 @@ function generate_initializesystem(sys::ODESystem;
defs = merge(defaults(sys), filtered_u0)
guesses = merge(get_guesses(sys), todict(guesses), dd_guess)

all_params = parameters(sys)
pars = [parameters(sys); get_iv(sys)]
paramsubs = Dict()
for p in all_params
haskey(pmap, p) && continue
paramsubs[p] = tovar(p)
push!(full_states, tovar(p))
deleteat!(pars, findfirst(isequal(p), pars))
if haskey(defs, p)
def = defs[p]
if def isa Equation
p ∉ keys(guesses) && check_defguess &&
error("Invalid setup: parameter $(p) has an initial condition equation with no guess.")
push!(eqs_ics, def)
push!(u0, p => guesses[p])
else
push!(eqs_ics, p ~ def)
push!(u0, p => def)
end
elseif haskey(guesses, p)
push!(u0, p => guesses[p])
elseif check_defguess
error("Invalid setup: parameter $(p) has no default value or initial guess")
end
end

if !algebraic_only
for st in full_states
if st ∈ keys(defs)
Expand All @@ -91,12 +118,12 @@ function generate_initializesystem(sys::ODESystem;
end
end

pars = [parameters(sys); get_iv(sys)]
nleqs = if algebraic_only
[eqs_ics; observed(sys)]
else
[eqs_ics; get_initialization_eqs(sys); observed(sys)]
end
nleqs = fast_substitute(nleqs, paramsubs)

sys_nl = NonlinearSystem(nleqs,
full_states,
Expand Down