From cb046ffe581c8320cb1ce8c14dcd3679d0f154fb Mon Sep 17 00:00:00 2001 From: vyudu Date: Fri, 25 Apr 2025 17:16:08 -0400 Subject: [PATCH 1/7] feat: At and costs in @mtkmodel --- src/systems/diffeqs/odesystem.jl | 52 +++++++++++++++----------------- src/systems/model_parsing.jl | 47 ++++++++++++++++++++++++++--- src/variables.jl | 39 ++++++++++++++++++++++++ test/model_parsing.jl | 32 ++++++++++++++++++++ test/variable_utils.jl | 26 ++++++++++++++++ 5 files changed, 165 insertions(+), 31 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 01b0ca5fbb..de59f3ae55 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -247,7 +247,7 @@ end function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; controls = Num[], observed = Equation[], - constraintsystem = nothing, + constraints = Any[], costs = Num[], consolidate = nothing, systems = ODESystem[], @@ -276,11 +276,30 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) @assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters." + + constraintsystem = nothing + if !isempty(constraints) + @show constraints + constraintsystem = process_constraint_system(constraints, dvs, ps, iv) + for p in parameters(constraintsystem) + !in(p, Set(ps)) && push!(ps, p) + end + end + + if !isempty(costs) + coststs, costps = process_costs(costs, dvs, ps, iv) + for p in costps + !in(p, Set(ps)) && push!(ps, p) + end + end + costs = wrap.(costs) + iv′ = value(iv) ps′ = value.(ps) ctrl′ = value.(controls) dvs′ = value.(dvs) dvs′ = filter(x -> !isdelay(x, iv), dvs′) + parameter_dependencies, ps′ = process_parameter_dependencies( parameter_dependencies, ps′) if !(isempty(default_u0) && isempty(default_p)) @@ -350,7 +369,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; metadata, gui_metadata, is_dde, tstops, checks = checks) end -function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...) +function ODESystem(eqs, iv; kwargs...) diffvars, allunknowns, ps, eqs = process_equations(eqs, iv) for eq in get(kwargs, :parameter_dependencies, Equation[]) @@ -382,29 +401,7 @@ function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...) end algevars = setdiff(allunknowns, diffvars) - consvars = OrderedSet() - constraintsystem = nothing - if !isempty(constraints) - constraintsystem = process_constraint_system(constraints, allunknowns, new_ps, iv) - for st in get_unknowns(constraintsystem) - iscall(st) ? - !in(operation(st)(iv), allunknowns) && push!(consvars, st) : - !in(st, allunknowns) && push!(consvars, st) - end - for p in parameters(constraintsystem) - !in(p, new_ps) && push!(new_ps, p) - end - end - - if !isempty(costs) - coststs, costps = process_costs(costs, allunknowns, new_ps, iv) - for p in costps - !in(p, new_ps) && push!(new_ps, p) - end - end - costs = wrap.(costs) - - return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))), + return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))), collect(new_ps); constraintsystem, costs, kwargs...) end @@ -760,7 +757,7 @@ end Build the constraint system for the ODESystem. """ function process_constraint_system( - constraints::Vector{Equation}, sts, ps, iv; consname = :cons) + constraints::Vector, sts, ps, iv; consname = :cons) isempty(constraints) && return nothing constraintsts = OrderedSet() @@ -800,7 +797,7 @@ Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 th parameter of the system. """ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv) - sts = sysvars + sts = Set(sysvars) for var in auxvars if !iscall(var) @@ -810,6 +807,7 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv) throw(ArgumentError("Too many arguments for variable $var.")) elseif length(arguments(var)) == 1 arg = only(arguments(var)) + @show sts operation(var)(iv) ∈ sts || throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem.")) diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 195b02118e..e0afa098a7 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -65,6 +65,8 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector) ps, sps, vs, = [], [], [] c_evts = [] d_evts = [] + cons = [] + costs = [] kwargs = OrderedCollections.OrderedSet() where_types = Union{Symbol, Expr}[] @@ -80,7 +82,7 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector) for arg in expr.args if arg.head == :macrocall parse_model!(exprs.args, comps, ext, eqs, icon, vs, ps, - sps, c_evts, d_evts, dict, mod, arg, kwargs, where_types) + sps, c_evts, d_evts, cons, costs, dict, mod, arg, kwargs, where_types) elseif arg.head == :block push!(exprs.args, arg) elseif arg.head == :if @@ -117,16 +119,19 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector) push!(exprs.args, :(push!(systems, $(comps...)))) push!(exprs.args, :(push!(variables, $(vs...)))) + gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) : GUIMetadata(GlobalRef(mod, name)) + consolidate = get(dict, :consolidate, nothing) description = get(dict, :description, "") @inline pop_structure_dict!.( Ref(dict), [:constants, :defaults, :kwargs, :structural_parameters]) sys = :($type($(flatten_equations)(equations), $iv, variables, parameters; - name, description = $description, systems, gui_metadata = $gui_metadata, defaults)) + name, description = $description, systems, gui_metadata = $gui_metadata, defaults, + costs = [$(costs...)], constraints = [$(cons...)], consolidate = $consolidate)) if length(ext) == 0 push!(exprs.args, :(var"#___sys___" = $sys)) @@ -610,9 +615,10 @@ function get_var(mod::Module, b) end function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts, - dict, mod, arg, kwargs, where_types) + cons, costs, dict, mod, arg, kwargs, where_types) mname = arg.args[1] body = arg.args[end] + @show dict if mname == Symbol("@description") parse_description!(body, dict) elseif mname == Symbol("@components") @@ -637,7 +643,13 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts, isassigned(icon) && error("This model has more than one icon.") parse_icon!(body, dict, icon, mod) elseif mname == Symbol("@defaults") - parse_system_defaults!(exprs, arg, dict) + parse_system_defaults!(exprs, dict, body) + elseif mname == Symbol("@constraints") + parse_costs!(cons, dict, body) + elseif mname == Symbol("@costs") + parse_constraints!(costs, dict, body) + elseif mname == Symbol("@consolidate") + parse_consolidate!(body, dict) else error("$mname is not handled.") end @@ -1149,6 +1161,33 @@ function parse_discrete_events!(d_evts, dict, body) end end +function parse_constraints!(cons, dict, body) + dict[:constraints] = [] + Base.remove_linenums!(body) + for arg in body.args + push!(cons, arg) + push!(dict[:constraints], readable_code.(cons)...) + end +end + +function parse_costs!(costs, dict, body) + @show dict + dict[:costs] = [] + Base.remove_linenums!(body) + for arg in body.args + push!(costs, arg) + push!(dict[:costs], readable_code.(costs)...) + end +end + +function parse_consolidate!(body, dict) + if !(occursin("->", string(body)) || occursin("=", string(body))) + error("Consolidate must be a function definition.") + else + dict[:consolidate] = body + end +end + function parse_icon!(body::String, dict, icon, mod) icon_dir = get(ENV, "MTK_ICONS_DIR", joinpath(DEPOT_PATH[1], "mtk_icons")) dict[:icon] = icon[] = if isfile(body) diff --git a/src/variables.jl b/src/variables.jl index f3dd16819d..90afb5ee08 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -612,3 +612,42 @@ getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing) getshift(x::Num) = getshift(unwrap(x)) getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0) + +################### +### Evaluate at ### +################### +struct At <: Symbolics.Operator + t::Union{Symbolic, Number} +end + +function (A::At)(x::Symbolic) + if symbolic_type(x) == NotSymbolic() || !iscall(x) + if x isa Symbolics.CallWithMetadata + return x(A.t) + else + return x + end + end + + if iscall(x) && operation(x) == getindex + arr = arguments(x)[1] + term(getindex, A(arr), arguments(x)[2:end]...) + elseif operation(x) isa Differential + x = default_toterm(x) + A(x) + else + length(arguments(x)) !== 1 && error("Variable $x has too many arguments. At can only be applied to one-argument variables.") + (symbolic_type(only(arguments(x))) !== ScalarSymbolic()) && return x + return operation(x)(A.t) + end +end + +function (A::At)(x::Union{Num, Symbolics.Arr}) + wrap(A(unwrap(x))) +end +SymbolicUtils.isbinop(::At) = false + +Base.nameof(::At) = :At +Base.show(io::IO, A::At) = print(io, "At(", A.t, ")") +Base.:(==)(A1::At, A2::At) = isequal(A1.t, A2.t) +Base.hash(A::At, u::UInt) = hash(A.t, u) diff --git a/test/model_parsing.jl b/test/model_parsing.jl index e8464707de..0d49554257 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -1026,3 +1026,35 @@ end @named sys = Float2Bool() @test typeof(sys) == DiscreteSystem end + +@testset "Constraints, costs, consolidate" begin + @mtkmodel Example begin + @variables begin + x(t) + y(t) + end + @equations begin + x ~ y + end + @constraints begin + At(0.3)(x) ~ 3 + y ≲ 4 + end + @costs begin + x + y + At(1)(y)^2 + end + @consolidate f(u) = u[1]^2 + log(u[2]) + end + + @named ex = Example() + ex = complete(ex) + + costs = ModelingToolkit.get_costs(ex) + constrs = ModelingToolkit.get_constraints(ModelingToolkit.get_constraintsystem(ex)) + @test isequal(costs[1], ex.x + ex.y) + @test isequal(costs[2], At(1)(ex.y)^2) + @test isequal(constrs[1], -3 + At(0.3)(ex.x) ~ 0) + @test isequal(constrs[2], -4 + ex.y ≲ 0) + @test ModelingToolkit.get_consolidate(ex)([1, 2]) ≈ 1 + log(2) +end diff --git a/test/variable_utils.jl b/test/variable_utils.jl index 3204d28836..ce7f7d26eb 100644 --- a/test/variable_utils.jl +++ b/test/variable_utils.jl @@ -158,3 +158,29 @@ end @test !isinitial(c) @test !isinitial(x) end + +@testset "At" begin + @independent_variables u + @variables x(t) v(..) w(t)[1:3] + @parameters y z(u, t) r[1:3] + + @test At(1)(x) isa Num + @test isequal(At(1)(y), y) + @test_throws ErrorException At(1)(z) + @test isequal(At(1)(v), v(1)) + @test isequal(At(1)(v(t)), v(1)) + @test isequal(At(1)(v(2)), v(2)) + + arr = At(1)(w) + var = At(1)(w[1]) + @test arr isa Symbolics.Arr + @test var isa Num + + @test isequal(At(1)(r), r) + @test isequal(At(1)(r[2]), r[2]) + + _x = ModelingToolkit.unwrap(x) + @test At(1)(_x) isa Symbolics.BasicSymbolic + @test only(arguments(At(1)(_x))) == 1 + @test At(1)(D(x)) isa Num +end From 028ba8914fe4e8ac6d65263ed3b734a2718082f8 Mon Sep 17 00:00:00 2001 From: vyudu Date: Fri, 25 Apr 2025 17:16:28 -0400 Subject: [PATCH 2/7] format --- src/systems/model_parsing.jl | 9 ++++----- src/variables.jl | 5 +++-- test/model_parsing.jl | 2 +- test/variable_utils.jl | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index e0afa098a7..6bea4b2b80 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -119,7 +119,6 @@ function _model_macro(mod, fullname::Union{Expr, Symbol}, expr, isconnector) push!(exprs.args, :(push!(systems, $(comps...)))) push!(exprs.args, :(push!(variables, $(vs...)))) - gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) : GUIMetadata(GlobalRef(mod, name)) @@ -1161,8 +1160,8 @@ function parse_discrete_events!(d_evts, dict, body) end end -function parse_constraints!(cons, dict, body) - dict[:constraints] = [] +function parse_constraints!(cons, dict, body) + dict[:constraints] = [] Base.remove_linenums!(body) for arg in body.args push!(cons, arg) @@ -1170,9 +1169,9 @@ function parse_constraints!(cons, dict, body) end end -function parse_costs!(costs, dict, body) +function parse_costs!(costs, dict, body) @show dict - dict[:costs] = [] + dict[:costs] = [] Base.remove_linenums!(body) for arg in body.args push!(costs, arg) diff --git a/src/variables.jl b/src/variables.jl index 90afb5ee08..e5d8961225 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -616,7 +616,7 @@ getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0) ################### ### Evaluate at ### ################### -struct At <: Symbolics.Operator +struct At <: Symbolics.Operator t::Union{Symbolic, Number} end @@ -636,7 +636,8 @@ function (A::At)(x::Symbolic) x = default_toterm(x) A(x) else - length(arguments(x)) !== 1 && error("Variable $x has too many arguments. At can only be applied to one-argument variables.") + length(arguments(x)) !== 1 && + error("Variable $x has too many arguments. At can only be applied to one-argument variables.") (symbolic_type(only(arguments(x))) !== ScalarSymbolic()) && return x return operation(x)(A.t) end diff --git a/test/model_parsing.jl b/test/model_parsing.jl index 0d49554257..05380269e5 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -1050,7 +1050,7 @@ end @named ex = Example() ex = complete(ex) - costs = ModelingToolkit.get_costs(ex) + costs = ModelingToolkit.get_costs(ex) constrs = ModelingToolkit.get_constraints(ModelingToolkit.get_constraintsystem(ex)) @test isequal(costs[1], ex.x + ex.y) @test isequal(costs[2], At(1)(ex.y)^2) diff --git a/test/variable_utils.jl b/test/variable_utils.jl index ce7f7d26eb..f447cfeb8d 100644 --- a/test/variable_utils.jl +++ b/test/variable_utils.jl @@ -173,9 +173,9 @@ end arr = At(1)(w) var = At(1)(w[1]) - @test arr isa Symbolics.Arr + @test arr isa Symbolics.Arr @test var isa Num - + @test isequal(At(1)(r), r) @test isequal(At(1)(r[2]), r[2]) From 2085862dcb98b821ca30911ddc6d003e15e4cb20 Mon Sep 17 00:00:00 2001 From: vyudu Date: Fri, 25 Apr 2025 17:19:05 -0400 Subject: [PATCH 3/7] cleanup: remove @show staements --- src/systems/diffeqs/odesystem.jl | 2 -- src/systems/model_parsing.jl | 2 -- 2 files changed, 4 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index de59f3ae55..61d437a11b 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -279,7 +279,6 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; constraintsystem = nothing if !isempty(constraints) - @show constraints constraintsystem = process_constraint_system(constraints, dvs, ps, iv) for p in parameters(constraintsystem) !in(p, Set(ps)) && push!(ps, p) @@ -807,7 +806,6 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv) throw(ArgumentError("Too many arguments for variable $var.")) elseif length(arguments(var)) == 1 arg = only(arguments(var)) - @show sts operation(var)(iv) ∈ sts || throw(ArgumentError("Variable $var is not a variable of the ODESystem. Called variables must be variables of the ODESystem.")) diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 6bea4b2b80..3f37775dc7 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -617,7 +617,6 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts, cons, costs, dict, mod, arg, kwargs, where_types) mname = arg.args[1] body = arg.args[end] - @show dict if mname == Symbol("@description") parse_description!(body, dict) elseif mname == Symbol("@components") @@ -1170,7 +1169,6 @@ function parse_constraints!(cons, dict, body) end function parse_costs!(costs, dict, body) - @show dict dict[:costs] = [] Base.remove_linenums!(body) for arg in body.args From 1026042855392012c73cfd0a5dbcb8e5b7e26b80 Mon Sep 17 00:00:00 2001 From: vyudu Date: Mon, 28 Apr 2025 15:35:45 -0400 Subject: [PATCH 4/7] fix: fix constructor and --- src/systems/diffeqs/odesystem.jl | 2 +- src/systems/model_parsing.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 61d437a11b..50655d0074 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -401,7 +401,7 @@ function ODESystem(eqs, iv; kwargs...) algevars = setdiff(allunknowns, diffvars) return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars))), - collect(new_ps); constraintsystem, costs, kwargs...) + collect(new_ps); kwargs...) end # NOTE: equality does not check cached Jacobian diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 3f37775dc7..024c249363 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -641,7 +641,7 @@ function parse_model!(exprs, comps, ext, eqs, icon, vs, ps, sps, c_evts, d_evts, isassigned(icon) && error("This model has more than one icon.") parse_icon!(body, dict, icon, mod) elseif mname == Symbol("@defaults") - parse_system_defaults!(exprs, dict, body) + parse_system_defaults!(exprs, arg, dict) elseif mname == Symbol("@constraints") parse_costs!(cons, dict, body) elseif mname == Symbol("@costs") From a0ab00301fd9f5ced4566cf7d69b6c163c822457 Mon Sep 17 00:00:00 2001 From: vyudu Date: Mon, 28 Apr 2025 16:47:26 -0400 Subject: [PATCH 5/7] test: fix model parsing tests --- src/ModelingToolkit.jl | 1 + src/systems/discrete_system/discrete_system.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index d0f427bd14..bfe7b489b6 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -142,6 +142,7 @@ function var_derivative_graph! end include("bipartite_graph.jl") using .BipartiteGraphs +export At include("variables.jl") include("parameters.jl") include("independent_variables.jl") diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 5f7c986659..075aa27e4d 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -121,7 +121,7 @@ struct DiscreteSystem <: AbstractDiscreteSystem tearing_state = nothing, substitutions = nothing, namespacing = true, complete = false, index_cache = nothing, parent = nothing, isscheduled = false; - checks::Union{Bool, Int} = true) + checks::Union{Bool, Int} = true, kwargs...) if checks == true || (checks & CheckComponents) > 0 check_independent_variables([iv]) check_variables(dvs, iv) @@ -199,7 +199,7 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps; DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), eqs, iv′, dvs′, ps′, tspan, var_to_name, observed, name, description, systems, defaults, guesses, initializesystem, initialization_eqs, preface, connector_type, - parameter_dependencies, metadata, gui_metadata, kwargs...) + parameter_dependencies, metadata, gui_metadata) end function DiscreteSystem(eqs, iv; kwargs...) From f50dc2653b91753ecab7812aa0c6b1071c8093a5 Mon Sep 17 00:00:00 2001 From: vyudu Date: Mon, 28 Apr 2025 20:09:34 -0400 Subject: [PATCH 6/7] fix: add validate overload for inequalities --- src/systems/unit_check.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index 5035a22b5e..83a5ac5483 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -272,7 +272,7 @@ function validate(jumps::ArrayPartition{<:Union{Any, Vector{<:JumpType}}}, t::Sy all([validate(jumps.x[idx], t, info = labels[idx]) for idx in 1:3]) end -function validate(eq::Equation; info::String = "") +function validate(eq::Union{Inequality, Equation}; info::String = "") if typeof(eq.lhs) == Connection _validate(eq.rhs; info) else From a30df651129ec974972b4698b3b06092c0f98a7e Mon Sep 17 00:00:00 2001 From: vyudu Date: Fri, 2 May 2025 17:31:13 -0400 Subject: [PATCH 7/7] rename At to EvalAt --- src/ModelingToolkit.jl | 2 +- src/variables.jl | 18 +++++++++--------- test/model_parsing.jl | 8 ++++---- test/variable_utils.jl | 26 +++++++++++++------------- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index bfe7b489b6..992690cfe0 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -142,7 +142,7 @@ function var_derivative_graph! end include("bipartite_graph.jl") using .BipartiteGraphs -export At +export EvalAt include("variables.jl") include("parameters.jl") include("independent_variables.jl") diff --git a/src/variables.jl b/src/variables.jl index e5d8961225..83e72cea35 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -616,11 +616,11 @@ getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0) ################### ### Evaluate at ### ################### -struct At <: Symbolics.Operator +struct EvalAt <: Symbolics.Operator t::Union{Symbolic, Number} end -function (A::At)(x::Symbolic) +function (A::EvalAt)(x::Symbolic) if symbolic_type(x) == NotSymbolic() || !iscall(x) if x isa Symbolics.CallWithMetadata return x(A.t) @@ -637,18 +637,18 @@ function (A::At)(x::Symbolic) A(x) else length(arguments(x)) !== 1 && - error("Variable $x has too many arguments. At can only be applied to one-argument variables.") + error("Variable $x has too many arguments. EvalAt can only be applied to one-argument variables.") (symbolic_type(only(arguments(x))) !== ScalarSymbolic()) && return x return operation(x)(A.t) end end -function (A::At)(x::Union{Num, Symbolics.Arr}) +function (A::EvalAt)(x::Union{Num, Symbolics.Arr}) wrap(A(unwrap(x))) end -SymbolicUtils.isbinop(::At) = false +SymbolicUtils.isbinop(::EvalAt) = false -Base.nameof(::At) = :At -Base.show(io::IO, A::At) = print(io, "At(", A.t, ")") -Base.:(==)(A1::At, A2::At) = isequal(A1.t, A2.t) -Base.hash(A::At, u::UInt) = hash(A.t, u) +Base.nameof(::EvalAt) = :EvalAt +Base.show(io::IO, A::EvalAt) = print(io, "EvalAt(", A.t, ")") +Base.:(==)(A1::EvalAt, A2::EvalAt) = isequal(A1.t, A2.t) +Base.hash(A::EvalAt, u::UInt) = hash(A.t, u) diff --git a/test/model_parsing.jl b/test/model_parsing.jl index 05380269e5..fe2bcbfca6 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -1037,12 +1037,12 @@ end x ~ y end @constraints begin - At(0.3)(x) ~ 3 + EvalAt(0.3)(x) ~ 3 y ≲ 4 end @costs begin x + y - At(1)(y)^2 + EvalAt(1)(y)^2 end @consolidate f(u) = u[1]^2 + log(u[2]) end @@ -1053,8 +1053,8 @@ end costs = ModelingToolkit.get_costs(ex) constrs = ModelingToolkit.get_constraints(ModelingToolkit.get_constraintsystem(ex)) @test isequal(costs[1], ex.x + ex.y) - @test isequal(costs[2], At(1)(ex.y)^2) - @test isequal(constrs[1], -3 + At(0.3)(ex.x) ~ 0) + @test isequal(costs[2], EvalAt(1)(ex.y)^2) + @test isequal(constrs[1], -3 + EvalAt(0.3)(ex.x) ~ 0) @test isequal(constrs[2], -4 + ex.y ≲ 0) @test ModelingToolkit.get_consolidate(ex)([1, 2]) ≈ 1 + log(2) end diff --git a/test/variable_utils.jl b/test/variable_utils.jl index f447cfeb8d..1dc45e11ef 100644 --- a/test/variable_utils.jl +++ b/test/variable_utils.jl @@ -164,23 +164,23 @@ end @variables x(t) v(..) w(t)[1:3] @parameters y z(u, t) r[1:3] - @test At(1)(x) isa Num - @test isequal(At(1)(y), y) - @test_throws ErrorException At(1)(z) - @test isequal(At(1)(v), v(1)) - @test isequal(At(1)(v(t)), v(1)) - @test isequal(At(1)(v(2)), v(2)) + @test EvalAt(1)(x) isa Num + @test isequal(EvalAt(1)(y), y) + @test_throws ErrorException EvalAt(1)(z) + @test isequal(EvalAt(1)(v), v(1)) + @test isequal(EvalAt(1)(v(t)), v(1)) + @test isequal(EvalAt(1)(v(2)), v(2)) - arr = At(1)(w) - var = At(1)(w[1]) + arr = EvalAt(1)(w) + var = EvalAt(1)(w[1]) @test arr isa Symbolics.Arr @test var isa Num - @test isequal(At(1)(r), r) - @test isequal(At(1)(r[2]), r[2]) + @test isequal(EvalAt(1)(r), r) + @test isequal(EvalAt(1)(r[2]), r[2]) _x = ModelingToolkit.unwrap(x) - @test At(1)(_x) isa Symbolics.BasicSymbolic - @test only(arguments(At(1)(_x))) == 1 - @test At(1)(D(x)) isa Num + @test EvalAt(1)(_x) isa Symbolics.BasicSymbolic + @test only(arguments(EvalAt(1)(_x))) == 1 + @test EvalAt(1)(D(x)) isa Num end