Skip to content
Merged
140 changes: 56 additions & 84 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,109 +5,81 @@ Generate `NonlinearSystem` which initializes an ODE problem from specified initi
"""
function generate_initializesystem(sys::ODESystem;
u0map = Dict(),
name = nameof(sys),
guesses = Dict(), check_defguess = false,
default_dd_value = 0.0,
algebraic_only = false,
initialization_eqs = [],
check_units = true,
kwargs...)
sts, eqs = unknowns(sys), equations(sys)
guesses = Dict(),
default_dd_guess = 0.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the old name documented anywhere?

algebraic_only = false,
check_units = true, check_defguess = false,
name = nameof(sys), kwargs...)
vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)])
vars_set = Set(vars) # for efficient in-lookup

eqs = equations(sys)
idxs_diff = isdiffeq.(eqs)
idxs_alge = .!idxs_diff
num_alge = sum(idxs_alge)

# Start the equations list with algebraic equations
eqs_ics = eqs[idxs_alge]
u0 = Vector{Pair}(undef, 0)

# prepare map for dummy derivative substitution
eqs_diff = eqs[idxs_diff]
diffmap = Dict(getfield.(eqs_diff, :lhs) .=> getfield.(eqs_diff, :rhs))
observed_diffmap = Dict(Differential(get_iv(sys)).(getfield.((observed(sys)), :lhs)) .=>
Differential(get_iv(sys)).(getfield.((observed(sys)), :rhs)))
full_diffmap = merge(diffmap, observed_diffmap)
D = Differential(get_iv(sys))
diffmap = merge(
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys))
)

full_states = unique([sts; getfield.((observed(sys)), :lhs)])
set_full_states = Set(full_states)
# 1) process dummy derivatives and u0map into initialization system
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
guesses = merge(get_guesses(sys), todict(guesses))
schedule = getfield(sys, :schedule)

if schedule !== nothing
guessmap = [x[1] => get(guesses, x[1], default_dd_value)
for x in schedule.dummy_sub]
dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap))
if u0map === nothing || isempty(u0map)
filtered_u0 = u0map
else
filtered_u0 = Pair[]
for x in u0map
y = get(schedule.dummy_sub, x[1], x[1])
y = ModelingToolkit.fixpoint_sub(y, full_diffmap)

if y ∈ set_full_states
# defer initialization until defaults are merged below
push!(filtered_u0, y => x[2])
elseif y isa Symbolics.Arr
# scalarize array # TODO: don't scalarize arrays
_y = collect(y)
for i in eachindex(_y)
push!(filtered_u0, _y[i] => x[2][i])
end
elseif y isa Symbolics.BasicSymbolic
# y is a derivative expression expanded
# add to the initialization equations
push!(eqs_ics, y ~ x[2])
else
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
end
if !isnothing(schedule)
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
# set dummy derivatives to default_dd_guess unless specified
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
end
for (y, x) in u0map
y = get(schedule.dummy_sub, y, y)
y = fixpoint_sub(y, diffmap)
if y ∈ vars_set
# variables specified in u0 overrides defaults
push!(defs, y => x)
elseif y isa Symbolics.Arr
# TODO: don't scalarize arrays
merge!(defs, Dict(scalarize(y .=> x)))
elseif y isa Symbolics.BasicSymbolic
# y is a derivative expression expanded; add it to the initialization equations
push!(eqs_ics, y ~ x)
else
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
end
filtered_u0 = todict(filtered_u0)
end
else
dd_guess = Dict()
filtered_u0 = todict(u0map)
end

defs = merge(defaults(sys), filtered_u0)

for st in full_states
if st ∈ keys(defs)
def = defs[st]

if def isa Equation
st ∉ keys(guesses) && check_defguess &&
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
push!(eqs_ics, def)
push!(u0, st => guesses[st])
else
push!(eqs_ics, st ~ def)
push!(u0, st => def)
end
elseif st ∈ keys(guesses)
push!(u0, st => guesses[st])
# 2) process other variables
for var in vars
if var ∈ keys(defs)
push!(eqs_ics, var ~ defs[var])
elseif var ∈ keys(guesses)
push!(defs, var => guesses[var])
elseif check_defguess
error("Invalid setup: unknown $(st) has no default value or initial guess")
error("Invalid setup: variable $(var) has no default value or initial guess")
end
end

# 3) process explicitly provided initialization equations
if !algebraic_only
for eq in [get_initialization_eqs(sys); initialization_eqs]
_eq = ModelingToolkit.fixpoint_sub(eq, full_diffmap)
push!(eqs_ics, _eq)
initialization_eqs = [get_initialization_eqs(sys); initialization_eqs]
for eq in initialization_eqs
eq = fixpoint_sub(eq, diffmap) # expand dummy derivatives
push!(eqs_ics, eq)
end
end

pars = [parameters(sys); get_iv(sys)]
nleqs = [eqs_ics; observed(sys)]

sys_nl = NonlinearSystem(nleqs,
full_states,
pars;
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess),
parameter_dependencies = parameter_dependencies(sys),
pars = [parameters(sys); get_iv(sys)] # include independent variable as pseudo-parameter
eqs_ics = [eqs_ics; observed(sys)]
return NonlinearSystem(
eqs_ics, vars, pars;
defaults = defs, parameter_dependencies = parameter_dependencies(sys),
checks = check_units,
name,
kwargs...)

return sys_nl
name, kwargs...
)
end
24 changes: 11 additions & 13 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,23 @@ function NonlinearSystem(eqs, unknowns, ps;
throw(ArgumentError("NonlinearSystem does not accept `continuous_events`, you provided $continuous_events"))
discrete_events === nothing || isempty(discrete_events) ||
throw(ArgumentError("NonlinearSystem does not accept `discrete_events`, you provided $discrete_events"))

name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
# Move things over, but do not touch array expressions
#
# # we cannot scalarize in the loop because `eqs` itself might require
# scalarization
eqs = [x.lhs isa Union{Symbolic, Number} ? 0 ~ x.rhs - x.lhs : x
for x in scalarize(eqs)]

if !(isempty(default_u0) && isempty(default_p))
length(unique(nameof.(systems))) == length(systems) ||
throw(ArgumentError("System names must be unique."))
(isempty(default_u0) && isempty(default_p)) ||
Base.depwarn(
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
:NonlinearSystem, force = true)

# Accept a single (scalar/vector) equation, but make array for consistent internal handling
if !(eqs isa AbstractArray)
eqs = [eqs]
end
sysnames = nameof.(systems)
if length(unique(sysnames)) != length(sysnames)
throw(ArgumentError("System names must be unique."))
end

# Copy equations to canonical form, but do not touch array expressions
eqs = [wrap(eq.lhs) isa Symbolics.Arr ? eq : 0 ~ eq.rhs - eq.lhs for eq in eqs]

jac = RefValue{Any}(EMPTY_JAC)
defaults = todict(defaults)
defaults = Dict{Any, Any}(value(k) => value(v)
Expand Down
9 changes: 9 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,12 @@ oprob_2nd_order_2 = ODEProblem(sys_2nd_order, u0_2nd_order_2, tspan, ps)
sol = solve(oprob_2nd_order_2, Rosenbrock23()) # retcode: Success
@test sol[Y][1] == 2.0
@test sol[D(Y)][1] == 0.5

@testset "Vector in initial conditions" begin
@variables x(t)[1:5] y(t)[1:5]
@named sys = ODESystem([D(x) ~ x, D(y) ~ y], t; initialization_eqs = [y ~ -x])
sys = structural_simplify(sys)
prob = ODEProblem(sys, [sys.x => ones(5)], (0.0, 1.0), [])
sol = solve(prob, Tsit5(), reltol = 1e-4)
@test all(sol(1.0, idxs = sys.x) .≈ +exp(1)) && all(sol(1.0, idxs = sys.y) .≈ -exp(1))
end
14 changes: 14 additions & 0 deletions test/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,20 @@ end
@test_nowarn solve(prob)
end

@testset "System of linear equations with vector variable" begin
# 1st example in https://en.wikipedia.org/w/index.php?title=System_of_linear_equations&oldid=1247697953
@variables x[1:3]
A = [3 2 -1
2 -2 4
-1 1/2 -1]
b = [1, -2, 0]
@named sys = NonlinearSystem(A * x ~ b, [x], [])
sys = structural_simplify(sys)
prob = NonlinearProblem(sys, unknowns(sys) .=> 0.0)
sol = solve(prob)
@test all(sol[x] .≈ A \ b)
end

@testset "resid_prototype when system has no unknowns and an equation" begin
@variables x
@parameters p
Expand Down
1 change: 1 addition & 0 deletions test/reduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ A = reshape(1:(N^2), N, N)
eqs = xs ~ A * xs
@named sys′ = NonlinearSystem(eqs, [xs], [])
sys = structural_simplify(sys′)
@test length(equations(sys)) == 3 && length(observed(sys)) == 2

# issue 958
@parameters k₁ k₂ k₋₁ E₀
Expand Down
Loading