diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index d0f427bd14..992690cfe0 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -142,6 +142,7 @@ function var_derivative_graph! end include("bipartite_graph.jl") using .BipartiteGraphs +export EvalAt include("variables.jl") include("parameters.jl") include("independent_variables.jl") diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 01b0ca5fbb..50655d0074 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -247,7 +247,7 @@ end function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; controls = Num[], observed = Equation[], - constraintsystem = nothing, + constraints = Any[], costs = Num[], consolidate = nothing, systems = ODESystem[], @@ -276,11 +276,29 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) @assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters." + + constraintsystem = nothing + if !isempty(constraints) + constraintsystem = process_constraint_system(constraints, dvs, ps, iv) + for p in parameters(constraintsystem) + !in(p, Set(ps)) && push!(ps, p) + end + end + + if !isempty(costs) + coststs, costps = process_costs(costs, dvs, ps, iv) + for p in costps + !in(p, Set(ps)) && push!(ps, p) + end + end + costs = wrap.(costs) + iv′ = value(iv) ps′ = value.(ps) ctrl′ = value.(controls) dvs′ = value.(dvs) dvs′ = filter(x -> !isdelay(x, iv), dvs′) + parameter_dependencies, ps′ = process_parameter_dependencies( parameter_dependencies, ps′) if !(isempty(default_u0) && isempty(default_p)) @@ -350,7 +368,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; metadata, gui_metadata, is_dde, tstops, checks = checks) end -function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...) +function ODESystem(eqs, iv; kwargs...) diffvars, allunknowns, ps, eqs = process_equations(eqs, iv) for eq in get(kwargs, :parameter_dependencies, Equation[]) @@ -382,30 +400,8 @@ function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...) end algevars = setdiff(allunknowns, diffvars) - consvars = OrderedSet() - constraintsystem = nothing - if !isempty(constraints) - constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv) - for st in get_unknowns(constraintsystem) - iscall(st) ? - !in(operation(st)(iv), allunknowns) && push!(consvars, st) : - !in(st, allunknowns) && push!(consvars, st) - end - for p in parameters(constraintsystem) - !in(p, new_ps) && push!(new_ps, p) - end - end - - if !isempty(costs) - coststs, costps = process_costs(costs, allunknowns, new_ps, iv) - for p in costps - !in(p, new_ps) && push!(new_ps, p) - end - end - costs = wrap.(costs) - - return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))), - collect(new_ps); constraintsystem, costs, kwargs...) + return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))), + collect(new_ps); kwargs...) end # NOTE: equality does not check cached Jacobian @@ -760,7 +756,7 @@ end Build the constraint system for the ODESystem. """ function process_constraint_system( - constraints::Vector{Equation}, sts, ps, iv; consname = :cons) + constraints::Vector, sts, ps, iv; consname = :cons) isempty(constraints) && return nothing constraintsts = OrderedSet() @@ -800,7 +796,7 @@ Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 th parameter of the system. """ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv) - sts = sysvars + sts = Set(sysvars) for var in auxvars if !iscall(var) diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 5f7c986659..075aa27e4d 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -121,7 +121,7 @@ struct DiscreteSystem <: AbstractDiscreteSystem tearing_state = nothing, substitutions = nothing, namespacing = true, complete = false, index_cache = nothing, parent = nothing, isscheduled = false; - checks::Union{Bool, Int} = true) + checks::Union{Bool, Int} = true, kwargs...) if checks == true || (checks & CheckComponents) > 0 check_independent_variables([iv]) check_variables(dvs, iv) @@ -199,7 +199,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps; DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), eqs, iv′, dvs′, ps′, tspan, var_to_name, observed, name, description, systems, defaults, guesses, initializesystem, initialization_eqs, preface, connector_type, - parameter_dependencies, metadata, gui_metadata, kwargs...) + parameter_dependencies, metadata, gui_metadata) end function DiscreteSystem(eqs, iv; kwargs...) diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 195b02118e..024c249363 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -65,6 +65,8 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector) ps, sps, vs, = [], [], [] c_evts = [] d_evts = [] + cons = [] + costs = [] kwargs = OrderedCollections.OrderedSet() where_types = Union{Symbol, Expr}[] @@ -80,7 +82,7 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector) for arg in expr.args if arg.head == :macrocall parse_model!(exprs.args, comps, ext, eqs, icon, vs, ps, - sps, c_evts, d_evts, dict, mod, arg, kwargs, where_types) + sps, c_evts, d_evts, cons, costs, dict, mod, arg, kwargs, where_types) elseif arg.head == :block push!(exprs.args, arg) elseif arg.head == :if @@ -120,13 +122,15 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector) gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) : GUIMetadata(GlobalRef(mod, name)) + consolidate = get(dict, :consolidate, nothing) description = get(dict, :description, "") @inline pop_structure_dict!.( Ref(dict), [:constants, :defaults, :kwargs, :structural_parameters]) sys = :($type($(flatten_equations)(equations), $iv, variables, parameters; - name, description = $description, systems, gui_metadata = $gui_metadata, defaults)) + name, description = $description, systems, gui_metadata = $gui_metadata, defaults, + costs = [$(costs...)], constraints = [$(cons...)], consolidate = $consolidate)) if length(ext) == 0 push!(exprs.args, :(var"#___sys___" = $sys)) @@ -610,7 +614,7 @@ function get_var(mod::Module, b) end function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts, - dict, mod, arg, kwargs, where_types) + cons, costs, dict, mod, arg, kwargs, where_types) mname = arg.args[1] body = arg.args[end] if mname == Symbol("@description") @@ -638,6 +642,12 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts, parse_icon!(body, dict, icon, mod) elseif mname == Symbol("@defaults") parse_system_defaults!(exprs, arg, dict) + elseif mname == Symbol("@constraints") + parse_costs!(cons, dict, body) + elseif mname == Symbol("@costs") + parse_constraints!(costs, dict, body) + elseif mname == Symbol("@consolidate") + parse_consolidate!(body, dict) else error("$mname is not handled.") end @@ -1149,6 +1159,32 @@ function parse_discrete_events!(d_evts, dict, body) end end +function parse_constraints!(cons, dict, body) + dict[:constraints] = [] + Base.remove_linenums!(body) + for arg in body.args + push!(cons, arg) + push!(dict[:constraints], readable_code.(cons)...) + end +end + +function parse_costs!(costs, dict, body) + dict[:costs] = [] + Base.remove_linenums!(body) + for arg in body.args + push!(costs, arg) + push!(dict[:costs], readable_code.(costs)...) + end +end + +function parse_consolidate!(body, dict) + if !(occursin("->", string(body)) || occursin("=", string(body))) + error("Consolidate must be a function definition.") + else + dict[:consolidate] = body + end +end + function parse_icon!(body::String, dict, icon, mod) icon_dir = get(ENV, "MTK_ICONS_DIR", joinpath(DEPOT_PATH[1], "mtk_icons")) dict[:icon] = icon[] = if isfile(body) diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index 5035a22b5e..83a5ac5483 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -272,7 +272,7 @@ function validate(jumps::ArrayPartition{<:Union{Any, Vector{<:JumpType}}}, t::Sy all([validate(jumps.x[idx], t, info = labels[idx]) for idx in 1:3]) end -function validate(eq::Equation; info::String = "") +function validate(eq::Union{Inequality, Equation}; info::String = "") if typeof(eq.lhs) == Connection _validate(eq.rhs; info) else diff --git a/src/variables.jl b/src/variables.jl index f3dd16819d..83e72cea35 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -612,3 +612,43 @@ getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing) getshift(x::Num) = getshift(unwrap(x)) getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0) + +################### +### Evaluate at ### +################### +struct EvalAt <: Symbolics.Operator + t::Union{Symbolic, Number} +end + +function (A::EvalAt)(x::Symbolic) + if symbolic_type(x) == NotSymbolic() || !iscall(x) + if x isa Symbolics.CallWithMetadata + return x(A.t) + else + return x + end + end + + if iscall(x) && operation(x) == getindex + arr = arguments(x)[1] + term(getindex, A(arr), arguments(x)[2:end]...) + elseif operation(x) isa Differential + x = default_toterm(x) + A(x) + else + length(arguments(x)) !== 1 && + error("Variable $x has too many arguments. EvalAt can only be applied to one-argument variables.") + (symbolic_type(only(arguments(x))) !== ScalarSymbolic()) && return x + return operation(x)(A.t) + end +end + +function (A::EvalAt)(x::Union{Num, Symbolics.Arr}) + wrap(A(unwrap(x))) +end +SymbolicUtils.isbinop(::EvalAt) = false + +Base.nameof(::EvalAt) = :EvalAt +Base.show(io::IO, A::EvalAt) = print(io, "EvalAt(", A.t, ")") +Base.:(==)(A1::EvalAt, A2::EvalAt) = isequal(A1.t, A2.t) +Base.hash(A::EvalAt, u::UInt) = hash(A.t, u) diff --git a/test/model_parsing.jl b/test/model_parsing.jl index e8464707de..fe2bcbfca6 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -1026,3 +1026,35 @@ end @named sys = Float2Bool() @test typeof(sys) == DiscreteSystem end + +@testset "Constraints, costs, consolidate" begin + @mtkmodel Example begin + @variables begin + x(t) + y(t) + end + @equations begin + x ~ y + end + @constraints begin + EvalAt(0.3)(x) ~ 3 + y ≲ 4 + end + @costs begin + x + y + EvalAt(1)(y)^2 + end + @consolidate f(u) = u[1]^2 + log(u[2]) + end + + @named ex = Example() + ex = complete(ex) + + costs = ModelingToolkit.get_costs(ex) + constrs = ModelingToolkit.get_constraints(ModelingToolkit.get_constraintsystem(ex)) + @test isequal(costs[1], ex.x + ex.y) + @test isequal(costs[2], EvalAt(1)(ex.y)^2) + @test isequal(constrs[1], -3 + EvalAt(0.3)(ex.x) ~ 0) + @test isequal(constrs[2], -4 + ex.y ≲ 0) + @test ModelingToolkit.get_consolidate(ex)([1, 2]) ≈ 1 + log(2) +end diff --git a/test/variable_utils.jl b/test/variable_utils.jl index 3204d28836..1dc45e11ef 100644 --- a/test/variable_utils.jl +++ b/test/variable_utils.jl @@ -158,3 +158,29 @@ end @test !isinitial(c) @test !isinitial(x) end + +@testset "At" begin + @independent_variables u + @variables x(t) v(..) w(t)[1:3] + @parameters y z(u, t) r[1:3] + + @test EvalAt(1)(x) isa Num + @test isequal(EvalAt(1)(y), y) + @test_throws ErrorException EvalAt(1)(z) + @test isequal(EvalAt(1)(v), v(1)) + @test isequal(EvalAt(1)(v(t)), v(1)) + @test isequal(EvalAt(1)(v(2)), v(2)) + + arr = EvalAt(1)(w) + var = EvalAt(1)(w[1]) + @test arr isa Symbolics.Arr + @test var isa Num + + @test isequal(EvalAt(1)(r), r) + @test isequal(EvalAt(1)(r[2]), r[2]) + + _x = ModelingToolkit.unwrap(x) + @test EvalAt(1)(_x) isa Symbolics.BasicSymbolic + @test only(arguments(EvalAt(1)(_x))) == 1 + @test EvalAt(1)(D(x)) isa Num +end