diff --git a/src/clock.jl b/src/clock.jl index 1c9ed89128..a230c54210 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -1,11 +1,15 @@ @data InferredClock begin Inferred - InferredDiscrete + InferredDiscrete(Int) end const InferredTimeDomain = InferredClock.Type using .InferredClock: Inferred, InferredDiscrete +function InferredClock.InferredDiscrete() + return InferredDiscrete(0) +end + Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x) struct VariableTimeDomain end diff --git a/src/discretedomain.jl b/src/discretedomain.jl index 370d93d894..e33cf1e429 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -10,6 +10,15 @@ are not transparent but `Sample` and `Hold` are. Defaults to `false` if not impl is_transparent_operator(x) = is_transparent_operator(typeof(x)) is_transparent_operator(::Type) = false +""" + $(TYPEDSIGNATURES) + +Trait to be implemented for operators which determines whether they are synchronous operators. +Synchronous operators must implement `input_timedomain` and `output_timedomain`. +""" +is_synchronous_operator(x) = is_synchronous_operator(typeof(x)) +is_synchronous_operator(::Type) = false + """ function SampleTime() @@ -52,6 +61,7 @@ struct Shift <: Operator end Shift(steps::Int) = new(nothing, steps) normalize_to_differential(s::Shift) = Differential(s.t)^s.steps +is_synchronous_operator(::Type{Shift}) = true Base.nameof(::Shift) = :Shift SymbolicUtils.isbinop(::Shift) = false @@ -138,6 +148,7 @@ struct Sample <: Operator Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete()) = new(clock) end +is_synchronous_operator(::Type{Sample}) = true is_transparent_operator(::Type{Sample}) = true function Sample(arg::Real) @@ -193,6 +204,7 @@ struct Hold <: Operator end is_transparent_operator(::Type{Hold}) = true +is_synchronous_operator(::Type{Hold}) = true (D::Hold)(x) = Term{symtype(x)}(D, Any[x]) (D::Hold)(x::Num) = Num(D(value(x))) @@ -314,12 +326,13 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i) input_timedomain(op::Operator) Return the time-domain type (`ContinuousClock()` or `InferredDiscrete()`) that `op` operates on. +Should return a tuple containing the time domain type for each argument to the operator. """ function input_timedomain(s::Shift, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() + (InferredDiscrete(),) end """ @@ -334,22 +347,20 @@ function output_timedomain(s::Shift, arg = nothing) InferredDiscrete() end -input_timedomain(::Sample, _ = nothing) = ContinuousClock() +input_timedomain(::Sample, _ = nothing) = (ContinuousClock(),) output_timedomain(s::Sample, _ = nothing) = s.clock function input_timedomain(h::Hold, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() # the Hold accepts any discrete + (InferredDiscrete(),) # the Hold accepts any discrete end output_timedomain(::Hold, _ = nothing) = ContinuousClock() sampletime(op::Sample, _ = nothing) = sampletime(op.clock) sampletime(op::ShiftIndex, _ = nothing) = sampletime(op.clock) -changes_domain(op) = isoperator(op, Union{Sample, Hold}) - function output_timedomain(x) if isoperator(x, Operator) return output_timedomain(operation(x), arguments(x)[]) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index ff2d77f19b..24ac8c11ee 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -1,3 +1,9 @@ +@data ClockVertex begin + Variable(Int) + Equation(Int) + Clock(SciMLBase.AbstractClock) +end + struct ClockInference{S} """Tearing state.""" ts::S @@ -5,6 +11,7 @@ struct ClockInference{S} eq_domain::Vector{TimeDomain} """The output time domain (discrete clock, continuous) of each variable.""" var_domain::Vector{TimeDomain} + inference_graph::HyperGraph{ClockVertex.Type} """The set of variables with concrete domains.""" inferred::BitSet end @@ -22,7 +29,21 @@ function ClockInference(ts::TransformationState) var_domain[i] = d end end - ClockInference(ts, eq_domain, var_domain, inferred) + inference_graph = HyperGraph{ClockVertex.Type}() + for i in 1:nsrcs(graph) + add_vertex!(inference_graph, ClockVertex.Equation(i)) + end + for i in 1:ndsts(graph) + varvert = ClockVertex.Variable(i) + add_vertex!(inference_graph, varvert) + v = ts.fullvars[i] + d = get_time_domain(v) + is_concrete_time_domain(d) || continue + dvert = ClockVertex.Clock(d) + add_vertex!(inference_graph, dvert) + add_edge!(inference_graph, (varvert, dvert)) + end + ClockInference(ts, eq_domain, var_domain, inference_graph, inferred) end struct NotInferredTimeDomain end @@ -75,47 +96,147 @@ end Update the equation-to-time domain mapping by inferring the time domain from the variables. """ function infer_clocks!(ci::ClockInference) - @unpack ts, eq_domain, var_domain, inferred = ci + @unpack ts, eq_domain, var_domain, inferred, inference_graph = ci @unpack var_to_diff, graph = ts.structure fullvars = get_fullvars(ts) isempty(inferred) && return ci - # TODO: add a graph type to do this lazily - var_graph = SimpleGraph(ndsts(graph)) - for eq in 𝑠vertices(graph) - vvs = 𝑠neighbors(graph, eq) - if !isempty(vvs) - fv, vs = Iterators.peel(vvs) - for v in vs - add_edge!(var_graph, fv, v) - end - end + + var_to_idx = Dict(fullvars .=> eachindex(fullvars)) + + # all shifted variables have the same clock as the unshifted variant + for (i, v) in enumerate(fullvars) + iscall(v) || continue + operation(v) isa Shift || continue + unshifted = only(arguments(v)) + add_edge!(inference_graph, (ClockVertex.Variable(i), ClockVertex.Variable(var_to_idx[unshifted]))) end - for v in vertices(var_to_diff) - if (v′ = var_to_diff[v]) !== nothing - add_edge!(var_graph, v, v′) + + # preallocated buffers: + # variables in each equation + varsbuf = Set() + # variables in each argument to an operator + arg_varsbuf = Set() + # hyperedge for each equation + hyperedge = Set{ClockVertex.Type}() + # hyperedge for each argument to an operator + arg_hyperedge = Set{ClockVertex.Type}() + # mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition + relative_hyperedges = Dict{Int, Set{ClockVertex.Type}}() + + for (ieq, eq) in enumerate(equations(ts)) + empty!(varsbuf) + empty!(hyperedge) + # get variables in equation + vars!(varsbuf, eq; op = Symbolics.Operator) + # add the equation to the hyperedge + push!(hyperedge, ClockVertex.Equation(ieq)) + for var in varsbuf + idx = get(var_to_idx, var, nothing) + # if this is just a single variable, add it to the hyperedge + if idx isa Int + push!(hyperedge, ClockVertex.Variable(idx)) + # we don't immediately `continue` here because this variable might be a + # `Sample` or similar and we want the clock information from it if it is. + end + # now we only care about synchronous operators + iscall(var) || continue + op = operation(var) + is_synchronous_operator(op) || continue + + # arguments and corresponding time domains + args = arguments(var) + tdomains = input_timedomain(op) + nargs = length(args) + ndoms = length(tdomains) + if nargs != ndoms + throw(ArgumentError(""" + Operator $op applied to $nargs arguments $args but only returns $ndoms \ + domains $tdomains from `input_timedomain`. + """)) + end + + # each relative clock mapping is only valid per operator application + empty!(relative_hyperedges) + for (arg, domain) in zip(args, tdomains) + empty!(arg_varsbuf) + empty!(arg_hyperedge) + # get variables in argument + vars!(arg_varsbuf, arg; op = Union{Differential, Shift}) + # get hyperedge for involved variables + for v in arg_varsbuf + vidx = get(var_to_idx, v, nothing) + vidx === nothing && continue + push!(arg_hyperedge, ClockVertex.Variable(vidx)) + end + + Moshi.Match.@match domain begin + # If the time domain for this argument is a clock, then all variables in this edge have that clock. + x::SciMLBase.AbstractClock => begin + # add the clock to the edge + push!(arg_hyperedge, ClockVertex.Clock(x)) + # add the edge to the graph + add_edge!(inference_graph, arg_hyperedge) + end + # We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the + # involved variables have the same clock. + InferredClock.Inferred() => add_edge!(inference_graph, arg_hyperedge) + # All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't + # add the edge, and instead add this to the `relative_hyperedges` mapping. + InferredClock.InferredDiscrete(i) => begin + relative_edge = get!(() -> Set{ClockVertex.Type}(), relative_hyperedges, i) + union!(relative_edge, arg_hyperedge) + end + end + end + + outdomain = output_timedomain(op) + Moshi.Match.@match outdomain begin + x::SciMLBase.AbstractClock => begin + push!(hyperedge, ClockVertex.Clock(x)) + end + InferredClock.Inferred() => nothing + InferredClock.InferredDiscrete(i) => begin + buffer = get(relative_hyperedges, i, nothing) + if buffer !== nothing + union!(hyperedge, buffer) + delete!(relative_hyperedges, i) + end + end + end + + for (_, relative_edge) in relative_hyperedges + add_edge!(inference_graph, relative_edge) + end end + + add_edge!(inference_graph, hyperedge) end - cc = connected_components(var_graph) - for c′ in cc - c = BitSet(c′) - idxs = intersect(c, inferred) - isempty(idxs) && continue - if !allequal(iscontinuous(var_domain[i]) for i in idxs) - display(fullvars[c′]) - throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])")) + + clock_partitions = connectionsets(inference_graph) + for partition in clock_partitions + clockidxs = findall(vert -> Moshi.Data.isa_variant(vert, ClockVertex.Clock), partition) + if isempty(clockidxs) + vidxs = Int[vert.:1 for vert in partition if Moshi.Data.isa_variant(vert, ClockVertex.Variable)] + throw(ArgumentError(""" + Found clock partion with no associated clock. Involved variables: $(fullvars[vidxs]). + """)) end - vd = var_domain[first(idxs)] - for v in c′ - var_domain[v] = vd + if length(clockidxs) > 1 + vidxs = Int[vert.:1 for vert in partition if Moshi.Data.isa_variant(vert, ClockVertex.Variable)] + clks = [vert.:1 for vert in view(partition, clockidxs)] + throw(ArgumentError(""" + Found clock partition with multiple associated clocks. Involved variables: \ + $(fullvars[vidxs]). Involved clocks: $(clks). + """)) end - end - for v in 𝑑vertices(graph) - vd = var_domain[v] - eqs = 𝑑neighbors(graph, v) - isempty(eqs) && continue - for eq in eqs - eq_domain[eq] = vd + clock = partition[only(clockidxs)].:1 + for vert in partition + Moshi.Match.@match vert begin + ClockVertex.Variable(i) => (var_domain[i] = clock) + ClockVertex.Equation(i) => (eq_domain[i] = clock) + ClockVertex.Clock(_) => nothing + end end end diff --git a/src/systems/connectiongraph.jl b/src/systems/connectiongraph.jl index 5c5e8716c6..99110e37e9 100644 --- a/src/systems/connectiongraph.jl +++ b/src/systems/connectiongraph.jl @@ -119,15 +119,15 @@ connection sets. $(TYPEDFIELDS) """ -struct ConnectionGraph +struct HyperGraph{V} """ Mapping from vertices to their integer ID. """ - labels::Dict{ConnectionVertex, Int} + labels::Dict{V, Int} """ Reverse mapping from integer ID to vertices. """ - invmap::Vector{ConnectionVertex} + invmap::Vector{V} """ Core data structure for storing the hypergraph. Each hyperedge is a source vertex and has bipartite edges to the connection vertices it is incident on. @@ -135,14 +135,16 @@ struct ConnectionGraph graph::BipartiteGraph{Int, Nothing} end +const ConnectionGraph = HyperGraph{ConnectionVertex} + """ $(TYPEDSIGNATURES) Create an empty `ConnectionGraph`. """ -function ConnectionGraph() +function HyperGraph{V}() where {V} graph = BipartiteGraph(0, 0, Val(true)) - return ConnectionGraph(Dict{ConnectionVertex, Int}(), ConnectionVertex[], graph) + return HyperGraph{V}(Dict{V, Int}(), V[], graph) end function Base.show(io::IO, graph::ConnectionGraph) @@ -178,7 +180,7 @@ end Add the given vertex to the connection graph. Return the integer ID of the added vertex. No-op if the vertex already exists. """ -function Graphs.add_vertex!(graph::ConnectionGraph, dst::ConnectionVertex) +function Graphs.add_vertex!(graph::HyperGraph{V}, dst::V) where {V} j = get(graph.labels, dst, 0) iszero(j) || return j j = Graphs.add_vertex!(graph.graph, DST) @@ -188,7 +190,8 @@ function Graphs.add_vertex!(graph::ConnectionGraph, dst::ConnectionVertex) return j end -const ConnectionGraphEdge = Union{Vector{ConnectionVertex}, Tuple{Vararg{ConnectionVertex}}} +const HyperGraphEdge{V} = Union{Vector{V}, Tuple{Vararg{V}}, Set{V}} +const ConnectionGraphEdge = HyperGraphEdge{ConnectionVertex} """ $(TYPEDSIGNATURES) @@ -196,7 +199,7 @@ const ConnectionGraphEdge = Union{Vector{ConnectionVertex}, Tuple{Vararg{Connect Add the given hyperedge to the connection graph. Adds all vertices in the given edge if they do not exist. Returns the integer ID of the added edge. """ -function Graphs.add_edge!(graph::ConnectionGraph, src::ConnectionGraphEdge) +function Graphs.add_edge!(graph::HyperGraph{V}, src::HyperGraphEdge{V}) where {V} i = Graphs.add_vertex!(graph.graph, SRC) for vert in src j = Graphs.add_vertex!(graph, vert) @@ -447,7 +450,7 @@ end Return the merged connection sets in `graph` as a `Vector{Vector{ConnectionVertex}}`. These are equivalent to the connected components of `graph`. """ -function connectionsets(graph::ConnectionGraph) +function connectionsets(graph::HyperGraph{V}) where {V} bigraph = graph.graph invmap = graph.invmap @@ -465,11 +468,11 @@ function connectionsets(graph::ConnectionGraph) # maps the root of a vertex in `disjoint_sets` to the index of the corresponding set # in `vertex_sets` root_to_set = Dict{Int, Int}() - vertex_sets = Vector{ConnectionVertex}[] + vertex_sets = Vector{V}[] for (vert_i, vert) in enumerate(invmap) root = find_root!(disjoint_sets, vert_i) set_i = get!(root_to_set, root) do - push!(vertex_sets, ConnectionVertex[]) + push!(vertex_sets, V[]) return length(vertex_sets) end push!(vertex_sets[set_i], vert)