Skip to content

Commit cb046ff

Browse files
committed
feat: At and costs in @mtkmodel
1 parent 30bf372 commit cb046ff

File tree

5 files changed

+165
-31
lines changed

5 files changed

+165
-31
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ end
247247
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
248248
controls = Num[],
249249
observed = Equation[],
250-
constraintsystem = nothing,
250+
constraints = Any[],
251251
costs = Num[],
252252
consolidate = nothing,
253253
systems = ODESystem[],
@@ -276,11 +276,30 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
276276
name === nothing &&
277277
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
278278
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
279+
280+
constraintsystem = nothing
281+
if !isempty(constraints)
282+
@show constraints
283+
constraintsystem = process_constraint_system(constraints, dvs, ps, iv)
284+
for p in parameters(constraintsystem)
285+
!in(p, Set(ps)) && push!(ps, p)
286+
end
287+
end
288+
289+
if !isempty(costs)
290+
coststs, costps = process_costs(costs, dvs, ps, iv)
291+
for p in costps
292+
!in(p, Set(ps)) && push!(ps, p)
293+
end
294+
end
295+
costs = wrap.(costs)
296+
279297
iv′ = value(iv)
280298
ps′ = value.(ps)
281299
ctrl′ = value.(controls)
282300
dvs′ = value.(dvs)
283301
dvs′ = filter(x -> !isdelay(x, iv), dvs′)
302+
284303
parameter_dependencies, ps′ = process_parameter_dependencies(
285304
parameter_dependencies, ps′)
286305
if !(isempty(default_u0) && isempty(default_p))
@@ -350,7 +369,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
350369
metadata, gui_metadata, is_dde, tstops, checks = checks)
351370
end
352371

353-
function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...)
372+
function ODESystem(eqs, iv; kwargs...)
354373
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
355374

356375
for eq in get(kwargs, :parameter_dependencies, Equation[])
@@ -382,29 +401,7 @@ function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...)
382401
end
383402
algevars = setdiff(allunknowns, diffvars)
384403

385-
consvars = OrderedSet()
386-
constraintsystem = nothing
387-
if !isempty(constraints)
388-
constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv)
389-
for st in get_unknowns(constraintsystem)
390-
iscall(st) ?
391-
!in(operation(st)(iv), allunknowns) && push!(consvars, st) :
392-
!in(st, allunknowns) && push!(consvars, st)
393-
end
394-
for p in parameters(constraintsystem)
395-
!in(p, new_ps) && push!(new_ps, p)
396-
end
397-
end
398-
399-
if !isempty(costs)
400-
coststs, costps = process_costs(costs, allunknowns, new_ps, iv)
401-
for p in costps
402-
!in(p, new_ps) && push!(new_ps, p)
403-
end
404-
end
405-
costs = wrap.(costs)
406-
407-
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
404+
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))),
408405
collect(new_ps); constraintsystem, costs, kwargs...)
409406
end
410407

@@ -760,7 +757,7 @@ end
760757
Build the constraint system for the ODESystem.
761758
"""
762759
function process_constraint_system(
763-
constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
760+
constraints::Vector, sts, ps, iv; consname = :cons)
764761
isempty(constraints) && return nothing
765762

766763
constraintsts = OrderedSet()
@@ -800,7 +797,7 @@ Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 th
800797
parameter of the system.
801798
"""
802799
function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
803-
sts = sysvars
800+
sts = Set(sysvars)
804801

805802
for var in auxvars
806803
if !iscall(var)
@@ -810,6 +807,7 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
810807
throw(ArgumentError("Too many arguments for variable $var."))
811808
elseif length(arguments(var)) == 1
812809
arg = only(arguments(var))
810+
@show sts
813811
operation(var)(iv) sts ||
814812
throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem."))
815813

src/systems/model_parsing.jl

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
6565
ps, sps, vs, = [], [], []
6666
c_evts = []
6767
d_evts = []
68+
cons = []
69+
costs = []
6870
kwargs = OrderedCollections.OrderedSet()
6971
where_types = Union{Symbol, Expr}[]
7072

@@ -80,7 +82,7 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
8082
for arg in expr.args
8183
if arg.head == :macrocall
8284
parse_model!(exprs.args, comps, ext, eqs, icon, vs, ps,
83-
sps, c_evts, d_evts, dict, mod, arg, kwargs, where_types)
85+
sps, c_evts, d_evts, cons, costs, dict, mod, arg, kwargs, where_types)
8486
elseif arg.head == :block
8587
push!(exprs.args, arg)
8688
elseif arg.head == :if
@@ -117,16 +119,19 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector)
117119
push!(exprs.args, :(push!(systems, $(comps...))))
118120
push!(exprs.args, :(push!(variables, $(vs...))))
119121

122+
120123
gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) :
121124
GUIMetadata(GlobalRef(mod, name))
122125

126+
consolidate = get(dict, :consolidate, nothing)
123127
description = get(dict, :description, "")
124128

125129
@inline pop_structure_dict!.(
126130
Ref(dict), [:constants, :defaults, :kwargs, :structural_parameters])
127131

128132
sys = :($type($(flatten_equations)(equations), $iv, variables, parameters;
129-
name, description = $description, systems, gui_metadata = $gui_metadata, defaults))
133+
name, description = $description, systems, gui_metadata = $gui_metadata, defaults,
134+
costs = [$(costs...)], constraints = [$(cons...)], consolidate = $consolidate))
130135

131136
if length(ext) == 0
132137
push!(exprs.args, :(var"#___sys___" = $sys))
@@ -610,9 +615,10 @@ function get_var(mod::Module, b)
610615
end
611616

612617
function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts,
613-
dict, mod, arg, kwargs, where_types)
618+
cons, costs, dict, mod, arg, kwargs, where_types)
614619
mname = arg.args[1]
615620
body = arg.args[end]
621+
@show dict
616622
if mname == Symbol("@description")
617623
parse_description!(body, dict)
618624
elseif mname == Symbol("@components")
@@ -637,7 +643,13 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts,
637643
isassigned(icon) && error("This model has more than one icon.")
638644
parse_icon!(body, dict, icon, mod)
639645
elseif mname == Symbol("@defaults")
640-
parse_system_defaults!(exprs, arg, dict)
646+
parse_system_defaults!(exprs, dict, body)
647+
elseif mname == Symbol("@constraints")
648+
parse_costs!(cons, dict, body)
649+
elseif mname == Symbol("@costs")
650+
parse_constraints!(costs, dict, body)
651+
elseif mname == Symbol("@consolidate")
652+
parse_consolidate!(body, dict)
641653
else
642654
error("$mname is not handled.")
643655
end
@@ -1149,6 +1161,33 @@ function parse_discrete_events!(d_evts, dict, body)
11491161
end
11501162
end
11511163

1164+
function parse_constraints!(cons, dict, body)
1165+
dict[:constraints] = []
1166+
Base.remove_linenums!(body)
1167+
for arg in body.args
1168+
push!(cons, arg)
1169+
push!(dict[:constraints], readable_code.(cons)...)
1170+
end
1171+
end
1172+
1173+
function parse_costs!(costs, dict, body)
1174+
@show dict
1175+
dict[:costs] = []
1176+
Base.remove_linenums!(body)
1177+
for arg in body.args
1178+
push!(costs, arg)
1179+
push!(dict[:costs], readable_code.(costs)...)
1180+
end
1181+
end
1182+
1183+
function parse_consolidate!(body, dict)
1184+
if !(occursin("->", string(body)) || occursin("=", string(body)))
1185+
error("Consolidate must be a function definition.")
1186+
else
1187+
dict[:consolidate] = body
1188+
end
1189+
end
1190+
11521191
function parse_icon!(body::String, dict, icon, mod)
11531192
icon_dir = get(ENV, "MTK_ICONS_DIR", joinpath(DEPOT_PATH[1], "mtk_icons"))
11541193
dict[:icon] = icon[] = if isfile(body)

src/variables.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,3 +612,42 @@ getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing)
612612

613613
getshift(x::Num) = getshift(unwrap(x))
614614
getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0)
615+
616+
###################
617+
### Evaluate at ###
618+
###################
619+
struct At <: Symbolics.Operator
620+
t::Union{Symbolic, Number}
621+
end
622+
623+
function (A::At)(x::Symbolic)
624+
if symbolic_type(x) == NotSymbolic() || !iscall(x)
625+
if x isa Symbolics.CallWithMetadata
626+
return x(A.t)
627+
else
628+
return x
629+
end
630+
end
631+
632+
if iscall(x) && operation(x) == getindex
633+
arr = arguments(x)[1]
634+
term(getindex, A(arr), arguments(x)[2:end]...)
635+
elseif operation(x) isa Differential
636+
x = default_toterm(x)
637+
A(x)
638+
else
639+
length(arguments(x)) !== 1 && error("Variable $x has too many arguments. At can only be applied to one-argument variables.")
640+
(symbolic_type(only(arguments(x))) !== ScalarSymbolic()) && return x
641+
return operation(x)(A.t)
642+
end
643+
end
644+
645+
function (A::At)(x::Union{Num, Symbolics.Arr})
646+
wrap(A(unwrap(x)))
647+
end
648+
SymbolicUtils.isbinop(::At) = false
649+
650+
Base.nameof(::At) = :At
651+
Base.show(io::IO, A::At) = print(io, "At(", A.t, ")")
652+
Base.:(==)(A1::At, A2::At) = isequal(A1.t, A2.t)
653+
Base.hash(A::At, u::UInt) = hash(A.t, u)

test/model_parsing.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,3 +1026,35 @@ end
10261026
@named sys = Float2Bool()
10271027
@test typeof(sys) == DiscreteSystem
10281028
end
1029+
1030+
@testset "Constraints, costs, consolidate" begin
1031+
@mtkmodel Example begin
1032+
@variables begin
1033+
x(t)
1034+
y(t)
1035+
end
1036+
@equations begin
1037+
x ~ y
1038+
end
1039+
@constraints begin
1040+
At(0.3)(x) ~ 3
1041+
y 4
1042+
end
1043+
@costs begin
1044+
x + y
1045+
At(1)(y)^2
1046+
end
1047+
@consolidate f(u) = u[1]^2 + log(u[2])
1048+
end
1049+
1050+
@named ex = Example()
1051+
ex = complete(ex)
1052+
1053+
costs = ModelingToolkit.get_costs(ex)
1054+
constrs = ModelingToolkit.get_constraints(ModelingToolkit.get_constraintsystem(ex))
1055+
@test isequal(costs[1], ex.x + ex.y)
1056+
@test isequal(costs[2], At(1)(ex.y)^2)
1057+
@test isequal(constrs[1], -3 + At(0.3)(ex.x) ~ 0)
1058+
@test isequal(constrs[2], -4 + ex.y 0)
1059+
@test ModelingToolkit.get_consolidate(ex)([1, 2]) 1 + log(2)
1060+
end

test/variable_utils.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,29 @@ end
158158
@test !isinitial(c)
159159
@test !isinitial(x)
160160
end
161+
162+
@testset "At" begin
163+
@independent_variables u
164+
@variables x(t) v(..) w(t)[1:3]
165+
@parameters y z(u, t) r[1:3]
166+
167+
@test At(1)(x) isa Num
168+
@test isequal(At(1)(y), y)
169+
@test_throws ErrorException At(1)(z)
170+
@test isequal(At(1)(v), v(1))
171+
@test isequal(At(1)(v(t)), v(1))
172+
@test isequal(At(1)(v(2)), v(2))
173+
174+
arr = At(1)(w)
175+
var = At(1)(w[1])
176+
@test arr isa Symbolics.Arr
177+
@test var isa Num
178+
179+
@test isequal(At(1)(r), r)
180+
@test isequal(At(1)(r[2]), r[2])
181+
182+
_x = ModelingToolkit.unwrap(x)
183+
@test At(1)(_x) isa Symbolics.BasicSymbolic
184+
@test only(arguments(At(1)(_x))) == 1
185+
@test At(1)(D(x)) isa Num
186+
end

0 commit comments

Comments
 (0)