Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ function var_derivative_graph! end
include("bipartite_graph.jl")
using .BipartiteGraphs

export At
include("variables.jl")
include("parameters.jl")
include("independent_variables.jl")
Expand Down
52 changes: 24 additions & 28 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ end
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
controls = Num[],
observed = Equation[],
constraintsystem = nothing,
constraints = Any[],
costs = Num[],
consolidate = nothing,
systems = ODESystem[],
Expand Down Expand Up @@ -276,11 +276,29 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."

constraintsystem = nothing
if !isempty(constraints)
constraintsystem = process_constraint_system(constraints, dvs, ps, iv)
for p in parameters(constraintsystem)
!in(p, Set(ps)) && push!(ps, p)
end
end

if !isempty(costs)
coststs, costps = process_costs(costs, dvs, ps, iv)
for p in costps
!in(p, Set(ps)) && push!(ps, p)
end
end
costs = wrap.(costs)

iv′ = value(iv)
ps′ = value.(ps)
ctrl′ = value.(controls)
dvs′ = value.(dvs)
dvs′ = filter(x -> !isdelay(x, iv), dvs′)

parameter_dependencies, ps′ = process_parameter_dependencies(
parameter_dependencies, ps′)
if !(isempty(default_u0) && isempty(default_p))
Expand Down Expand Up @@ -350,7 +368,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
metadata, gui_metadata, is_dde, tstops, checks = checks)
end

function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...)
function ODESystem(eqs, iv; kwargs...)
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)

for eq in get(kwargs, :parameter_dependencies, Equation[])
Expand Down Expand Up @@ -382,30 +400,8 @@ function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...)
end
algevars = setdiff(allunknowns, diffvars)

consvars = OrderedSet()
constraintsystem = nothing
if !isempty(constraints)
constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
for st in get_unknowns(constraintsystem)
iscall(st) ?
!in(operation(st)(iv), allunknowns) && push!(consvars, st) :
!in(st, allunknowns) && push!(consvars, st)
end
for p in parameters(constraintsystem)
!in(p, new_ps) && push!(new_ps, p)
end
end

if !isempty(costs)
coststs, costps = process_costs(costs, allunknowns, new_ps, iv)
for p in costps
!in(p, new_ps) && push!(new_ps, p)
end
end
costs = wrap.(costs)

return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
collect(new_ps); constraintsystem, costs, kwargs...)
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))),
collect(new_ps); kwargs...)
end

# NOTE: equality does not check cached Jacobian
Expand Down Expand Up @@ -760,7 +756,7 @@ end
Build the constraint system for the ODESystem.
"""
function process_constraint_system(
constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
constraints::Vector, sts, ps, iv; consname = :cons)
isempty(constraints) && return nothing

constraintsts = OrderedSet()
Expand Down Expand Up @@ -800,7 +796,7 @@ Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 th
parameter of the system.
"""
function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
sts = sysvars
sts = Set(sysvars)

for var in auxvars
if !iscall(var)
Expand Down
4 changes: 2 additions & 2 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ struct DiscreteSystem <: AbstractDiscreteSystem
tearing_state = nothing, substitutions = nothing, namespacing = true,
complete = false, index_cache = nothing, parent = nothing,
isscheduled = false;
checks::Union{Bool, Int} = true)
checks::Union{Bool, Int} = true, kwargs...)
if checks == true || (checks & CheckComponents) > 0
check_independent_variables([iv])
check_variables(dvs, iv)
Expand Down Expand Up @@ -199,7 +199,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
eqs, iv′, dvs′, ps′, tspan, var_to_name, observed, name, description, systems,
defaults, guesses, initializesystem, initialization_eqs, preface, connector_type,
parameter_dependencies, metadata, gui_metadata, kwargs...)
parameter_dependencies, metadata, gui_metadata)
end

function DiscreteSystem(eqs, iv; kwargs...)
Expand Down
42 changes: 39 additions & 3 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
ps, sps, vs, = [], [], []
c_evts = []
d_evts = []
cons = []
costs = []
kwargs = OrderedCollections.OrderedSet()
where_types = Union{Symbol, Expr}[]

Expand All @@ -80,7 +82,7 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
for arg in expr.args
if arg.head == :macrocall
parse_model!(exprs.args, comps, ext, eqs, icon, vs, ps,
sps, c_evts, d_evts, dict, mod, arg, kwargs, where_types)
sps, c_evts, d_evts, cons, costs, dict, mod, arg, kwargs, where_types)
elseif arg.head == :block
push!(exprs.args, arg)
elseif arg.head == :if
Expand Down Expand Up @@ -120,13 +122,15 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) :
GUIMetadata(GlobalRef(mod, name))

consolidate = get(dict, :consolidate, nothing)
description = get(dict, :description, "")

@inline pop_structure_dict!.(
Ref(dict), [:constants, :defaults, :kwargs, :structural_parameters])

sys = :($type($(flatten_equations)(equations), $iv, variables, parameters;
name, description = $description, systems, gui_metadata = $gui_metadata, defaults))
name, description = $description, systems, gui_metadata = $gui_metadata, defaults,
costs = [$(costs...)], constraints = [$(cons...)], consolidate = $consolidate))

if length(ext) == 0
push!(exprs.args, :(var"#___sys___" = $sys))
Expand Down Expand Up @@ -610,7 +614,7 @@ function get_var(mod::Module, b)
end

function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts,
dict, mod, arg, kwargs, where_types)
cons, costs, dict, mod, arg, kwargs, where_types)
mname = arg.args[1]
body = arg.args[end]
if mname == Symbol("@description")
Expand Down Expand Up @@ -638,6 +642,12 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts,
parse_icon!(body, dict, icon, mod)
elseif mname == Symbol("@defaults")
parse_system_defaults!(exprs, arg, dict)
elseif mname == Symbol("@constraints")
parse_costs!(cons, dict, body)
elseif mname == Symbol("@costs")
parse_constraints!(costs, dict, body)
elseif mname == Symbol("@consolidate")
parse_consolidate!(body, dict)
else
error("$mname is not handled.")
end
Expand Down Expand Up @@ -1149,6 +1159,32 @@ function parse_discrete_events!(d_evts, dict, body)
end
end

function parse_constraints!(cons, dict, body)
dict[:constraints] = []
Base.remove_linenums!(body)
for arg in body.args
push!(cons, arg)
push!(dict[:constraints], readable_code.(cons)...)
end
end

function parse_costs!(costs, dict, body)
dict[:costs] = []
Base.remove_linenums!(body)
for arg in body.args
push!(costs, arg)
push!(dict[:costs], readable_code.(costs)...)
end
end

function parse_consolidate!(body, dict)
if !(occursin("->", string(body)) || occursin("=", string(body)))
error("Consolidate must be a function definition.")
else
dict[:consolidate] = body
end
end

function parse_icon!(body::String, dict, icon, mod)
icon_dir = get(ENV, "MTK_ICONS_DIR", joinpath(DEPOT_PATH[1], "mtk_icons"))
dict[:icon] = icon[] = if isfile(body)
Expand Down
2 changes: 1 addition & 1 deletion src/systems/unit_check.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ function validate(jumps::ArrayPartition{<:Union{Any, Vector{<:JumpType}}}, t::Sy
all([validate(jumps.x[idx], t, info = labels[idx]) for idx in 1:3])
end

function validate(eq::Equation; info::String = "")
function validate(eq::Union{Inequality, Equation}; info::String = "")
if typeof(eq.lhs) == Connection
_validate(eq.rhs; info)
else
Expand Down
40 changes: 40 additions & 0 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,43 @@ getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing)

getshift(x::Num) = getshift(unwrap(x))
getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0)

###################
### Evaluate at ###
###################
struct At <: Symbolics.Operator
t::Union{Symbolic, Number}
end

function (A::At)(x::Symbolic)
if symbolic_type(x) == NotSymbolic() || !iscall(x)
if x isa Symbolics.CallWithMetadata
return x(A.t)
else
return x
end
end

if iscall(x) && operation(x) == getindex
arr = arguments(x)[1]
term(getindex, A(arr), arguments(x)[2:end]...)
elseif operation(x) isa Differential
x = default_toterm(x)
A(x)
else
length(arguments(x)) !== 1 &&
error("Variable $x has too many arguments. At can only be applied to one-argument variables.")
(symbolic_type(only(arguments(x))) !== ScalarSymbolic()) && return x
return operation(x)(A.t)
end
end

function (A::At)(x::Union{Num, Symbolics.Arr})
wrap(A(unwrap(x)))
end
SymbolicUtils.isbinop(::At) = false

Base.nameof(::At) = :At
Base.show(io::IO, A::At) = print(io, "At(", A.t, ")")
Base.:(==)(A1::At, A2::At) = isequal(A1.t, A2.t)
Base.hash(A::At, u::UInt) = hash(A.t, u)
32 changes: 32 additions & 0 deletions test/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1026,3 +1026,35 @@ end
@named sys = Float2Bool()
@test typeof(sys) == DiscreteSystem
end

@testset "Constraints, costs, consolidate" begin
@mtkmodel Example begin
@variables begin
x(t)
y(t)
end
@equations begin
x ~ y
end
@constraints begin
At(0.3)(x) ~ 3
y ≲ 4
end
@costs begin
x + y
At(1)(y)^2
end
@consolidate f(u) = u[1]^2 + log(u[2])
end

@named ex = Example()
ex = complete(ex)

costs = ModelingToolkit.get_costs(ex)
constrs = ModelingToolkit.get_constraints(ModelingToolkit.get_constraintsystem(ex))
@test isequal(costs[1], ex.x + ex.y)
@test isequal(costs[2], At(1)(ex.y)^2)
@test isequal(constrs[1], -3 + At(0.3)(ex.x) ~ 0)
@test isequal(constrs[2], -4 + ex.y ≲ 0)
@test ModelingToolkit.get_consolidate(ex)([1, 2]) ≈ 1 + log(2)
end
26 changes: 26 additions & 0 deletions test/variable_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,29 @@ end
@test !isinitial(c)
@test !isinitial(x)
end

@testset "At" begin
@independent_variables u
@variables x(t) v(..) w(t)[1:3]
@parameters y z(u, t) r[1:3]

@test At(1)(x) isa Num
@test isequal(At(1)(y), y)
@test_throws ErrorException At(1)(z)
@test isequal(At(1)(v), v(1))
@test isequal(At(1)(v(t)), v(1))
@test isequal(At(1)(v(2)), v(2))

arr = At(1)(w)
var = At(1)(w[1])
@test arr isa Symbolics.Arr
@test var isa Num

@test isequal(At(1)(r), r)
@test isequal(At(1)(r[2]), r[2])

_x = ModelingToolkit.unwrap(x)
@test At(1)(_x) isa Symbolics.BasicSymbolic
@test only(arguments(At(1)(_x))) == 1
@test At(1)(D(x)) isa Num
end
Loading