diff --git a/Project.toml b/Project.toml index 8e603699ba..2f25e2b3cc 100644 --- a/Project.toml +++ b/Project.toml @@ -148,7 +148,7 @@ RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" SCCNonlinearSolve = "1.0.0" -SciMLBase = "2.106.0" +SciMLBase = "2.108.0" SciMLPublic = "1.0.0" SciMLStructures = "1.7" Serialization = "1" diff --git a/src/clock.jl b/src/clock.jl index 1c9ed89128..df3b6f4b47 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -1,11 +1,15 @@ @data InferredClock begin Inferred - InferredDiscrete + InferredDiscrete(Int) end const InferredTimeDomain = InferredClock.Type using .InferredClock: Inferred, InferredDiscrete +function InferredClock.InferredDiscrete() + return InferredDiscrete(0) +end + Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x) struct VariableTimeDomain end @@ -50,7 +54,7 @@ has_time_domain(x::Num) = has_time_domain(value(x)) has_time_domain(x) = false for op in [Differential] - @eval input_timedomain(::$op, arg = nothing) = ContinuousClock() + @eval input_timedomain(::$op, arg = nothing) = (ContinuousClock(),) @eval output_timedomain(::$op, arg = nothing) = ContinuousClock() end @@ -97,6 +101,7 @@ function is_discrete_domain(x) end sampletime(c) = Moshi.Match.@match c begin + x::SciMLBase.AbstractClock => nothing PeriodicClock(dt) => dt _ => nothing end diff --git a/src/discretedomain.jl b/src/discretedomain.jl index da8417de4e..9e57296d9f 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -10,6 +10,20 @@ are not transparent but `Sample` and `Hold` are. Defaults to `false` if not impl is_transparent_operator(x) = is_transparent_operator(typeof(x)) is_transparent_operator(::Type) = false +""" + $(TYPEDSIGNATURES) + +Trait to be implemented for operators which determines whether the operator is applied to +a time-varying quantity and results in a time-varying quantity. For example, `Initial` and +`Pre` are not time-varying since while they are applied to variables, the application +results in a non-discrete-time parameter. `Differential`, `Shift`, `Sample` and `Hold` are +all time-varying operators. All time-varying operators must implement `input_timedomain` and +`output_timedomain`. +""" +is_timevarying_operator(x) = is_timevarying_operator(typeof(x)) +is_timevarying_operator(::Type{<:Symbolics.Operator}) = true +is_timevarying_operator(::Type) = false + """ function SampleTime() @@ -314,12 +328,13 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i) input_timedomain(op::Operator) Return the time-domain type (`ContinuousClock()` or `InferredDiscrete()`) that `op` operates on. +Should return a tuple containing the time domain type for each argument to the operator. """ function input_timedomain(s::Shift, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() + (InferredDiscrete(),) end """ @@ -334,25 +349,28 @@ function output_timedomain(s::Shift, arg = nothing) InferredDiscrete() end -input_timedomain(::Sample, _ = nothing) = ContinuousClock() +input_timedomain(::Sample, _ = nothing) = (ContinuousClock(),) output_timedomain(s::Sample, _ = nothing) = s.clock function input_timedomain(h::Hold, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() # the Hold accepts any discrete + (InferredDiscrete(),) # the Hold accepts any discrete end output_timedomain(::Hold, _ = nothing) = ContinuousClock() sampletime(op::Sample, _ = nothing) = sampletime(op.clock) sampletime(op::ShiftIndex, _ = nothing) = sampletime(op.clock) -changes_domain(op) = isoperator(op, Union{Sample, Hold}) - function output_timedomain(x) if isoperator(x, Operator) - return output_timedomain(operation(x), arguments(x)[]) + args = arguments(x) + return output_timedomain(operation(x), if length(args) == 1 + args[] + else + args + end) else throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression")) end @@ -360,8 +378,24 @@ end function input_timedomain(x) if isoperator(x, Operator) - return input_timedomain(operation(x), arguments(x)[]) + args = arguments(x) + return input_timedomain(operation(x), if length(args) == 1 + args[] + else + args + end) else throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression")) end end + +function ZeroCrossing(expr; name = gensym(), up = true, down = true, kwargs...) + return SymbolicContinuousCallback( + [expr ~ 0], up ? ImperativeAffect(Returns(nothing)) : nothing; + affect_neg = down ? ImperativeAffect(Returns(nothing)) : nothing, + kwargs..., zero_crossing_id = name) +end + +function SciMLBase.Clocks.EventClock(cb::SymbolicContinuousCallback) + return SciMLBase.Clocks.EventClock(cb.zero_crossing_id) +end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index ab74a71ffa..3fe04936d3 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -486,13 +486,12 @@ The `Initial` operator. Used by initialization to store constant constraints on of a system. See the documentation section on initialization for more information. """ struct Initial <: Symbolics.Operator end +is_timevarying_operator(::Type{Initial}) = false Initial(x) = Initial()(x) SymbolicUtils.promote_symtype(::Type{Initial}, T) = T SymbolicUtils.isbinop(::Initial) = false Base.nameof(::Initial) = :Initial Base.show(io::IO, x::Initial) = print(io, "Initial") -input_timedomain(::Initial, _ = nothing) = ContinuousClock() -output_timedomain(::Initial, _ = nothing) = ContinuousClock() function (f::Initial)(x) # wrap output if wrapped input @@ -1246,6 +1245,7 @@ function namespace_expr( O end end + _nonum(@nospecialize x) = x isa Num ? x.val : x """ diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 09a259610a..a97ab2d233 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -143,12 +143,11 @@ before the callback is triggered. """ struct Pre <: Symbolics.Operator end Pre(x) = Pre()(x) +is_timevarying_operator(::Type{Pre}) = false SymbolicUtils.promote_symtype(::Type{Pre}, T) = T SymbolicUtils.isbinop(::Pre) = false Base.nameof(::Pre) = :Pre Base.show(io::IO, x::Pre) = print(io, "Pre") -input_timedomain(::Pre, _ = nothing) = ContinuousClock() -output_timedomain(::Pre, _ = nothing) = ContinuousClock() unPre(x::Num) = unPre(unwrap(x)) unPre(x::Symbolics.Arr) = unPre(unwrap(x)) unPre(x::Symbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x @@ -252,6 +251,7 @@ struct SymbolicContinuousCallback <: AbstractCallback finalize::Union{Affect, SymbolicAffect, Nothing} rootfind::Union{Nothing, SciMLBase.RootfindOpt} reinitializealg::SciMLBase.DAEInitializationAlgorithm + zero_crossing_id::Symbol end function SymbolicContinuousCallback( @@ -262,6 +262,7 @@ function SymbolicContinuousCallback( finalize = nothing, rootfind = SciMLBase.LeftRootFind, reinitializealg = nothing, + zero_crossing_id = gensym(), kwargs...) conditions = (conditions isa AbstractVector) ? conditions : [conditions] @@ -278,7 +279,7 @@ function SymbolicContinuousCallback( SymbolicAffect(affect_neg; kwargs...), SymbolicAffect(initialize; kwargs...), SymbolicAffect( finalize; kwargs...), - rootfind, reinitializealg) + rootfind, reinitializealg, zero_crossing_id) end # Default affect to nothing function SymbolicContinuousCallback(p::Pair, args...; kwargs...) @@ -297,7 +298,7 @@ end function complete(cb::SymbolicContinuousCallback; kwargs...) SymbolicContinuousCallback(cb.conditions, make_affect(cb.affect; kwargs...), make_affect(cb.affect_neg; kwargs...), make_affect(cb.initialize; kwargs...), - make_affect(cb.finalize; kwargs...), cb.rootfind, cb.reinitializealg) + make_affect(cb.finalize; kwargs...), cb.rootfind, cb.reinitializealg, cb.zero_crossing_id) end make_affect(affect::SymbolicAffect; kwargs...) = AffectSystem(affect; kwargs...) @@ -512,7 +513,8 @@ function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuo affect_neg = namespace_affects(affect_negs(cb), s), initialize = namespace_affects(initialize_affects(cb), s), finalize = namespace_affects(finalize_affects(cb), s), - rootfind = cb.rootfind, reinitializealg = cb.reinitializealg) + rootfind = cb.rootfind, reinitializealg = cb.reinitializealg, + zero_crossing_id = cb.zero_crossing_id) end function namespace_conditions(condition, s) @@ -536,6 +538,8 @@ function Base.hash(cb::AbstractCallback, s::UInt) s = hash(finalize_affects(cb), s) !is_discrete(cb) && (s = hash(cb.rootfind, s)) hash(cb.reinitializealg, s) + !is_discrete(cb) && (s = hash(cb.zero_crossing_id, s)) + return s end ########################### @@ -570,13 +574,17 @@ function finalize_affects(cbs::Vector{<:AbstractCallback}) end function Base.:(==)(e1::AbstractCallback, e2::AbstractCallback) - (is_discrete(e1) === is_discrete(e2)) || return false - (isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) && - isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)) && - isequal(e1.reinitializealg, e2.reinitializealg) || - return false - is_discrete(e1) || - (isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind)) + is_discrete(e1) === is_discrete(e2) || return false + isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) || return false + isequal(e1.initialize, e2.initialize) || return false + isequal(e1.finalize, e2.finalize) || return false + isequal(e1.reinitializealg, e2.reinitializealg) || return false + if !is_discrete(e1) + isequal(e1.affect_neg, e2.affect_neg) || return false + isequal(e1.rootfind, e2.rootfind) || return false + isequal(e1.zero_crossing_id, e2.zero_crossing_id) || return false + end + return true end Base.isempty(cb::AbstractCallback) = isempty(cb.conditions) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 97b6be27ab..a88e8c42fe 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -1,10 +1,20 @@ +@data ClockVertex begin + Variable(Int) + Equation(Int) + InitEquation(Int) + Clock(SciMLBase.AbstractClock) +end + struct ClockInference{S} """Tearing state.""" ts::S """The time domain (discrete clock, continuous) of each equation.""" eq_domain::Vector{TimeDomain} + """The time domain (discrete clock, continuous) of each initialization equation.""" + init_eq_domain::Vector{TimeDomain} """The output time domain (discrete clock, continuous) of each variable.""" var_domain::Vector{TimeDomain} + inference_graph::HyperGraph{ClockVertex.Type} """The set of variables with concrete domains.""" inferred::BitSet end @@ -13,6 +23,8 @@ function ClockInference(ts::TransformationState) @unpack structure = ts @unpack graph = structure eq_domain = TimeDomain[ContinuousClock() for _ in 1:nsrcs(graph)] + init_eq_domain = TimeDomain[ContinuousClock() + for _ in 1:length(initialization_equations(ts.sys))] var_domain = TimeDomain[ContinuousClock() for _ in 1:ndsts(graph)] inferred = BitSet() for (i, v) in enumerate(get_fullvars(ts)) @@ -22,7 +34,24 @@ function ClockInference(ts::TransformationState) var_domain[i] = d end end - ClockInference(ts, eq_domain, var_domain, inferred) + inference_graph = HyperGraph{ClockVertex.Type}() + for i in 1:nsrcs(graph) + add_vertex!(inference_graph, ClockVertex.Equation(i)) + end + for i in eachindex(initialization_equations(ts.sys)) + add_vertex!(inference_graph, ClockVertex.InitEquation(i)) + end + for i in 1:ndsts(graph) + varvert = ClockVertex.Variable(i) + add_vertex!(inference_graph, varvert) + v = ts.fullvars[i] + d = get_time_domain(v) + is_concrete_time_domain(d) || continue + dvert = ClockVertex.Clock(d) + add_vertex!(inference_graph, dvert) + add_edge!(inference_graph, (varvert, dvert)) + end + ClockInference(ts, eq_domain, init_eq_domain, var_domain, inference_graph, inferred) end struct NotInferredTimeDomain end @@ -75,47 +104,163 @@ end Update the equation-to-time domain mapping by inferring the time domain from the variables. """ function infer_clocks!(ci::ClockInference) - @unpack ts, eq_domain, var_domain, inferred = ci + @unpack ts, eq_domain, init_eq_domain, var_domain, inferred, inference_graph = ci @unpack var_to_diff, graph = ts.structure fullvars = get_fullvars(ts) isempty(inferred) && return ci - # TODO: add a graph type to do this lazily - var_graph = SimpleGraph(ndsts(graph)) - for eq in 𝑠vertices(graph) - vvs = 𝑠neighbors(graph, eq) - if !isempty(vvs) - fv, vs = Iterators.peel(vvs) - for v in vs - add_edge!(var_graph, fv, v) + + var_to_idx = Dict(fullvars .=> eachindex(fullvars)) + + # all shifted variables have the same clock as the unshifted variant + for (i, v) in enumerate(fullvars) + iscall(v) || continue + operation(v) isa Shift || continue + unshifted = only(arguments(v)) + add_edge!(inference_graph, ( + ClockVertex.Variable(i), ClockVertex.Variable(var_to_idx[unshifted]))) + end + + # preallocated buffers: + # variables in each equation + varsbuf = Set() + # variables in each argument to an operator + arg_varsbuf = Set() + # hyperedge for each equation + hyperedge = Set{ClockVertex.Type}() + # hyperedge for each argument to an operator + arg_hyperedge = Set{ClockVertex.Type}() + # mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition + relative_hyperedges = Dict{Int, Set{ClockVertex.Type}}() + + function infer_equation(ieq, eq, is_initialization_equation) + empty!(varsbuf) + empty!(hyperedge) + # get variables in equation + vars!(varsbuf, eq; op = Symbolics.Operator) + # add the equation to the hyperedge + eq_node = if is_initialization_equation + ClockVertex.InitEquation(ieq) + else + ClockVertex.Equation(ieq) + end + push!(hyperedge, eq_node) + for var in varsbuf + idx = get(var_to_idx, var, nothing) + # if this is just a single variable, add it to the hyperedge + if idx isa Int + push!(hyperedge, ClockVertex.Variable(idx)) + # we don't immediately `continue` here because this variable might be a + # `Sample` or similar and we want the clock information from it if it is. + end + # now we only care about synchronous operators + iscall(var) || continue + op = operation(var) + is_timevarying_operator(op) || continue + + # arguments and corresponding time domains + args = arguments(var) + tdomains = input_timedomain(op) + if !(tdomains isa AbstractArray || tdomains isa Tuple) + tdomains = [tdomains] + end + nargs = length(args) + ndoms = length(tdomains) + if nargs != ndoms + throw(ArgumentError(""" + Operator $op applied to $nargs arguments $args but only returns $ndoms \ + domains $tdomains from `input_timedomain`. + """)) + end + + # each relative clock mapping is only valid per operator application + empty!(relative_hyperedges) + for (arg, domain) in zip(args, tdomains) + empty!(arg_varsbuf) + empty!(arg_hyperedge) + # get variables in argument + vars!(arg_varsbuf, arg; op = Union{Differential, Shift}) + # get hyperedge for involved variables + for v in arg_varsbuf + vidx = get(var_to_idx, v, nothing) + vidx === nothing && continue + push!(arg_hyperedge, ClockVertex.Variable(vidx)) + end + + Moshi.Match.@match domain begin + # If the time domain for this argument is a clock, then all variables in this edge have that clock. + x::SciMLBase.AbstractClock => begin + # add the clock to the edge + push!(arg_hyperedge, ClockVertex.Clock(x)) + # add the edge to the graph + add_edge!(inference_graph, arg_hyperedge) + end + # We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the + # involved variables have the same clock. + InferredClock.Inferred() => add_edge!(inference_graph, arg_hyperedge) + # All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't + # add the edge, and instead add this to the `relative_hyperedges` mapping. + InferredClock.InferredDiscrete(i) => begin + relative_edge = get!(() -> Set{ClockVertex.Type}(), relative_hyperedges, i) + union!(relative_edge, arg_hyperedge) + end + end + end + + outdomain = output_timedomain(op) + Moshi.Match.@match outdomain begin + x::SciMLBase.AbstractClock => begin + push!(hyperedge, ClockVertex.Clock(x)) + end + InferredClock.Inferred() => nothing + InferredClock.InferredDiscrete(i) => begin + buffer = get(relative_hyperedges, i, nothing) + if buffer !== nothing + union!(hyperedge, buffer) + delete!(relative_hyperedges, i) + end + end + end + + for (_, relative_edge) in relative_hyperedges + add_edge!(inference_graph, relative_edge) end end + + add_edge!(inference_graph, hyperedge) end - for v in vertices(var_to_diff) - if (v′ = var_to_diff[v]) !== nothing - add_edge!(var_graph, v, v′) - end + for (ieq, eq) in enumerate(equations(ts)) + infer_equation(ieq, eq, false) + end + for (ieq, eq) in enumerate(initialization_equations(ts.sys)) + infer_equation(ieq, eq, true) end - cc = connected_components(var_graph) - for c′ in cc - c = BitSet(c′) - idxs = intersect(c, inferred) - isempty(idxs) && continue - if !allequal(var_domain[i] for i in idxs) - display(fullvars[c′]) - throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])")) + + clock_partitions = connectionsets(inference_graph) + for partition in clock_partitions + clockidxs = findall(vert -> Moshi.Data.isa_variant(vert, ClockVertex.Clock), partition) + if isempty(clockidxs) + push!(partition, ClockVertex.Clock(ContinuousClock())) + push!(clockidxs, length(partition)) end - vd = var_domain[first(idxs)] - for v in c′ - var_domain[v] = vd + if length(clockidxs) > 1 + vidxs = Int[vert.:1 + for vert in partition + if Moshi.Data.isa_variant(vert, ClockVertex.Variable)] + clks = [vert.:1 for vert in view(partition, clockidxs)] + throw(ArgumentError(""" + Found clock partition with multiple associated clocks. Involved variables: \ + $(fullvars[vidxs]). Involved clocks: $(clks). + """)) end - end - for v in 𝑑vertices(graph) - vd = var_domain[v] - eqs = 𝑑neighbors(graph, v) - isempty(eqs) && continue - for eq in eqs - eq_domain[eq] = vd + clock = partition[only(clockidxs)].:1 + for vert in partition + Moshi.Match.@match vert begin + ClockVertex.Variable(i) => (var_domain[i] = clock) + ClockVertex.Equation(i) => (eq_domain[i] = clock) + ClockVertex.InitEquation(i) => (init_eq_domain[i] = clock) + ClockVertex.Clock(_) => nothing + end end end @@ -135,15 +280,21 @@ function resize_or_push!(v, val, idx) end function is_time_domain_conversion(v) - iscall(v) && (o = operation(v)) isa Operator && - input_timedomain(o) != output_timedomain(o) + iscall(v) || return false + o = operation(v) + o isa Operator || return false + itd = input_timedomain(o) + allequal(itd) || return true + otd = output_timedomain(o) + itd[1] == otd || return true + return false end """ For multi-clock systems, create a separate system for each clock in the system, along with associated equations. Return the updated tearing state, and the sets of clocked variables associated with each time domain. """ function split_system(ci::ClockInference{S}) where {S} - @unpack ts, eq_domain, var_domain, inferred = ci + @unpack ts, eq_domain, init_eq_domain, var_domain, inferred = ci fullvars = get_fullvars(ts) @unpack graph = ts.structure continuous_id = Ref(0) @@ -151,10 +302,15 @@ function split_system(ci::ClockInference{S}) where {S} id_to_clock = TimeDomain[] eq_to_cid = Vector{Int}(undef, nsrcs(graph)) cid_to_eq = Vector{Int}[] + init_eq_to_cid = Vector{Int}(undef, length(initialization_equations(ts.sys))) + cid_to_init_eq = Vector{Int}[] var_to_cid = Vector{Int}(undef, ndsts(graph)) cid_to_var = Vector{Int}[] # cid_counter = number of clocks cid_counter = Ref(0) + + # populates clock_to_id and id_to_clock + # checks if there is a continuous_id (for some reason? clock to id does this too) for (i, d) in enumerate(eq_domain) cid = let cid_counter = cid_counter, id_to_clock = id_to_clock, continuous_id = continuous_id @@ -173,10 +329,23 @@ function split_system(ci::ClockInference{S}) where {S} eq_to_cid[i] = cid resize_or_push!(cid_to_eq, i, cid) end + # NOTE: This assumes that there is at least one equation for each clock + for _ in 1:length(cid_to_eq) + push!(cid_to_init_eq, Int[]) + end + for (i, d) in enumerate(init_eq_domain) + cid = clock_to_id[d] + init_eq_to_cid[i] = cid + push!(cid_to_init_eq[cid], i) + end continuous_id = continuous_id[] + # for each clock partition what are the input (indexes/vars) input_idxs = map(_ -> Int[], 1:cid_counter[]) inputs = map(_ -> Any[], 1:cid_counter[]) + # var_domain corresponds to fullvars/all variables in the system nvv = length(var_domain) + # put variables into the right clock partition + # keep track of inputs to each partition for i in 1:nvv d = var_domain[i] cid = get(clock_to_id, d, 0) @@ -190,15 +359,17 @@ function split_system(ci::ClockInference{S}) where {S} resize_or_push!(cid_to_var, i, cid) end + # breaks the system up into a continous and 0 or more discrete systems tss = similar(cid_to_eq, S) - for (id, ieqs) in enumerate(cid_to_eq) - ts_i = system_subset(ts, ieqs) + for (id, (ieqs, iieqs, ivars)) in enumerate(zip(cid_to_eq, cid_to_init_eq, cid_to_var)) + ts_i = system_subset(ts, ieqs, iieqs, ivars) if id != continuous_id ts_i = shift_discrete_system(ts_i) @set! ts_i.structure.only_discrete = true end tss[id] = ts_i end + # put the continous system at the back if continuous_id != 0 tss[continuous_id], tss[end] = tss[end], tss[continuous_id] inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id] diff --git a/src/systems/connectiongraph.jl b/src/systems/connectiongraph.jl index 5c5e8716c6..99110e37e9 100644 --- a/src/systems/connectiongraph.jl +++ b/src/systems/connectiongraph.jl @@ -119,15 +119,15 @@ connection sets. $(TYPEDFIELDS) """ -struct ConnectionGraph +struct HyperGraph{V} """ Mapping from vertices to their integer ID. """ - labels::Dict{ConnectionVertex, Int} + labels::Dict{V, Int} """ Reverse mapping from integer ID to vertices. """ - invmap::Vector{ConnectionVertex} + invmap::Vector{V} """ Core data structure for storing the hypergraph. Each hyperedge is a source vertex and has bipartite edges to the connection vertices it is incident on. @@ -135,14 +135,16 @@ struct ConnectionGraph graph::BipartiteGraph{Int, Nothing} end +const ConnectionGraph = HyperGraph{ConnectionVertex} + """ $(TYPEDSIGNATURES) Create an empty `ConnectionGraph`. """ -function ConnectionGraph() +function HyperGraph{V}() where {V} graph = BipartiteGraph(0, 0, Val(true)) - return ConnectionGraph(Dict{ConnectionVertex, Int}(), ConnectionVertex[], graph) + return HyperGraph{V}(Dict{V, Int}(), V[], graph) end function Base.show(io::IO, graph::ConnectionGraph) @@ -178,7 +180,7 @@ end Add the given vertex to the connection graph. Return the integer ID of the added vertex. No-op if the vertex already exists. """ -function Graphs.add_vertex!(graph::ConnectionGraph, dst::ConnectionVertex) +function Graphs.add_vertex!(graph::HyperGraph{V}, dst::V) where {V} j = get(graph.labels, dst, 0) iszero(j) || return j j = Graphs.add_vertex!(graph.graph, DST) @@ -188,7 +190,8 @@ function Graphs.add_vertex!(graph::ConnectionGraph, dst::ConnectionVertex) return j end -const ConnectionGraphEdge = Union{Vector{ConnectionVertex}, Tuple{Vararg{ConnectionVertex}}} +const HyperGraphEdge{V} = Union{Vector{V}, Tuple{Vararg{V}}, Set{V}} +const ConnectionGraphEdge = HyperGraphEdge{ConnectionVertex} """ $(TYPEDSIGNATURES) @@ -196,7 +199,7 @@ const ConnectionGraphEdge = Union{Vector{ConnectionVertex}, Tuple{Vararg{Connect Add the given hyperedge to the connection graph. Adds all vertices in the given edge if they do not exist. Returns the integer ID of the added edge. """ -function Graphs.add_edge!(graph::ConnectionGraph, src::ConnectionGraphEdge) +function Graphs.add_edge!(graph::HyperGraph{V}, src::HyperGraphEdge{V}) where {V} i = Graphs.add_vertex!(graph.graph, SRC) for vert in src j = Graphs.add_vertex!(graph, vert) @@ -447,7 +450,7 @@ end Return the merged connection sets in `graph` as a `Vector{Vector{ConnectionVertex}}`. These are equivalent to the connected components of `graph`. """ -function connectionsets(graph::ConnectionGraph) +function connectionsets(graph::HyperGraph{V}) where {V} bigraph = graph.graph invmap = graph.invmap @@ -465,11 +468,11 @@ function connectionsets(graph::ConnectionGraph) # maps the root of a vertex in `disjoint_sets` to the index of the corresponding set # in `vertex_sets` root_to_set = Dict{Int, Int}() - vertex_sets = Vector{ConnectionVertex}[] + vertex_sets = Vector{V}[] for (vert_i, vert) in enumerate(invmap) root = find_root!(disjoint_sets, vert_i) set_i = get!(root_to_set, root) do - push!(vertex_sets, ConnectionVertex[]) + push!(vertex_sets, V[]) return length(vertex_sets) end push!(vertex_sets[set_i], vert) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 7b1a9fb286..f3d45e258a 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -262,7 +262,9 @@ function compile_functional_affect( upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) # write the new values back to the integrator - _generated_writeback(integ, upd_funs, upd_vals) + if !isnothing(upd_vals) + _generated_writeback(integ, upd_funs, upd_vals) + end reset_jumps && reset_aggregated_jumps!(integ) end diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index b57412bf2e..19c78413cf 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -1,5 +1,5 @@ struct BufferTemplate - type::Union{DataType, UnionAll} + type::Union{DataType, UnionAll, Union} length::Int end diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 7a9bc2617e..3fda3c450e 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1296,6 +1296,8 @@ function get_p_constructor(p_constructor, pType::Type, floatT::Type) end end +abstract type ProblemConstructionHook end + """ $(TYPEDSIGNATURES) @@ -1348,6 +1350,8 @@ function process_SciMLProblem( check_inputmap_keys(sys, op) + op = getmetadata(sys, ProblemConstructionHook, identity)(op) + defs = add_toterms(recursive_unwrap(defaults(sys)); replace = is_discrete_system(sys)) kwargs = NamedTuple(kwargs) diff --git a/src/systems/state_machines.jl b/src/systems/state_machines.jl index 347f92e6f8..ea65981804 100644 --- a/src/systems/state_machines.jl +++ b/src/systems/state_machines.jl @@ -153,3 +153,36 @@ entry When used in a finite state machine, this operator returns `true` if the queried state is active and false otherwise. """ activeState + +function vars!(vars, O::Transition; op = Differential) + vars!(vars, O.from) + vars!(vars, O.to) + vars!(vars, O.cond; op) + return vars +end +function vars!(vars, O::InitialState; op = Differential) + vars!(vars, O.s; op) + return vars +end +function vars!(vars, O::StateMachineOperator; op = Differential) + error("Unhandled state machine operator") +end + +function namespace_expr( + O::Transition, sys, n = nameof(sys); ivs = independent_variables(sys)) + return Transition( + O.from === nothing ? O.from : renamespace(sys, O.from), + O.to === nothing ? O.to : renamespace(sys, O.to), + O.cond === nothing ? O.cond : namespace_expr(O.cond, sys), + O.immediate, O.reset, O.synchronize, O.priority + ) +end + +function namespace_expr( + O::InitialState, sys, n = nameof(sys); ivs = independent_variables(sys)) + return InitialState(O.s === nothing ? O.s : renamespace(sys, O.s)) +end + +function namespace_expr(O::StateMachineOperator, sys, n = nameof(sys); kwargs...) + error("Unhandled state machine operator") +end diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 891714dce6..4c52300239 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -36,7 +36,7 @@ function mtkcompile( isscheduled(sys) && throw(RepeatedStructuralSimplificationError()) newsys′ = __mtkcompile(sys; simplify, allow_symbolic, allow_parameter, conservative, fully_determined, - inputs, outputs, disturbance_inputs, + inputs, outputs, disturbance_inputs, additional_passes, kwargs...) if newsys′ isa Tuple @assert length(newsys′) == 2 @@ -75,12 +75,13 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify) end + sys, statemachines = extract_top_level_statemachines(sys) sys = expand_connections(sys) - state = TearingState(sys; sort_eqs) + state = TearingState(sys) + append!(state.statemachines, statemachines) @unpack structure, fullvars = state @unpack graph, var_to_diff, var_types = structure - eqs = equations(state) brown_vars = Int[] new_idxs = zeros(Int, length(var_types)) idx = 0 @@ -98,7 +99,8 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, Is = Int[] Js = Int[] vals = Num[] - new_eqs = copy(eqs) + make_eqs_zero_equals!(state) + new_eqs = copy(equations(state)) dvar2eq = Dict{Any, Int}() for (v, dv) in enumerate(var_to_diff) dv === nothing && continue @@ -293,3 +295,8 @@ function map_variables_to_equations(sys::AbstractSystem; rename_dummy_derivative return mapping end + +""" +Mark whether an extra pass `p` can support compiling discrete systems. +""" +discrete_compile_pass(p) = false diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 4fdb96c789..2841acf711 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -215,30 +215,62 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} are not used in the rest of the system. """ additional_observed::Vector{Equation} + statemachines::Vector{T} end TransformationState(sys::AbstractSystem) = TearingState(sys) -function system_subset(ts::TearingState, ieqs::Vector{Int}) +function system_subset(ts::TearingState, ieqs::Vector{Int}, iieqs::Vector{Int}, ivars::Vector{Int}) eqs = equations(ts) + initeqs = initialization_equations(ts.sys) @set! ts.sys.eqs = eqs[ieqs] + @set! ts.sys.initialization_eqs = initeqs[iieqs] @set! ts.original_eqs = ts.original_eqs[ieqs] - @set! ts.structure = system_subset(ts.structure, ieqs) + @set! ts.structure = system_subset(ts.structure, ieqs, ivars) + if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys)) + names = Symbol[] + for eq in get_eqs(ts.sys) + if eq.lhs isa Transition + push!(names, first(namespace_hierarchy(nameof(eq.rhs.from)))) + push!(names, first(namespace_hierarchy(nameof(eq.rhs.to)))) + elseif eq.lhs isa InitialState + push!(names, first(namespace_hierarchy(nameof(eq.rhs.s)))) + else + error("Unhandled state machine operator") + end + end + @set! ts.statemachines = filter(x -> nameof(x) in names, ts.statemachines) + else + @set! ts.statemachines = eltype(ts.statemachines)[] + end + @set! ts.fullvars = ts.fullvars[ivars] ts end -function system_subset(structure::SystemStructure, ieqs::Vector{Int}) - @unpack graph, eq_to_diff = structure +function system_subset(structure::SystemStructure, ieqs::Vector{Int}, ivars::Vector{Int}) + @unpack graph = structure fadj = Vector{Int}[] eq_to_diff = DiffGraph(length(ieqs)) + var_to_diff = DiffGraph(length(ivars)) + ne = 0 + old_to_new_var = zeros(Int, ndsts(graph)) + for (i, iv) in enumerate(ivars) + old_to_new_var[iv] = i + end + for (i, iv) in enumerate(ivars) + structure.var_to_diff[iv] === nothing && continue + var_to_diff[i] = old_to_new_var[structure.var_to_diff[iv]] + end for (j, eq_i) in enumerate(ieqs) - ivars = copy(graph.fadjlist[eq_i]) - ne += length(ivars) - push!(fadj, ivars) + var_adj = [old_to_new_var[i] for i in graph.fadjlist[eq_i]] + @assert all(!iszero, var_adj) + ne += length(var_adj) + push!(fadj, var_adj) eq_to_diff[j] = structure.eq_to_diff[eq_i] end - @set! structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph))) + @set! structure.graph = complete(BipartiteGraph(ne, fadj, length(ivars))) @set! structure.eq_to_diff = eq_to_diff + @set! structure.var_to_diff = complete(var_to_diff) structure end @@ -276,6 +308,49 @@ function symbolic_contains(var, set) all(x -> x in set, Symbolics.scalarize(var)) end +""" + $(TYPEDSIGNATURES) + +Descend through the system hierarchy and look for statemachines. Remove equations from +the inner statemachine systems. Return the new `sys` and an array of top-level +statemachines. +""" +function extract_top_level_statemachines(sys::AbstractSystem) + eqs = get_eqs(sys) + + if !isempty(eqs) && all(eq -> eq.lhs isa StateMachineOperator, eqs) + # top-level statemachine + with_removed = @set sys.systems = map(remove_child_equations, get_systems(sys)) + return with_removed, [sys] + elseif !isempty(eqs) && any(eq -> eq.lhs isa StateMachineOperator, eqs) + # error: can't mix + error("Mixing statemachine equations and standard equations in a top-level statemachine is not allowed.") + else + # descend + subsystems = get_systems(sys) + newsubsystems = eltype(subsystems)[] + statemachines = eltype(subsystems)[] + for subsys in subsystems + newsubsys, sub_statemachines = extract_top_level_statemachines(subsys) + push!(newsubsystems, newsubsys) + append!(statemachines, sub_statemachines) + end + @set! sys.systems = newsubsystems + return sys, statemachines + end +end + +""" + $(TYPEDSIGNATURES) + +Return `sys` with all equations (including those in subsystems) removed. +""" +function remove_child_equations(sys::AbstractSystem) + @set! sys.eqs = eltype(get_eqs(sys))[] + @set! sys.systems = map(remove_child_equations, get_systems(sys)) + return sys +end + function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) # flatten system sys = flatten(sys) @@ -341,9 +416,16 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) # change the equation if the RHS is `missing` so the rest of this loop works eq = 0.0 ~ coalesce(eq.rhs, 0.0) end - rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs - if !_iszero(eq.lhs) + is_statemachine_equation = false + if eq.lhs isa StateMachineOperator + is_statemachine_equation = true + eq = eq + rhs = eq.rhs + elseif _iszero(eq.lhs) + rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs + else lhs = quick_cancel ? quick_cancel_expr(eq.lhs) : eq.lhs + rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs eq = 0 ~ rhs - lhs end empty!(varsbuf) @@ -390,8 +472,9 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) addvar!(v, VARIABLE) if iscall(v) && operation(v) isa Symbolics.Operator && !isdifferential(v) && (it = input_timedomain(v)) !== nothing - v′ = only(arguments(v)) - addvar!(setmetadata(v′, VariableTimeDomain, it), VARIABLE) + for v′ in arguments(v) + addvar!(setmetadata(v′, VariableTimeDomain, it), VARIABLE) + end end end @@ -408,8 +491,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) addvar!(v, VARIABLE) end end - - if isalgeq + if isalgeq || is_statemachine_equation eqs[i] = eq else eqs[i] = eqs[i].lhs ~ rhs @@ -530,8 +612,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) ts = TearingState(sys, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), complete(graph), nothing, var_types, false), - Any[], param_derivative_map, original_eqs, Equation[]) - + Any[], param_derivative_map, original_eqs, Equation[], typeof(sys)[]) return ts end @@ -813,29 +894,78 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure) printstyled(io, " SelectedState") end +function make_eqs_zero_equals!(ts::TearingState) + neweqs = map(enumerate(get_eqs(ts.sys))) do kvp + i, eq = kvp + isalgeq = true + for j in 𝑠neighbors(ts.structure.graph, i) + isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing + end + if isalgeq + return 0 ~ eq.rhs - eq.lhs + else + return eq + end + end + copyto!(get_eqs(ts.sys), neweqs) +end + function mtkcompile!(state::TearingState; simplify = false, check_consistency = true, fully_determined = true, warn_initialize_determined = true, inputs = Any[], outputs = Any[], disturbance_inputs = Any[], kwargs...) + if !is_time_dependent(state.sys) + return _mtkcompile!(state; simplify, check_consistency, + inputs, outputs, disturbance_inputs, + fully_determined, kwargs...) + end + # split_system returns one or two systems and the inputs for each + # mod clock inference to be binary + # if it's continous keep going, if not then error unless given trait impl in additional passes ci = ModelingToolkit.ClockInference(state) ci = ModelingToolkit.infer_clocks!(ci) time_domains = merge(Dict(state.fullvars .=> ci.var_domain), Dict(default_toterm.(state.fullvars) .=> ci.var_domain)) tss, clocked_inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci) + if continuous_id == 0 + # do a trait check here - handle fully discrete system + additional_passes = get(kwargs, :additional_passes, nothing) + if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) + # take the first discrete compilation pass given for now + discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) + discrete_compile = additional_passes[discrete_pass_idx] + deleteat!(additional_passes, discrete_pass_idx) + return discrete_compile(tss, clocked_inputs, ci) + end + throw(HybridSystemNotSupportedException(""" + Discrete systems with multiple clocks are not supported with the standard \ + MTK compiler. + """)) + end if length(tss) > 1 - if continuous_id == 0 - throw(HybridSystemNotSupportedException(""" - Discrete systems with multiple clocks are not supported with the standard \ - MTK compiler. - """)) - else - throw(HybridSystemNotSupportedException(""" - Hybrid continuous-discrete systems are currently not supported with \ - the standard MTK compiler. This system requires JuliaSimCompiler.jl, \ - see https://help.juliahub.com/juliasimcompiler/stable/ - """)) + make_eqs_zero_equals!(tss[continuous_id]) + # simplify as normal + sys = _mtkcompile!(tss[continuous_id]; simplify, + inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs, + check_consistency, fully_determined, + kwargs...) + additional_passes = get(kwargs, :additional_passes, nothing) + if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) + discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) + discrete_compile = additional_passes[discrete_pass_idx] + deleteat!(additional_passes, discrete_pass_idx) + # in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems + # and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed + return discrete_compile( + sys, tss[[i for i in eachindex(tss) if i != continuous_id]], + clocked_inputs, ci, id_to_clock) end + throw(HybridSystemNotSupportedException(""" + Hybrid continuous-discrete systems are currently not supported with \ + the standard MTK compiler. This system requires JuliaSimCompiler.jl, \ + see https://help.juliahub.com/juliasimcompiler/stable/ + """)) end if get_is_discrete(state.sys) || continuous_id == 1 && any(Base.Fix2(isoperator, Shift), state.fullvars) diff --git a/src/utils.jl b/src/utils.jl index e96f31f533..d028d4ed18 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -391,6 +391,12 @@ vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op) function vars!(vars, eq::Equation; op = Differential) (vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars) end +function vars!(vars, O::AbstractSystem; op = Differential) + for eq in equations(O) + vars!(vars, eq; op) + end + return vars +end function vars!(vars, O; op = Differential) if isvariable(O) if iscall(O) && operation(O) === getindex && iscalledparameter(first(arguments(O))) diff --git a/test/clock.jl b/test/clock.jl index 446b0ffb0d..bb16884fe7 100644 --- a/test/clock.jl +++ b/test/clock.jl @@ -118,6 +118,34 @@ eqs = [yd ~ Sample(dt)(y) @named sys = System(eqs, t) @test_throws ModelingToolkit.HybridSystemNotSupportedException ss=mtkcompile(sys); +@testset "Clock inference uses and splits initialization equations" begin + @variables x(t) y(t) z(t) + k = ShiftIndex() + clk = Clock(0.1) + eqs = [D(x) ~ x, y ~ Sample(clk)(x), z ~ z(k-1) + 1] + initialization_eqs = [y ~ z, x ~ 1] + @named sys = System(eqs, t; initialization_eqs) + ts = TearingState(sys) + ci = ModelingToolkit.ClockInference(ts) + @test length(ci.init_eq_domain) == 2 + ModelingToolkit.infer_clocks!(ci) + canonical_eqs = map(eqs) do eq + if iscall(eq.lhs) && operation(eq.lhs) isa Differential + return eq + else + return 0 ~ eq.rhs - eq.lhs + end + end + eqs_idxs = findfirst.(isequal.(canonical_eqs), (equations(ci.ts),)) + @test ci.eq_domain[eqs_idxs[1]] == ContinuousClock() + @test ci.eq_domain[eqs_idxs[2]] == clk + @test ci.eq_domain[eqs_idxs[3]] == clk + varmap = Dict(ci.ts.fullvars .=> ci.var_domain) + @test varmap[x] == ContinuousClock() + @test varmap[y] == clk + @test varmap[z] == clk +end + @test_skip begin Tf = 1.0 prob = ODEProblem( diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 54ac59bdf4..cf90ccd283 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -206,18 +206,20 @@ end @testset "Condition Compilation" begin @named sys = System(eqs, t, continuous_events = [x ~ 1]) + cevt1 = getfield(sys, :continuous_events)[] @test getfield(sys, :continuous_events)[] == - SymbolicContinuousCallback(Equation[x ~ 1], nothing) + SymbolicContinuousCallback(Equation[x ~ 1], nothing; zero_crossing_id = cevt1.zero_crossing_id) @test isequal(equations(getfield(sys, :continuous_events))[], x ~ 1) fsys = flatten(sys) @test isequal(equations(getfield(fsys, :continuous_events))[], x ~ 1) @named sys2 = System([D(x) ~ 1], t, continuous_events = [x ~ 2], systems = [sys]) + cevt2 = getfield(sys2, :continuous_events)[] @test getfield(sys2, :continuous_events)[] == - SymbolicContinuousCallback(Equation[x ~ 2], nothing) + SymbolicContinuousCallback(Equation[x ~ 2], nothing; zero_crossing_id = cevt2.zero_crossing_id) @test all(ModelingToolkit.continuous_events(sys2) .== [ - SymbolicContinuousCallback(Equation[x ~ 2], nothing), - SymbolicContinuousCallback(Equation[sys.x ~ 1], nothing) + SymbolicContinuousCallback(Equation[x ~ 2], nothing; zero_crossing_id = cevt2.zero_crossing_id), + SymbolicContinuousCallback(Equation[sys.x ~ 1], nothing; zero_crossing_id = cevt1.zero_crossing_id) ]) @test isequal(equations(getfield(sys2, :continuous_events))[1], x ~ 2)