diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 0ad1080965..d23d173b9a 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -161,6 +161,7 @@ include("systems/index_cache.jl") include("systems/parameter_buffer.jl") include("systems/abstractsystem.jl") include("systems/model_parsing.jl") +include("systems/connectiongraph.jl") include("systems/connectors.jl") include("systems/state_machines.jl") include("systems/analysis_points.jl") diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index e643be904d..f95cf90eea 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1368,104 +1368,6 @@ struct IgnoredAnalysisPoint outputs::Vector{Union{BasicSymbolic, AbstractSystem}} end -const HierarchyVariableT = Vector{Union{BasicSymbolic, Symbol}} -const HierarchySystemT = Vector{Union{AbstractSystem, Symbol}} -""" -The type returned from `analysis_point_common_hierarchy`. -""" -const HierarchyAnalysisPointT = Vector{Union{IgnoredAnalysisPoint, Symbol}} -""" -The type returned from `as_hierarchy`. -""" -const HierarchyT = Union{HierarchyVariableT, HierarchySystemT} - -""" - $(TYPEDSIGNATURES) - -The inverse operation of `as_hierarchy`. -""" -function from_hierarchy(hierarchy::HierarchyT) - namefn = hierarchy[1] isa AbstractSystem ? nameof : getname - foldl(@view hierarchy[2:end]; init = hierarchy[1]) do sys, name - rename(sys, Symbol(name, NAMESPACE_SEPARATOR, namefn(sys))) - end -end - -""" - $(TYPEDSIGNATURES) - -Represent an ignored analysis point as a namespaced hierarchy. The hierarchy is built -using the common hierarchy of all involved systems/variables. -""" -function analysis_point_common_hierarchy(ap::IgnoredAnalysisPoint)::HierarchyAnalysisPointT - isys = as_hierarchy(ap.input) - osyss = as_hierarchy.(ap.outputs) - suffix = Symbol[] - while isys[end] == osyss[1][end] && allequal(last.(osyss)) - push!(suffix, isys[end]) - pop!(isys) - pop!.(osyss) - end - isys = from_hierarchy(isys) - osyss = from_hierarchy.(osyss) - newap = IgnoredAnalysisPoint(isys, osyss) - hierarchy = HierarchyAnalysisPointT([suffix; newap]) - reverse!(hierarchy) - return hierarchy -end - -""" - $(TYPEDSIGNATURES) - -Represent a namespaced system (or variable) `sys` as a hierarchy. Return a vector, where -the first element is the unnamespaced system (variable) and subsequent elements are -`Symbol`s representing the parents of the unnamespaced system (variable) in order from -inner to outer. -""" -function as_hierarchy(sys::Union{AbstractSystem, BasicSymbolic})::HierarchyT - namefn = sys isa AbstractSystem ? nameof : getname - # get the hierarchy - hierarchy = namespace_hierarchy(namefn(sys)) - # rename the system with unnamespaced name - newsys = rename(sys, hierarchy[end]) - # and remove it from the list - pop!(hierarchy) - # reverse it to go from inner to outer - reverse!(hierarchy) - # concatenate - T = sys isa AbstractSystem ? AbstractSystem : BasicSymbolic - return Union{Symbol, T}[newsys; hierarchy] -end - -""" - $(TYPEDSIGNATURES) - -Get the analysis points to ignore for `sys` and its subsystems. The returned value is a -`Tuple` similar in structure to the `ignored_connections` field. -""" -function ignored_connections(sys::AbstractSystem) - has_ignored_connections(sys) || - return (HierarchyAnalysisPointT[], HierarchyAnalysisPointT[]) - - ics = get_ignored_connections(sys) - if ics === nothing - ics = (HierarchyAnalysisPointT[], HierarchyAnalysisPointT[]) - end - # turn into hierarchies - ics = (map(analysis_point_common_hierarchy, ics[1]), - map(analysis_point_common_hierarchy, ics[2])) - systems = get_systems(sys) - # for each subsystem, get its ignored connections, add the name of the subsystem - # to the hierarchy and concatenate corresponding buffers of the result - result = mapreduce(Broadcast.BroadcastFunction(vcat), systems; init = ics) do subsys - sub_ics = ignored_connections(subsys) - (map(Base.Fix2(push!, nameof(subsys)), sub_ics[1]), - map(Base.Fix2(push!, nameof(subsys)), sub_ics[2])) - end - return (Vector{HierarchyAnalysisPointT}(result[1]), - Vector{HierarchyAnalysisPointT}(result[2])) -end - """ $(TYPEDSIGNATURES) @@ -1993,35 +1895,20 @@ function n_expanded_connection_equations(sys::AbstractSystem) # TODO: what about inputs? isconnector(sys) && return length(get_unknowns(sys)) sys = remove_analysis_points(sys) - n_variable_connect_eqs = 0 - for eq in equations(sys) - is_causal_variable_connection(eq.rhs) || continue - n_variable_connect_eqs += length(get_systems(eq.rhs)) - 1 - end - sys, (csets, _) = generate_connection_set(sys) - ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets) - n_outer_stream_variables = 0 - for cset in instream_csets - n_outer_stream_variables += count(x -> x.isouter, cset.set) - end - - #n_toplevel_unused_flows = 0 - #toplevel_flows = Set() - #for cset in csets - # e1 = first(cset.set) - # e1.sys.namespace === nothing || continue - # for e in cset.set - # get_connection_type(e.v) === Flow || continue - # push!(toplevel_flows, e.v) - # end - #end - #for m in get_systems(sys) - # isconnector(m) || continue - # n_toplevel_unused_flows += count(x->get_connection_type(x) === Flow && !(x in toplevel_flows), get_unknowns(m)) - #end - - nextras = n_outer_stream_variables + length(ceqs) + n_variable_connect_eqs + + n_extras = 0 + for cset in csets + rep = cset[1] + if rep.type <: Union{InputVar, OutputVar, Equality} + n_extras += length(cset) - 1 + elseif rep.type == Flow + n_extras += 1 + elseif rep.type == Stream + n_extras += count(x -> x.isouter, cset) + end + end + return n_extras end Base.show(io::IO, sys::AbstractSystem; kws...) = show(io, MIME"text/plain"(), sys; kws...) diff --git a/src/systems/analysis_points.jl b/src/systems/analysis_points.jl index caa51f0740..4171042908 100644 --- a/src/systems/analysis_points.jl +++ b/src/systems/analysis_points.jl @@ -432,19 +432,15 @@ function with_analysis_point_ignored(sys::AbstractSystem, ap::AnalysisPoint) has_ignored_connections(sys) || return sys ignored = get_ignored_connections(sys) if ignored === nothing - ignored = (IgnoredAnalysisPoint[], IgnoredAnalysisPoint[]) + ignored = Connection[] else - ignored = copy.(ignored) + ignored = copy(ignored) end if ap.outputs === nothing error("Empty analysis point") end - if ap.input isa AbstractSystem && all(x -> x isa AbstractSystem, ap.outputs) - push!(ignored[1], IgnoredAnalysisPoint(ap.input, ap.outputs)) - else - push!(ignored[2], IgnoredAnalysisPoint(unwrap(ap.input), unwrap.(ap.outputs))) - end + push!(ignored, Connection([unwrap(ap.input); unwrap.(ap.outputs)])) return @set sys.ignored_connections = ignored end diff --git a/src/systems/connectiongraph.jl b/src/systems/connectiongraph.jl new file mode 100644 index 0000000000..e99f200732 --- /dev/null +++ b/src/systems/connectiongraph.jl @@ -0,0 +1,492 @@ +""" + $(TYPEDEF) + +A vertex in the connection hypergraph. + +## Fields + +$(TYPEDFIELDS) +""" +struct ConnectionVertex + """ + The name of the variable or subsystem represented by this connection vertex. Stored as + a list of names denoting the path from the root system to this variable/subsystem. The + name of the root system is not included. + """ + name::Vector{Symbol} + """ + Boolean indicating whether this is an outside connector. + """ + isouter::Bool + """ + A type indicating what kind of connector it is. One of: + - `Stream` + - `Equality` + - `Flow` + - `InputVar` + - `OutputVar` + """ + type::DataType + """ + The cached hash value of this struct. Should never be passed manually. + """ + hash::UInt +end + +""" + $(TYPEDSIGNATURES) + +Create a `ConnectionVertex` given +- `namespace`: the path from the root to the variable/subsystem. Does not include the root + system. +- `var`: the variable/subsystem. + +`isouter` is the same as the struct field. Uses `get_connection_type` to find the type to +use for this connection. +""" +function ConnectionVertex( + namespace::Vector{Symbol}, var::Union{BasicSymbolic, AbstractSystem}, isouter::Bool) + if var isa BasicSymbolic + name = getname(var) + else + name = nameof(var) + end + var_ns = namespace_hierarchy(name) + type = get_connection_type(var) + name = vcat(namespace, var_ns) + return ConnectionVertex(name, isouter, type; alias = true) +end + +""" + $(TYPEDSIGNATURES) + +Create a connection vertex for the given path. Typically used for domain connection graphs, +where the type of connection doesn't matter. Uses `isouter = true` and `type = Flow`. +""" +function ConnectionVertex(name::Vector{Symbol}) + return ConnectionVertex(name, true, Flow) +end + +""" + $(TYPEDSIGNATURES) + +Create a connection vertex for the given path `name` using the provided `isouter` and +`type`. `alias` denotes whether `name` can be stored by this vertex without copying. +""" +function ConnectionVertex( + name::Vector{Symbol}, isouter::Bool, type::DataType; alias = false) + if !alias + name = copy(name) + end + h = foldr(hash, name; init = zero(UInt)) + h = hash(type, h) + h = hash(isouter, h) + return ConnectionVertex(name, isouter, type, h) +end + +Base.hash(x::ConnectionVertex, h::UInt) = h ⊻ x.hash + +function Base.:(==)(a::ConnectionVertex, b::ConnectionVertex) + length(a.name) == length(b.name) || return false + for (x, y) in zip(a.name, b.name) + x == y || return false + end + a.isouter == b.isouter || return false + a.type == b.type || return false + if a.hash != b.hash + error(""" + This should never happen. Please open an issue in ModelingToolkit.jl with an MWE. + """) + end + return true +end + +function Base.show(io::IO, vert::ConnectionVertex) + for name in @view(vert.name[1:(end - 1)]) + print(io, name, ".") + end + print(io, vert.name[end], "::", vert.isouter ? "outer" : "inner") +end + +""" + $(TYPEDEF) + +A hypergraph used to represent the connection sets in a system. Vertices of this graph are +of type `ConnectionVertex`. The connected components of a connection graph are the merged +connection sets. + +## Fields + +$(TYPEDFIELDS) +""" +struct ConnectionGraph + """ + Mapping from vertices to their integer ID. + """ + labels::Dict{ConnectionVertex, Int} + """ + Reverse mapping from integer ID to vertices. + """ + invmap::Vector{ConnectionVertex} + """ + Core data structure for storing the hypergraph. Each hyperedge is a source vertex and + has bipartite edges to the connection vertices it is incident on. + """ + graph::BipartiteGraph{Int, Nothing} +end + +""" + $(TYPEDSIGNATURES) + +Create an empty `ConnectionGraph`. +""" +function ConnectionGraph() + graph = BipartiteGraph(0, 0, Val(true)) + return ConnectionGraph(Dict{ConnectionVertex, Int}(), ConnectionVertex[], graph) +end + +function Base.show(io::IO, graph::ConnectionGraph) + printstyled(io, get(io, :cgraph_name, "ConnectionGraph"); color = :blue, bold = true) + println(io, " with ", length(graph.labels), + " vertices and ", nsrcs(graph.graph), " hyperedges") + compact = get(io, :compact, false) + for edge_i in 𝑠vertices(graph.graph) + if compact && edge_i > 5 + println(io, "⋮") + break + end + edge_idxs = 𝑠neighbors(graph.graph, edge_i) + type = graph.invmap[edge_idxs[1]].type + if type <: Union{InputVar, OutputVar} + type = "Causal" + elseif type == Equality + # otherwise it prints `ModelingToolkit.Equality` + type = "Equality" + end + printstyled(io, " ", type; bold = true, color = :yellow) + print(io, "<") + for vi in @view(edge_idxs[1:(end - 1)]) + print(io, graph.invmap[vi], ", ") + end + println(io, graph.invmap[edge_idxs[end]], ">") + end +end + +""" + $(TYPEDSIGNATURES) + +Add the given vertex to the connection graph. Return the integer ID of the added vertex. +No-op if the vertex already exists. +""" +function Graphs.add_vertex!(graph::ConnectionGraph, dst::ConnectionVertex) + j = get(graph.labels, dst, 0) + iszero(j) || return j + j = Graphs.add_vertex!(graph.graph, DST) + push!(graph.invmap, dst) + @assert length(graph.invmap) == j + graph.labels[dst] = j + return j +end + +const ConnectionGraphEdge = Union{Vector{ConnectionVertex}, Tuple{Vararg{ConnectionVertex}}} + +""" + $(TYPEDSIGNATURES) + +Add the given hyperedge to the connection graph. Adds all vertices in the given edge if +they do not exist. Returns the integer ID of the added edge. +""" +function Graphs.add_edge!(graph::ConnectionGraph, src::ConnectionGraphEdge) + i = Graphs.add_vertex!(graph.graph, SRC) + for vert in src + j = Graphs.add_vertex!(graph, vert) + Graphs.add_edge!(graph.graph, i, j) + end + return i +end + +""" + $(TYPEDEF) + +A connection state is a combination of two `ConnectionGraph`s, one for the connection sets +and the other for the domain network. The domain network is a graph of connected +subsystems. The connected components of the domain network denote connected domains that +share properties. +""" +abstract type AbstractConnectionState end + +""" + $(TYPEDEF) + +The most trivial `AbstractConnectionState`. + +## Fields + +$(TYPEDFIELDS) +""" +struct ConnectionState <: AbstractConnectionState + """ + The connection graph for connection sets. + """ + connection_graph::ConnectionGraph + """ + The connection graph for the domain network. + """ + domain_connection_graph::ConnectionGraph +end + +""" + $(TYPEDSIGNATURES) + +Create an empty `ConnectionState` with empty graphs. +""" +ConnectionState() = ConnectionState(ConnectionGraph(), ConnectionGraph()) + +function Base.show(io::IO, state::AbstractConnectionState) + printstyled(io, typeof(state); bold = true, color = :green) + println(io, " comprising of") + ctx1 = IOContext(io, :cgraph_name => "Connection Network", :compact => true) + show(ctx1, state.connection_graph) + println(io) + println(io, "And") + println(io) + ctx2 = IOContext(io, :cgraph_name => "Domain Network", :compact => true) + show(ctx2, state.domain_connection_graph) +end + +""" + $(TYPEDSIGNATURES) + +Add the given edge to the connection network. Does not affect the domain network. +""" +function add_connection_edge!(state::ConnectionState, edge::ConnectionGraphEdge) + Graphs.add_edge!(state.connection_graph, edge) + return nothing +end + +""" + $(TYPEDSIGNATURES) + +Add the given edge to the domain network. Does not affect the connection network. +""" +function add_domain_connection_edge!(state::ConnectionState, edge::ConnectionGraphEdge) + Graphs.add_edge!(state.domain_connection_graph, edge) + return nothing +end + +""" + $(TYPEDSIGNATURES) + +An `AbstractConnectionState` that is used to remove edges from the main connection state. +Transformed analysis points add to the list of removed connections, and the list of removed +connections builds this connection state. This allows ensuring that the removed connections +are not present in the final network even if they are connected multiple times. This state +also tracks which vertex in each hyperedge is the input, since the removed connections are +causal. + +## Fields + +$(TYPEDFIELDS) +""" +struct NegativeConnectionState <: AbstractConnectionState + """ + The connection graph for connection sets. + """ + connection_graph::ConnectionGraph + """ + The connection graph for the domain network. + """ + domain_connection_graph::ConnectionGraph + """ + Mapping from the integer ID of each hyperedge in `connection_graph` to the integer ID + of the "input" in that hyperedge. + """ + connection_hyperedge_inputs::Vector{Int} + """ + Mapping from the integer ID of each hyperedge in `domain_connection_graph` to the + integer ID of the "input" in that hyperedge. + """ + domain_hyperedge_inputs::Vector{Int} +end + +""" + $(TYPEDSIGNATURES) + +Create an empty `NegativeConnectionState` with empty graphs. +""" +function NegativeConnectionState() + NegativeConnectionState(ConnectionGraph(), ConnectionGraph(), Int[], Int[]) +end + +""" + $(TYPEDSIGNATURES) + +Add the given edge to the connection network. Does not affect the domain network. Assumes +that the first vertex in `edge` is the input. +""" +function add_connection_edge!(state::NegativeConnectionState, edge::ConnectionGraphEdge) + i = Graphs.add_edge!(state.connection_graph, edge) + j = state.connection_graph.labels[first(edge)] + push!(state.connection_hyperedge_inputs, j) + @assert length(state.connection_hyperedge_inputs) == i + return nothing +end + +""" + $(TYPEDSIGNATURES) + +Add the given edge to the domain network. Does not affect the connection network. Assumes +that the first vertex in `edge` is the input. +""" +function add_domain_connection_edge!( + state::NegativeConnectionState, edge::ConnectionGraphEdge) + i = Graphs.add_edge!(state.domain_connection_graph, edge) + j = state.domain_connection_graph.labels[first(edge)] + push!(state.domain_hyperedge_inputs, j) + @assert length(state.domain_hyperedge_inputs) == i + return nothing +end + +""" + $(TYPEDSIGNATURES) + +Modify `graph` such that no hyperedge is a superset of any (causal) hyerpedge in `neg_graph`. + +For each "negative" hyperedge in `neg_graph` with integer ID `neg_edge_id`, +`neg_edge_inputs[neg_edge_id]` denotes the vertex the negative hyperedge is incident on +which is considered the input of the negative hyperedge. If any hyperedge in `graph` +contains this input as well as at least one other vertex in the negative hyperedge, all +vertices common between the hyperedge and negative hyperedge are removed from the hyperedge. + +`graph` is modified in-place. Note that `graph` and `neg_graph` may not have the same +ordering of vertices, and thus all comparisons should be done by comparing the +`ConnectionVertex`. +""" +function remove_negative_connections!( + graph::ConnectionGraph, neg_graph::ConnectionGraph, neg_edge_inputs::Vector{Int}) + # _i means index in neg_graph + # _j means index in graph + + # get all edges in `graph` as bitsets + graph_hyperedgesets = map(𝑠vertices(graph.graph)) do edge_j + hyperedge_jdxs = 𝑠neighbors(graph.graph, edge_j) + return BitSet(hyperedge_jdxs) + end + + # indexes in each hyperedge to remove + idxs_to_rm = [BitSet() for _ in graph_hyperedgesets] + # iterate over negative edges and the corresponding input vertex in each edge + for (input_i, edge_i) in zip(neg_edge_inputs, 𝑠vertices(neg_graph.graph)) + # get the hyperedge as integer indexes in `neg_graph` + neg_hyperedge_idxs = 𝑠neighbors(neg_graph.graph, edge_i) + # the hyperedge as `ConnectionVar`s + neg_hyperedge = map(Base.Fix1(getindex, neg_graph.invmap), neg_hyperedge_idxs) + # The hyperedge as integer indexes in `graph` + # *j*dxs. See what I did there? + neg_hyperedge_jdxs = map(cvar -> get(graph.labels, cvar, 0), neg_hyperedge) + # the edge to remove is between variables that aren't connected, so ignore it + if any(iszero, neg_hyperedge_jdxs) + continue + end + + # The input vertex as a `ConnectionVar` + input_v = neg_graph.invmap[input_i] + # The input vertex as an index in `graph` + input_j = graph.labels[input_v] + # Iterate over hyperedges in `graph` + for edge_j in 𝑠vertices(graph.graph) + # The bitset of nodes this edge is incident on + edgeset = graph_hyperedgesets[edge_j] + # the input must be in this hyperedge + input_j in edgeset || continue + # now, if any other vertex apart from this input is also in the hyperedge + # we remove all the indices in `neg_hyperedge_jdxs` also present in this + # hyperedge + + # should_rm tracks if any other vertex apart from `input_j` is in the hyperedge + should_rm = false + # iterate over the negative hyperedge + for var_j in neg_hyperedge_jdxs + var_j == input_j && continue + # check if the variable which is not `input_j` is in the hyperedge + should_rm |= var_j in edgeset + should_rm || continue + # if there is any other variable, start removing + push!(idxs_to_rm[edge_j], var_j) + end + if should_rm + # if there was any other variable, also remove `input_j` + push!(idxs_to_rm, input_j) + end + end + end + + # for each edge and list of vertices to remove from the edge + for (edge_j, neg_vertices) in enumerate(idxs_to_rm) + for vert_j in neg_vertices + # remove those vertices + Graphs.rem_edge!(graph.graph, edge_j, vert_j) + end + end +end + +""" + $(TYPEDSIGNATURES) + +Remove negative hyperedges given by `neg_state` from the connection and domain networks of +`state`. +""" +function remove_negative_connections!( + state::ConnectionState, neg_state::NegativeConnectionState) + remove_negative_connections!(state.connection_graph, neg_state.connection_graph, + neg_state.connection_hyperedge_inputs) + remove_negative_connections!( + state.domain_connection_graph, neg_state.domain_connection_graph, + neg_state.domain_hyperedge_inputs) +end + +""" + $(TYPEDSIGNATURES) + +Return the merged connection sets in `graph` as a `Vector{Vector{ConnectionVertex}}`. These +are equivalent to the connected components of `graph`. +""" +function connectionsets(graph::ConnectionGraph) + bigraph = graph.graph + invmap = graph.invmap + + # union all of the hyperedges + disjoint_sets = IntDisjointSets(length(invmap)) + for edge_i in 𝑠vertices(bigraph) + hyperedge = 𝑠neighbors(bigraph, edge_i) + root, rest = Iterators.peel(hyperedge) + for vert in rest + union!(disjoint_sets, root, vert) + end + end + + # maps the root of a vertex in `disjoint_sets` to the index of the corresponding set + # in `vertex_sets` + root_to_set = Dict{Int, Int}() + vertex_sets = Vector{ConnectionVertex}[] + for (vert_i, vert) in enumerate(invmap) + root = find_root!(disjoint_sets, vert_i) + set_i = get!(root_to_set, root) do + push!(vertex_sets, ConnectionVertex[]) + return length(vertex_sets) + end + push!(vertex_sets[set_i], vert) + end + + return vertex_sets +end + +""" + $(TYPEDSIGNATURES) + +Return the connection sets of the connection graph and domain network. +""" +function connectionsets(state::ConnectionState) + return connectionsets(state.connection_graph), + connectionsets(state.domain_connection_graph) +end diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 5d6227d4c1..14f90ff14e 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -28,6 +28,8 @@ function connect(sys1::AbstractSystem, sys2::AbstractSystem, syss::AbstractSyste Equation(Connection(), Connection(syss)) # the RHS are connected systems end +const _debug_mode = Base.JLOptions().check_bounds == 1 + function Base.show(io::IO, c::Connection) print(io, "connect(") if c.systems isa AbstractArray || c.systems isa Tuple @@ -52,18 +54,25 @@ end isconnection(_) = false isconnection(_::Connection) = true + """ - domain_connect(sys1, sys2, syss...) + $(TYPEDSIGNATURES) Adds a domain only connection equation, through and across state equations are not generated. """ -function domain_connect(sys1, sys2, syss...) +function domain_connect(sys1::AbstractSystem, sys2::AbstractSystem, syss::AbstractSystem...) syss = (sys1, sys2, syss...) length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!") Equation(Connection(:domain), Connection(syss)) # the RHS are connected systems end -function get_connection_type(s) +""" + $(TYPEDSIGNATURES) + +Get the connection type of symbolic variable `s` from the `VariableConnectType` metadata. +Defaults to `Equality` if not present. +""" +function get_connection_type(s::Symbolic) s = unwrap(s) if iscall(s) && operation(s) === getindex s = arguments(s)[1] @@ -111,6 +120,14 @@ struct StreamConnector <: AbstractConnectorType end struct RegularConnector <: AbstractConnectorType end struct DomainConnector <: AbstractConnectorType end +""" + $(TYPEDSIGNATURES) + +Return an `AbstractConnectorType` denoting the type of connector that `sys` is. +Domain connectors have a single `Flow` unknown. Stream connectors have a single +`Flow` variable and multiple `Stream` variables. Any other type of connector is +a "regular" connector. +""" function connector_type(sys::AbstractSystem) unkvars = get_unknowns(sys) n_stream = 0 @@ -180,22 +197,15 @@ end const ConnectableSymbolicT = Union{BasicSymbolic, Num, Symbolics.Arr} -const CAUSAL_CONNECTION_ERR = """ -Only causal variables can be used in a `connect` statement. The first argument must \ -be a single output variable and all subsequent variables must be input variables. -""" - -function VariableNotOutputError(var) +function NonCausalVariableError(vars) + names = join(map(var -> " " * string(var), vars), "\n") ArgumentError(""" - $CAUSAL_CONNECTION_ERR Expected $var to be marked as an output with `[output = true]` \ - in the variable metadata. - """) -end + Only causal variables can be used in a `connect` statement. Each variable must be \ + either an input or an output. Mark a variable as input using the `[input = true]` \ + variable metadata or as an output using the `[output = true]` variable metadata. -function VariableNotInputError(var) - ArgumentError(""" - $CAUSAL_CONNECTION_ERR Expected $var to be marked an input with `[input = true]` \ - in the variable metadata. + The following variables were found to be non-causal: + $names """) end @@ -220,11 +230,10 @@ function validate_causal_variables_connection(allvars) if !allequal(allsizes) throw(ArgumentError("Expected all connection variables to have the same size. Got variables $allvars with sizes $allsizes respectively.")) end - isoutput(var1) || throw(VariableNotOutputError(var1)) - isinput(var2) || throw(VariableNotInputError(var2)) - for var in vars - isinput(var) || throw(VariableNotInputError(var)) + non_causal_variables = filter(allvars) do var + !isinput(var) && !isoutput(var) end + isempty(non_causal_variables) || throw(NonCausalVariableError(non_causal_variables)) end """ @@ -246,15 +255,11 @@ function connect(var1::ConnectableSymbolicT, var2::ConnectableSymbolicT, return Equation(Connection(), Connection(map(SymbolicWithNameof, unwrap.(allvars)))) end -function flowvar(sys::AbstractSystem) - sts = get_unknowns(sys) - for s in sts - vtype = get_connection_type(s) - vtype === Flow && return s - end - error("There in no flow variable in $(nameof(sys))") -end +""" + $(METHODLIST) +Add all `instream(..)` expressions to `set`. +""" function collect_instream!(set, eq::Equation) collect_instream!(set, eq.lhs) | collect_instream!(set, eq.rhs) end @@ -297,6 +302,12 @@ mydiv(num, den) = end @register_symbolic mydiv(n, d) +""" + $(TYPEDSIGNATURES) + +Return a function which checks whether the connector (system) passed to it is an outside +connector of `sys`. The function can also be given the name of a system as a `Symbol`. +""" function generate_isouter(sys::AbstractSystem) outer_connectors = Symbol[] for s in get_systems(sys) @@ -309,82 +320,14 @@ function generate_isouter(sys::AbstractSystem) isconnector(sys) || error("$s is not a connector!") idx = findfirst(isequal(NAMESPACE_SEPARATOR), s) parent_name = Symbol(idx === nothing ? s : s[1:prevind(s, idx)]) - parent_name in outer_connectors + isouter(parent_name) + end + function isouter(name::Symbol)::Bool + return name in outer_connectors end end end -struct LazyNamespace - namespace::Union{Nothing, AbstractSystem} - sys::Any -end - -_getname(::Nothing) = nothing -_getname(sys) = nameof(sys) -Base.copy(l::LazyNamespace) = renamespace(_getname(l.namespace), l.sys) -Base.nameof(l::LazyNamespace) = renamespace(_getname(l.namespace), nameof(l.sys)) - -struct ConnectionElement - sys::LazyNamespace - v::Any - isouter::Bool - h::UInt -end -function _hash_impl(sys, v, isouter) - hashcore = hash(nameof(sys)::Symbol) ⊻ hash(getname(v)::Symbol) - hashouter = isouter ? hash(true) : hash(false) - hashcore ⊻ hashouter -end -function ConnectionElement(sys::LazyNamespace, v, isouter::Bool) - ConnectionElement(sys, v, isouter, _hash_impl(sys, v, isouter)) -end -Base.nameof(l::ConnectionElement) = renamespace(nameof(l.sys), getname(l.v)) -Base.isequal(l1::ConnectionElement, l2::ConnectionElement) = l1 == l2 -function Base.:(==)(l1::ConnectionElement, l2::ConnectionElement) - l1.isouter == l2.isouter && nameof(l1.sys) == nameof(l2.sys) && isequal(l1.v, l2.v) -end - -const _debug_mode = Base.JLOptions().check_bounds == 1 - -function Base.show(io::IO, c::ConnectionElement) - @unpack sys, v, isouter = c - print(io, nameof(sys), ".", v, "::", isouter ? "outer" : "inner") -end - -function Base.hash(e::ConnectionElement, salt::UInt) - if _debug_mode - @assert e.h === _hash_impl(e.sys, e.v, e.isouter) - end - e.h ⊻ salt -end -namespaced_var(l::ConnectionElement) = unknowns(l, l.v) -unknowns(l::ConnectionElement, v) = unknowns(copy(l.sys), v) - -function withtrueouter(e::ConnectionElement) - e.isouter && return e - # we undo the xor - newhash = (e.h ⊻ hash(false)) ⊻ hash(true) - ConnectionElement(e.sys, e.v, true, newhash) -end - -struct ConnectionSet - set::Vector{ConnectionElement} # namespace.sys, var, isouter -end -ConnectionSet() = ConnectionSet(ConnectionElement[]) -Base.copy(c::ConnectionSet) = ConnectionSet(copy(c.set)) -Base.:(==)(a::ConnectionSet, b::ConnectionSet) = a.set == b.set -Base.sort(a::ConnectionSet) = ConnectionSet(sort(a.set, by = string)) - -function Base.show(io::IO, c::ConnectionSet) - print(io, "<") - for i in 1:(length(c.set) - 1) - @unpack sys, v, isouter = c.set[i] - print(io, nameof(sys), ".", v, "::", isouter ? "outer" : "inner", ", ") - end - @unpack sys, v, isouter = last(c.set) - print(io, nameof(sys), ".", v, "::", isouter ? "outer" : "inner", ">") -end - @noinline function connection_error(ss) error("Different types of connectors are in one connection statement: <$(map(nameof, ss))>") end @@ -406,384 +349,460 @@ function ori(sys) end end +""" +Connection type used in `ConnectionVertex` for a causal input variable. `I` is an object +that can be passed to `getindex` as an index denoting the index in the variable for +causal array variables. For non-array variables this should be `1`. +""" +abstract type InputVar{I} end +""" +Connection type used in `ConnectionVertex` for a causal output variable. `I` is an object +that can be passed to `getindex` as an index denoting the index in the variable for +causal array variables. For non-array variables this should be `1`. +""" +abstract type OutputVar{I} end + +""" + $(METHODLIST) + +Get the contained index in an `InputVar` or `OutputVar` type. +""" +index_from_type(::Type{InputVar{I}}) where {I} = I +index_from_type(::Type{OutputVar{I}}) where {I} = I + +""" + $(TYPEDSIGNATURES) + +Chain `getproperty` calls on sys in the order given by `names` and return the unwrapped +result. +""" +function iterative_getproperty(sys::AbstractSystem, names::AbstractVector{Symbol}) + # we don't want to namespace the first time + result = toggle_namespacing(sys, false) + for name in names + result = getproperty(result, name) + end + return unwrap(result) +end + """ $(TYPEDSIGNATURES) -Populate `connectionsets` with connections between the connectors `ss`, all of which are -namespaced by `namespace`. +Return the variable/subsystem of `sys` referred to by vertex `vert`. +""" +function variable_from_vertex(sys::AbstractSystem, vert::ConnectionVertex) + value = iterative_getproperty(sys, vert.name) + value isa AbstractSystem && return value + vert.type <: Union{InputVar, OutputVar} || return value + # index possibly array causal variable + unwrap(wrap(value)[index_from_type(vert.type)]) +end + +""" + $(TYPEDSIGNATURES) + +Given `connected`, the list of connected variables/systems, generate the appropriate +connection sets and add them to `connection_state`. Update both the connection network and +domain network as necessary. `namespace` is the path from the root system to the system in +which the [`connect`](@ref) equation containing `connected` is located. `isouter` is the +function returned from [`generate_isouter`](@ref) for the system referred to by +`namespace`. -# Keyword Arguments -- `ignored_connects`: A tuple of the systems and variables for which connections should be - ignored. Of the format returned from `as_hierarchy`. -- `namespaced_ignored_systems`: The `from_hierarchy` versions of entries in - `ignored_connects[1]`, purely to avoid unnecessary recomputation. +`namespace` must not contain the name of the root system. """ -function connection2set!(connectionsets, namespace, ss, isouter; - ignored_systems = HierarchySystemT[], ignored_variables = HierarchyVariableT[]) - ns_ignored_systems = from_hierarchy.(ignored_systems) - ns_ignored_variables = from_hierarchy.(ignored_variables) - # ignore specified systems - ss = filter(ss) do s - !any(x -> nameof(x) == nameof(s), ns_ignored_systems) +function generate_connectionsets!(connection_state::AbstractConnectionState, + namespace::Vector{Symbol}, connected, isouter) + initial_len = length(namespace) + _generate_connectionsets!(connection_state, namespace, connected, isouter) + # Enforce postcondition as a sanity check that the namespacing is implemented correctly + length(namespace) == initial_len || throw(NotPossibleError()) + return nothing +end + +function _generate_connectionsets!(connection_state::AbstractConnectionState, + namespace::Vector{Symbol}, + connected_vars::Union{ + AbstractVector{SymbolicWithNameof}, Tuple{Vararg{SymbolicWithNameof}}}, + isouter) + # unwrap the `SymbolicWithNameof` into the contained symbolic variables. + connected_vars = map(x -> x.var, connected_vars) + _generate_connectionsets!(connection_state, namespace, connected_vars, isouter) +end + +function _generate_connectionsets!(connection_state::AbstractConnectionState, + namespace::Vector{Symbol}, + connected_vars::Union{ + AbstractVector{<:BasicSymbolic}, Tuple{Vararg{BasicSymbolic}}}, + isouter) + # NOTE: variable connections don't populate the domain network + + # wrap to be able to call `eachindex` on a non-array variable + representative = wrap(first(connected_vars)) + # all of them have the same size, but may have different axes/shape + # so we iterate over `eachindex(eachindex(..))` since that is identical for all + for sz_i in eachindex(eachindex(representative)) + hyperedge = map(connected_vars) do var + var = unwrap(var) + var_ns = namespace_hierarchy(getname(var)) + i = eachindex(wrap(var))[sz_i] + + is_input = isinput(var) + is_output = isoutput(var) + if is_input && is_output + names = join(string.(connected_vars), ", ") + throw(ArgumentError(""" + Variable $var in connection `connect($names)` is both input and output. + """)) + elseif is_input + type = InputVar{i} + elseif is_output + type = OutputVar{i} + else + names = join(string.(connected_vars), ", ") + throw(ArgumentError(""" + Variable $var in connection `connect($names)` is neither input nor output. + """)) + end + + return ConnectionVertex( + [namespace; var_ns], length(var_ns) == 1 || isouter(var_ns[1]), type) + end + add_connection_edge!(connection_state, hyperedge) end - # `ignored_variables` for each `s` in `ss` - corresponding_ignored_variables = map( - Base.Fix2(ignored_systems_for_subsystem, ignored_variables), ss) - corresponding_namespaced_ignored_variables = map( - Broadcast.BroadcastFunction(from_hierarchy), corresponding_ignored_variables) - - regular_ss = [] - domain_ss = nothing - for s in ss +end + +function _generate_connectionsets!(connection_state::AbstractConnectionState, + namespace::Vector{Symbol}, + systems::Union{AbstractVector{<:AbstractSystem}, Tuple{Vararg{AbstractSystem}}}, + isouter) + regular_systems = System[] + domain_system = nothing + for s in systems if is_domain_connector(s) - if domain_ss === nothing - domain_ss = s + if domain_system === nothing + domain_system = s else - names = join(map(string ∘ nameof, ss), ",") + names = join(map(string ∘ nameof, systems), ",") error("connect($names) contains multiple source domain connectors. There can only be one!") end else - push!(regular_ss, s) + push!(regular_systems, s) end end - T = ConnectionElement - @assert !isempty(regular_ss) - ss = regular_ss - # domain connections don't generate any equations - if domain_ss !== nothing - cset = ConnectionElement[] - dv = only(unknowns(domain_ss)) - for (i, s) in enumerate(ss) - sts = unknowns(s) - io = isouter(s) - _ignored_variables = corresponding_ignored_variables[i] - _namespaced_ignored_variables = corresponding_namespaced_ignored_variables[i] + + @assert !isempty(regular_systems) + + systems = regular_systems + # There is a domain being connected here. In such a case, we only connect the + # flow variable common between the domain setter and all other connectors in the + # normal connection graph. The domain graph connects all these subsystems. + if domain_system !== nothing + hyperedge = ConnectionVertex[] + domain_hyperedge = ConnectionVertex[] + sizehint!(hyperedge, length(systems) + 1) + sizehint!(domain_hyperedge, length(systems) + 1) + + dv = only(unknowns(domain_system)) + push!(namespace, nameof(domain_system)) + dv_vertex = ConnectionVertex(namespace, dv, false) + domain_vertex = ConnectionVertex(namespace) + pop!(namespace) + + push!(domain_hyperedge, domain_vertex) + push!(hyperedge, dv_vertex) + + for (i, sys) in enumerate(systems) + sts = unknowns(sys) + sys_is_outer = isouter(sys) + + # add this system to the namespace so all vertices created from its unknowns + # are properly namespaced + sysname = nameof(sys) + sys_ns = namespace_hierarchy(sysname) + append!(namespace, sys_ns) for v in sts vtype = get_connection_type(v) + # ignore all non-flow vertices in connectors (vtype === Flow && isequal(v, dv)) || continue - any(isequal(v), _namespaced_ignored_variables) && continue - push!(cset, T(LazyNamespace(namespace, domain_ss), dv, false)) - push!(cset, T(LazyNamespace(namespace, s), v, io)) + + vertex = ConnectionVertex(namespace, v, sys_is_outer) + # vertices in the domain graph are systems with isouter=true and type=Flow + sys_vertex = ConnectionVertex(namespace) + push!(hyperedge, vertex) + push!(domain_hyperedge, sys_vertex) end + # remember to remove the added namespace! + foreach(_ -> pop!(namespace), sys_ns) end - @assert length(cset) > 0 - push!(connectionsets, ConnectionSet(cset)) - return connectionsets + @assert length(hyperedge) > 1 + @assert length(domain_hyperedge) == length(hyperedge) + + add_connection_edge!(connection_state, hyperedge) + add_domain_connection_edge!(connection_state, domain_hyperedge) + return end - s1 = first(ss) - sts1v = unknowns(s1) - if isframe(s1) # Multibody - O = ori(s1) + sys1 = first(systems) + sys1_dvs = unknowns(sys1) + # Add 9 orientation variables if connection is between multibody frames + if isframe(sys1) # Multibody + O = ori(sys1) orientation_vars = Symbolics.unwrap.(collect(vec(O.R))) - sts1v = [sts1v; orientation_vars] + sys1_dvs = [sys1_dvs; orientation_vars] end - sts1 = Set(sts1v) - num_unknowns = length(sts1) - - # we don't filter here because `csets` should include the full set of unknowns. - # not all of `ss` will have the same (or any) variables filtered so the ones - # that aren't should still go in the right cset. Since `sts1` is only used for - # validating that all systems being connected are of the same type, it has - # unfiltered entries. - csets = [T[] for _ in 1:num_unknowns] # Add 9 orientation variables if connection is between multibody frames - for (i, s) in enumerate(ss) - unknown_vars = unknowns(s) - if isframe(s) # Multibody - O = ori(s) + sys1_dvs_set = Set(sys1_dvs) + num_unknowns = length(sys1_dvs) + + # We first build sets of all vertices that are connected together + var_sets = [ConnectionVertex[] for _ in 1:num_unknowns] + domain_hyperedge = ConnectionVertex[] + for (i, sys) in enumerate(systems) + unknown_vars = unknowns(sys) + # Add 9 orientation variables if connection is between multibody frames + if isframe(sys) # Multibody + O = ori(sys) orientation_vars = Symbolics.unwrap.(vec(O.R)) unknown_vars = [unknown_vars; orientation_vars] end - i != 1 && ((num_unknowns == length(unknown_vars) && - all(Base.Fix2(in, sts1), unknown_vars)) || - connection_error(ss)) - io = isouter(s) - # don't `filter!` here so that `j` points to the correct cset regardless of - # which variables are filtered. + # Error if any subsequent systems do not have the same number of unknowns + # or have unknowns not in the others. + if i != 1 && + (num_unknowns != length(unknown_vars) || any(!in(sys1_dvs_set), unknown_vars)) + connection_error(systems) + end + # add this system to the namespace so all vertices created from its unknowns + # are properly namespaced + sysname = nameof(sys) + sys_ns = namespace_hierarchy(sysname) + append!(namespace, sys_ns) + sys_is_outer = isouter(sys) for (j, v) in enumerate(unknown_vars) - any(isequal(v), corresponding_namespaced_ignored_variables[i]) && continue - push!(csets[j], T(LazyNamespace(namespace, s), v, io)) + push!(var_sets[j], ConnectionVertex(namespace, v, sys_is_outer)) end + domain_vertex = ConnectionVertex(namespace) + push!(domain_hyperedge, domain_vertex) + # remember to remove the added namespace! + foreach(_ -> pop!(namespace), sys_ns) end - for cset in csets - v = first(cset).v - vtype = get_connection_type(v) - if domain_ss !== nothing && vtype === Flow && - (dv = only(unknowns(domain_ss)); isequal(v, dv)) - push!(cset, T(LazyNamespace(namespace, domain_ss), dv, false)) - end - for k in 2:length(cset) - vtype === get_connection_type(cset[k].v) || connection_error(ss) + for var_set in var_sets + # all connected variables should have the same type + if !allequal(Iterators.map(cvert -> cvert.type, var_set)) + connection_error(systems) end - push!(connectionsets, ConnectionSet(cset)) + # add edges + add_connection_edge!(connection_state, var_set) end -end - -function generate_connection_set( - sys::AbstractSystem; scalarize = false) - connectionsets = ConnectionSet[] - domain_csets = ConnectionSet[] - sys = generate_connection_set!( - connectionsets, domain_csets, sys, scalarize, nothing, ignored_connections(sys)) - csets = merge(connectionsets) - domain_csets = merge([csets; domain_csets], true) - - sys, (csets, domain_csets) + add_domain_connection_edge!(connection_state, domain_hyperedge) end """ $(TYPEDSIGNATURES) -For a list of `systems` in a connect equation, return the subset of it to ignore (as a -list of hierarchical systems) based on `ignored_system_aps`, the analysis points to be -ignored. All analysis points in `ignored_system_aps` must contain systems (connectors) -as their input/outputs. +Generate the merged connection sets and connected domain sets for system `sys`. Also +removes all `connect` equations in `sys`. Return the modified system and a tuple of the +connection sets and domain sets. Also scalarizes array equations in the system. """ -function systems_to_ignore(ignored_system_aps::Vector{HierarchyAnalysisPointT}, - systems::Union{Vector{S}, Tuple{Vararg{S}}}) where {S <: AbstractSystem} - to_ignore = HierarchySystemT[] - for ap in ignored_system_aps - # if `systems` contains the input of the AP, ignore any outputs of the AP present in it. - isys_hierarchy = HierarchySystemT([ap[1].input; @view ap[2:end]]) - isys = from_hierarchy(isys_hierarchy) - any(x -> nameof(x) == nameof(isys), systems) || continue - - for outsys in ap[1].outputs - osys_hierarchy = HierarchySystemT([outsys; @view ap[2:end]]) - osys = from_hierarchy(osys_hierarchy) - any(x -> nameof(x) == nameof(osys), systems) || continue - push!(to_ignore, HierarchySystemT(osys_hierarchy)) - end - end +function generate_connection_set(sys::AbstractSystem) + # generate the states + connection_state = ConnectionState() + negative_connection_state = NegativeConnectionState() + # the root system isn't added to the namespace, which we handle by not namespacing it + sys = toggle_namespacing(sys, false) + sys = generate_connection_set!( + connection_state, negative_connection_state, sys, Symbol[]) + remove_negative_connections!(connection_state, negative_connection_state) - return to_ignore + return sys, connectionsets(connection_state) end """ $(TYPEDSIGNATURES) -For a list of `systems` in a connect equation, return the subset of their variables to -ignore (as a list of hierarchical variables) based on `ignored_system_aps`, the analysis -points to be ignored. All analysis points in `ignored_system_aps` must contain variables -as their input/outputs. +Appropriately handle the equation `eq` depending on whether it is a normal or connection +equation. For normal equations, it is expected that `eqs` is a buffer to which the equation +can be pushed, unmodified. Connection equations update the given `state`. The equation is +present at the path in the hierarchical system given by `namespace`. `isouter` is the +function returned from `generate_isouter`. """ -function variables_to_ignore(ignored_variable_aps::Vector{HierarchyAnalysisPointT}, - systems::Union{Vector{S}, Tuple{Vararg{S}}}) where {S <: AbstractSystem} - to_ignore = HierarchyVariableT[] - for ap in ignored_variable_aps - ivar_hierarchy = HierarchyVariableT([ap[1].input; @view ap[2:end]]) - ivar = from_hierarchy(ivar_hierarchy) - any(x -> any(isequal(ivar), renamespace.((x,), unknowns(x))), systems) || continue - - for outvar in ap[1].outputs - ovar_hierarchy = HierarchyVariableT([as_hierarchy(outvar); @view ap[2:end]]) - ovar = from_hierarchy(ovar_hierarchy) - any(x -> any(isequal(ovar), renamespace.((x,), unknowns(x))), systems) || - continue - push!(to_ignore, HierarchyVariableT(ovar_hierarchy)) +function handle_maybe_connect_equation!(eqs, state::AbstractConnectionState, + eq::Equation, namespace::Vector{Symbol}, isouter) + lhs = eq.lhs + rhs = eq.rhs + + if !(lhs isa Connection) + # split connections and equations + if eq.lhs isa AbstractArray || eq.rhs isa AbstractArray + append!(eqs, Symbolics.scalarize(eq)) + else + push!(eqs, eq) end + return end - return to_ignore -end -""" - $(TYPEDSIGNATURES) - -For a list of variables `vars` in a connect equation, return the subset of them ignore -(as a list of symbolic variables) based on `ignored_system_aps`, the analysis points to -be ignored. All analysis points in `ignored_system_aps` must contain variables as their -input/outputs. -""" -function variables_to_ignore(ignored_variable_aps::Vector{HierarchyAnalysisPointT}, - vars::Union{Vector{S}, Tuple{Vararg{S}}}) where {S <: BasicSymbolic} - to_ignore = eltype(vars)[] - for ap in ignored_variable_aps - ivar_hierarchy = HierarchyVariableT([ap[1].input; @view ap[2:end]]) - ivar = from_hierarchy(ivar_hierarchy) - any(isequal(ivar), vars) || continue - - for outvar in ap[1].outputs - ovar_hierarchy = HierarchyVariableT([outvar; @view ap[2:end]]) - ovar = from_hierarchy(ovar_hierarchy) - any(isequal(ovar), vars) || continue - push!(to_ignore, ovar) + if get_systems(lhs) === :domain + # This is a domain connection, so we only update the domain connection graph + hyperedge = map(get_systems(rhs)) do sys + sys isa AbstractSystem || error("Domain connections can only connect systems!") + sysname = nameof(sys) + sys_ns = namespace_hierarchy(sysname) + append!(namespace, sys_ns) + vertex = ConnectionVertex(namespace) + foreach(_ -> pop!(namespace), sys_ns) + return vertex end + add_domain_connection_edge!(state, hyperedge) + else + connected_systems = get_systems(rhs) + generate_connectionsets!(state, namespace, connected_systems, isouter) end - - return to_ignore + return nothing end """ $(TYPEDSIGNATURES) -Generate connection sets from `connect` equations. +Generate the appropriate connection sets from `connect` equations present in the +hierarchical system `sys`. This is a recursive function that descends the hierarchy. If +`sys` is the root system, then `does_namespacing(sys)` must be `false` and `namespace` +should be empty. It is essential that the traversal is preorder. -# Arguments +## Arguments -- `connectionsets` is the list of connection sets to be populated by recursively - descending `sys`. -- `domain_csets` is the list of connection sets for domain connections. -- `sys` is the system whose equations are to be searched. -- `namespace` is a system representing the namespace in which `sys` exists, or `nothing` - for no namespace (if `sys` is top-level). +- `connection_state`: The connection state keeping track of the connection network and the + domain network. +- `negative_connection_state`: The connection state that tracks connections removed by + analysis point transformations. These removed connections are stored in the + `ignored_connections` field of the system. +- `namespace`: The path of names from the root system to the current system. This should + not include the name of the root system. """ -function generate_connection_set!(connectionsets, domain_csets, - sys::AbstractSystem, scalarize, namespace = nothing, - ignored_connects = (HierarchyAnalysisPointT[], HierarchyAnalysisPointT[])) +function generate_connection_set!(connection_state::ConnectionState, + negative_connection_state::NegativeConnectionState, + sys::AbstractSystem, namespace::Vector{Symbol}) + initial_len = length(namespace) + res = _generate_connection_set!( + connection_state, negative_connection_state, sys, namespace) + # Enforce postcondition as a sanity check that the recursion is implemented correctly + length(namespace) == initial_len || throw(NotPossibleError()) + return res +end + +function _generate_connection_set!(connection_state::ConnectionState, + negative_connection_state::NegativeConnectionState, + sys::AbstractSystem, namespace::Vector{Symbol}) + # This function recurses down the system tree. Each system adds its name and pops + # it before returning. We don't add the root system, which is handled by assuming + # it doesn't do namespacing. + does_namespacing(sys) && push!(namespace, nameof(sys)) subsys = get_systems(sys) - ignored_system_aps, ignored_variable_aps = ignored_connects isouter = generate_isouter(sys) eqs′ = get_eqs(sys) eqs = Equation[] - cts = [] # connections - domain_cts = [] # connections - extra_unknowns = [] + # generate connection equations and separate out non-connection equations for eq in eqs′ - lhs = eq.lhs - rhs = eq.rhs - - # causal variable connections will be expanded before we get here, - # but this guard is useful for `n_expanded_connection_equations`. - is_causal_variable_connection(rhs) && continue - if lhs isa Connection && get_systems(lhs) === :domain - connected_systems = get_systems(rhs) - connection2set!(domain_csets, namespace, connected_systems, isouter; - ignored_systems = systems_to_ignore( - ignored_system_aps, connected_systems), - ignored_variables = variables_to_ignore( - ignored_variable_aps, connected_systems)) - elseif isconnection(rhs) - push!(cts, get_systems(rhs)) - else - # split connections and equations - if eq.lhs isa AbstractArray || eq.rhs isa AbstractArray - append!(eqs, Symbolics.scalarize(eq)) - else - push!(eqs, eq) - end - end + handle_maybe_connect_equation!(eqs, connection_state, eq, namespace, isouter) end - # all connectors are eventually inside connectors. - T = ConnectionElement - # only generate connection sets for systems that are not ignored + # go through the removed connections and update the negative graph + for conn in something(get_ignored_connections(sys), ()) + eq = Equation(Connection(), conn) + # there won't be any standard equations, so we can pass `nothing` instead of + # `eqs`. + handle_maybe_connect_equation!( + nothing, negative_connection_state, eq, namespace, isouter) + end + + # all connectors are eventually inside connectors, and all flow variables + # need at least a singleton connectionset (hyperedge) with the inside variant for s in subsys isconnector(s) || continue is_domain_connector(s) && continue + push!(namespace, nameof(s)) for v in unknowns(s) Flow === get_connection_type(v) || continue - push!(connectionsets, ConnectionSet([T(LazyNamespace(namespace, s), v, false)])) + add_connection_edge!(connection_state, (ConnectionVertex(namespace, v, false),)) end + pop!(namespace) end - for ct in cts - connection2set!(connectionsets, namespace, ct, isouter; - ignored_systems = systems_to_ignore(ignored_system_aps, ct), - ignored_variables = variables_to_ignore(ignored_variable_aps, ct)) - end - - # pre order traversal - if !isempty(extra_unknowns) - @set! sys.unknowns = [get_unknowns(sys); extra_unknowns] + # recurse down the hierarchy + @set! sys.systems = map(subsys) do s + generate_connection_set!(connection_state, negative_connection_state, s, namespace) end - @set! sys.systems = map( - s -> generate_connection_set!(connectionsets, domain_csets, s, - scalarize, renamespace(namespace, s), - ignored_systems_for_subsystem.((s,), ignored_connects)), - subsys) @set! sys.eqs = eqs + # Remember to pop the name at the end! + does_namespacing(sys) && pop!(namespace) + return sys end """ $(TYPEDSIGNATURES) -Given a subsystem `subsys` of a parent system and a list of systems (variables) to be -ignored by `generate_connection_set!` (`expand_variable_connections`), filter -`ignored_systems` to only include those present in the subtree of `subsys` and update -their hierarchy to not include `subsys`. +Generate connection equations for the connection sets given by `csets`. This does not +handle stream connections. Return the generated equations and the stream connection sets. """ -function ignored_systems_for_subsystem( - subsys::AbstractSystem, ignored_systems::Vector{<:Union{ - HierarchyT, HierarchyAnalysisPointT}}) - result = eltype(ignored_systems)[] - # in case `subsys` is namespaced, get its hierarchy and compare suffixes - # instead of the just the last element - suffix = reverse!(namespace_hierarchy(nameof(subsys))) - N = length(suffix) - for igsys in ignored_systems - if length(igsys) > N && igsys[(end - N + 1):end] == suffix - push!(result, copy(igsys)) - for i in 1:N - pop!(result[end]) - end - end - end - return result -end - -function Base.merge(csets::AbstractVector{<:ConnectionSet}, allouter = false) - ele2idx = Dict{ConnectionElement, Int}() - idx2ele = ConnectionElement[] - union_find = IntDisjointSets(0) - prev_id = Ref(-1) - for cset in csets, (j, s) in enumerate(cset.set) - v = allouter ? withtrueouter(s) : s - id = let ele2idx = ele2idx, idx2ele = idx2ele - get!(ele2idx, v) do - push!(idx2ele, v) - id = length(idx2ele) - id′ = push!(union_find) - @assert id == id′ - id - end - end - # isequal might not be equal? lol - if v.sys.namespace !== nothing - idx2ele[id] = v - end - if j > 1 - union!(union_find, prev_id[], id) - end - prev_id[] = id - end - id2set = Dict{Int, Int}() - merged_set = ConnectionSet[] - for (id, ele) in enumerate(idx2ele) - rid = find_root!(union_find, id) - set_idx = get!(id2set, rid) do - set = ConnectionSet() - push!(merged_set, set) - length(merged_set) - end - push!(merged_set[set_idx].set, ele) - end - merged_set -end - -function generate_connection_equations_and_stream_connections(csets::AbstractVector{ - <:ConnectionSet, -}) +function generate_connection_equations_and_stream_connections( + sys::AbstractSystem, csets::Vector{Vector{ConnectionVertex}}) eqs = Equation[] - stream_connections = ConnectionSet[] + stream_connections = Vector{ConnectionVertex}[] for cset in csets - v = cset.set[1].v - v = getparent(v, v) - vtype = get_connection_type(v) - if vtype === Stream + cvert = cset[1] + var = variable_from_vertex(sys, cvert)::BasicSymbolic + vtype = cvert.type + if vtype <: Union{InputVar, OutputVar} + inner_output = nothing + outer_input = nothing + for cvert in cset + if cvert.isouter && cvert.type <: InputVar + if outer_input !== nothing + error(""" + Found two outer input connectors `$outer_input` and `$cvert` in the + same connection set. + """) + end + outer_input = cvert + elseif !cvert.isouter && cvert.type <: OutputVar + if inner_output !== nothing + error(""" + Found two inner output connectors `$inner_output` and `$cvert` in + the same connection set. + """) + end + inner_output = cvert + end + end + root, rest = Iterators.peel(cset) + root_var = variable_from_vertex(sys, root) + for cvert in rest + var = variable_from_vertex(sys, cvert) + push!(eqs, root_var ~ var) + end + elseif vtype === Stream push!(stream_connections, cset) elseif vtype === Flow - rhs = 0 - for ele in cset.set - v = namespaced_var(ele) - rhs += ele.isouter ? -v : v + # arrays have to be broadcasted to be added/subtracted/negated which leads + # to bad-looking equations. Just generate scalar equations instead since + # mtkcompile will scalarize anyway. + representative = variable_from_vertex(sys, cset[1]) + # each variable can have different axes, but they all have the same size + for sz_i in eachindex(eachindex(wrap(representative))) + rhs = 0 + for cvert in cset + # all of this wrapping/unwrapping is necessary because the relevant + # methods are defined on `Arr/Num` and not `BasicSymbolic`. + v = variable_from_vertex(sys, cvert)::BasicSymbolic + idxs = eachindex(wrap(v)) + v = unwrap(wrap(v)[idxs[sz_i]]) + rhs += cvert.isouter ? unwrap(-wrap(v)) : v + end + push!(eqs, 0 ~ rhs) end - push!(eqs, 0 ~ rhs) else # Equality - base = namespaced_var(cset.set[1]) - for i in 2:length(cset.set) - v = namespaced_var(cset.set[i]) + base = variable_from_vertex(sys, cset[1]) + for i in 2:length(cset) + v = variable_from_vertex(sys, cset[i]) push!(eqs, base ~ v) end end @@ -791,295 +810,212 @@ function generate_connection_equations_and_stream_connections(csets::AbstractVec eqs, stream_connections end -function domain_defaults(sys, domain_csets) - def = Dict() - for c in domain_csets - cset = c.set - idx = findfirst(s -> is_domain_connector(s.sys.sys), cset) - idx === nothing && continue - s = cset[idx] - root = s.sys - s_def = defaults(root.sys) - for (j, m) in enumerate(cset) - if j == idx - continue - elseif is_domain_connector(m.sys.sys) - error("Domain sources $(nameof(root)) and $(nameof(m)) are connected!") - else - ns_s_def = Dict(unknowns(m.sys.sys, n) => n for (n, v) in s_def) - for p in parameters(m.sys.namespace) - d_p = get(ns_s_def, p, nothing) - if d_p !== nothing - def[parameters(m.sys.namespace, p)] = parameters(s.sys.namespace, - parameters(s.sys.sys, - d_p)) - end - end - end - end - end - def -end - """ $(TYPEDSIGNATURES) -Recursively descend through the hierarchy of `sys` and expand all connection equations -of causal variables. Return the modified system. +Generate the defaults for parameters in the domain sets given by `domain_csets`. """ -function expand_variable_connections(sys::AbstractSystem; ignored_variables = nothing) - if ignored_variables === nothing - ignored_variables = ignored_connections(sys)[2] - end - eqs = copy(get_eqs(sys)) - valid_idxs = trues(length(eqs)) - additional_eqs = Equation[] - - for (i, eq) in enumerate(eqs) - eq.lhs isa Connection || continue - connection = eq.rhs - elements = get_systems(connection) - is_causal_variable_connection(connection) || continue - - valid_idxs[i] = false - elements = map(x -> x.var, elements) - to_ignore = variables_to_ignore(ignored_variables, elements) - elements = setdiff(elements, to_ignore) - outvar = first(elements) - for invar in Iterators.drop(elements, 1) - push!(additional_eqs, outvar ~ invar) +function domain_defaults( + sys::AbstractSystem, domain_csets::Vector{Vector{ConnectionVertex}}) + defs = Dict() + for cset in domain_csets + systems = map(Base.Fix1(variable_from_vertex, sys), cset) + @assert all(x -> x isa AbstractSystem, systems) + idx = findfirst(is_domain_connector, systems) + idx === nothing && continue + domain_sys = systems[idx] + # note that these will not be namespaced with `domain_sys`. + domain_defs = defaults(domain_sys) + for (j, csys) in enumerate(systems) + j == idx && continue + if is_domain_connector(csys) + throw(ArgumentError(""" + Domain sources $(nameof(domain_sys)) and $(nameof(csys)) are connected! + """)) + end + for par in parameters(csys) + defval = get(domain_defs, par, nothing) + defval === nothing && continue + defs[parameters(csys, par)] = parameters(domain_sys, par) + end end end - eqs = [eqs[valid_idxs]; additional_eqs] - subsystems = map(get_systems(sys)) do subsys - expand_variable_connections(subsys; - ignored_variables = ignored_systems_for_subsystem(subsys, ignored_variables)) - end - @set! sys.eqs = eqs - @set! sys.systems = subsystems - return sys + return defs end """ - function expand_connections(sys::AbstractSystem) + $(TYPEDSIGNATURES) Given a hierarchical system with [`connect`](@ref) equations, expand the connection -equations and return the new system. +equations and return the new system. `tol` is the tolerance for handling the singularities +in stream connection equations that happen when a flow variable approaches zero. """ -function expand_connections(sys::AbstractSystem; - debug = false, tol = 1e-10, scalarize = true) +function expand_connections(sys::AbstractSystem; tol = 1e-10) + # turn analysis points into standard connection equations sys = remove_analysis_points(sys) - sys = expand_variable_connections(sys) - sys, (csets, domain_csets) = generate_connection_set(sys; scalarize) - ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets) - _sys = expand_instream(instream_csets, sys; debug = debug, tol = tol) - sys = flatten(sys, true) - @set! sys.eqs = [equations(_sys); ceqs] + # generate the connection sets + sys, (csets, domain_csets) = generate_connection_set(sys) + # generate equations, and stream equations + ceqs, instream_csets = generate_connection_equations_and_stream_connections(sys, csets) + stream_eqs, instream_subs = expand_instream(instream_csets, sys; tol = tol) + + eqs = [equations(sys); ceqs; stream_eqs] + # substitute `instream(..)` expressions with their new values + for i in eachindex(eqs) + eqs[i] = fixpoint_sub(eqs[i], instream_subs; maxiters = length(instream_subs)) + end + # get the defaults for domain networks d_defs = domain_defaults(sys, domain_csets) + # build the new system + sys = flatten(sys, true) + @set! sys.eqs = eqs @set! sys.defaults = merge(get_defaults(sys), d_defs) end -function unnamespace(root, namespace) - root === nothing && return namespace - root = string(root) - namespace = string(namespace) - if length(namespace) > length(root) - @assert root == namespace[1:length(root)] - Symbol(namespace[nextind(namespace, length(root)):end]) - else - @assert root == namespace - nothing +""" + $(TYPEDSIGNATURES) + +Given a connection vertex `cvert` referring to a variable in a connector in `sys`, return +the flow variable in that connector. +""" +function get_flowvar(sys::AbstractSystem, cvert::ConnectionVertex) + parent_names = @view cvert.name[1:(end - 1)] + parent_sys = iterative_getproperty(sys, parent_names) + for var in unknowns(parent_sys) + type = get_connection_type(var) + type == Flow || continue + return unwrap(unknowns(parent_sys, var)) end + throw(ArgumentError("There is no flow variable in system `$(nameof(parent_sys))`")) end -function expand_instream(csets::AbstractVector{<:ConnectionSet}, sys::AbstractSystem, - namespace = nothing, prevnamespace = nothing; debug = false, - tol = 1e-8) - subsys = get_systems(sys) - # post order traversal - @set! sys.systems = map( - s -> expand_instream(csets, s, - renamespace(namespace, nameof(s)), - namespace; debug, tol), - subsys) - subsys = get_systems(sys) +""" + $(TYPEDSIGNATURES) - if debug - @info "Expanding" namespace +Given connection sets of stream variables in `sys`, return the additional equations to add +to the system and the substitutions to make to handle `instream(..)` expressions. `tol` is +the tolerance for handling singularities in stream connection equations when the flow +variable approaches zero. +""" +function expand_instream(csets::Vector{Vector{ConnectionVertex}}, sys::AbstractSystem; + tol = 1e-8) + eqs = equations(sys) + # collect all `instream` terms in the equations + instream_exprs = Set{BasicSymbolic}() + for eq in eqs + collect_instream!(instream_exprs, eq) end - sub = Dict() - eqs = Equation[] - instream_eqs = Equation[] - instream_exprs = Set() - for s in subsys - for eq in get_eqs(s) - eq = namespace_equation(eq, s) - if collect_instream!(instream_exprs, eq) - push!(instream_eqs, eq) - else - push!(eqs, eq) - end - end + # specifically substitute `instream(x[i]) => instream(x)[i]` + instream_subs = Dict{BasicSymbolic, BasicSymbolic}() + for expr in instream_exprs + stream_var = only(arguments(expr)) + iscall(stream_var) && operation(stream_var) === getindex || continue + args = arguments(stream_var) + new_expr = Symbolics.array_term( + instream, args[1]; size = size(args[1]), ndims = ndims(args[1]))[args[2:end]...] + instream_subs[expr] = new_expr end - if !isempty(instream_exprs) - # map from a namespaced stream variable to a ConnectionSet - expr_cset = Dict() - for cset in csets - crep = first(cset.set) - current = namespace == _getname(crep.sys.namespace) - for v in cset.set - if (current || !v.isouter) - expr_cset[namespaced_var(v)] = cset.set - end - end - end + # for all the newly added `instream(x)[i]`, add `instream(x)` to `instream_exprs` + # also remove all `instream(x[i])` + for (k, v) in instream_subs + push!(instream_exprs, arguments(v)[1]) + delete!(instream_exprs, k) end - for ex in instream_exprs - ns_sv = only(arguments(ex)) - full_name_sv = renamespace(namespace, ns_sv) - cset = get(expr_cset, full_name_sv, nothing) - cset === nothing && error("$ns_sv is not a variable inside stream connectors") - idx_in_set, sv = get_cset_sv(full_name_sv, cset) - - n_inners = n_outers = 0 - for (i, e) in enumerate(cset) - if e.isouter - n_outers += 1 - else - n_inners += 1 - end - end - if debug - @info "Expanding at [$idx_in_set]" ex ConnectionSet(cset) n_inners n_outers - end - if n_inners == 1 && n_outers == 0 - sub[ex] = sv - elseif n_inners == 2 && n_outers == 0 - other = idx_in_set == 1 ? 2 : 1 - sub[ex] = get_current_var(namespace, cset[other], sv) - elseif n_inners == 1 && n_outers == 1 - if !cset[idx_in_set].isouter - other = idx_in_set == 1 ? 2 : 1 - outerstream = get_current_var(namespace, cset[other], sv) - sub[ex] = instream(outerstream) + # This is an implementation of the modelica spec + # https://specification.modelica.org/maint/3.6/stream-connectors.html + additional_eqs = Equation[] + for cset in csets + n_outer = count(cvert -> cvert.isouter, cset) + n_inner = length(cset) - n_outer + if n_inner == 1 && n_outer == 0 + cvert = only(cset) + stream_var = variable_from_vertex(sys, cvert)::BasicSymbolic + instream_subs[instream(stream_var)] = stream_var + elseif n_inner == 2 && n_outer == 0 + cvert1, cvert2 = cset + stream_var1 = variable_from_vertex(sys, cvert1)::BasicSymbolic + stream_var2 = variable_from_vertex(sys, cvert2)::BasicSymbolic + instream_subs[instream(stream_var1)] = stream_var2 + instream_subs[instream(stream_var2)] = stream_var1 + elseif n_inner == 1 && n_outer == 1 + cvert_inner, cvert_outer = cset + if cvert_inner.isouter + cvert_inner, cvert_outer = cvert_outer, cvert_inner end + streamvar_inner = variable_from_vertex(sys, cvert_inner)::BasicSymbolic + streamvar_outer = variable_from_vertex(sys, cvert_outer)::BasicSymbolic + instream_subs[instream(streamvar_inner)] = instream(streamvar_outer) + push!(additional_eqs, (streamvar_outer ~ streamvar_inner)) + elseif n_inner == 0 && n_outer == 2 + cvert1, cvert2 = cset + stream_var1 = variable_from_vertex(sys, cvert1)::BasicSymbolic + stream_var2 = variable_from_vertex(sys, cvert2)::BasicSymbolic + push!(additional_eqs, (stream_var1 ~ instream(stream_var2)), + (stream_var2 ~ instream(stream_var1))) else - if !cset[idx_in_set].isouter - fv = flowvar(first(cset).sys.sys) - # mj.c.m_flow - innerfvs = [get_current_var(namespace, s, fv) - for (j, s) in enumerate(cset) if j != idx_in_set && !s.isouter] - innersvs = [get_current_var(namespace, s, sv) - for (j, s) in enumerate(cset) if j != idx_in_set && !s.isouter] - # ck.m_flow - outerfvs = [get_current_var(namespace, s, fv) for s in cset if s.isouter] - outersvs = [get_current_var(namespace, s, sv) for s in cset if s.isouter] - - sub[ex] = term(instream_rt, Val(length(innerfvs)), Val(length(outerfvs)), - innerfvs..., innersvs..., outerfvs..., outersvs...) + # Currently just implements the "else" case for `instream(..)` in the suggested + # implementation of stream connectors in the Modelica spec v3.6 section 15.2. + # https://specification.modelica.org/maint/3.6/stream-connectors.html#instream-and-connection-equations + # We could implement the "if" case using variable bounds? It would be nice to + # move that metadata to the system (storing it similar to `defaults`). + outer_cverts = filter(cvert -> cvert.isouter, cset) + inner_cverts = filter(cvert -> !cvert.isouter, cset) + + outer_streamvars = map(Base.Fix1(variable_from_vertex, sys), outer_cverts) + inner_streamvars = map(Base.Fix1(variable_from_vertex, sys), inner_cverts) + + outer_flowvars = map(Base.Fix1(get_flowvar, sys), outer_cverts) + inner_flowvars = map(Base.Fix1(get_flowvar, sys), inner_cverts) + + mask = trues(length(inner_cverts)) + for inner_i in eachindex(inner_cverts) + # mask out the current variable + mask[inner_i] = false + svar = inner_streamvars[inner_i] + instream_subs[instream(svar)] = term( + instream_rt, Val(n_inner - 1), Val(n_outer), inner_flowvars[mask]..., + inner_streamvars[mask]..., outer_flowvars..., outer_streamvars...) + # make sure to reset the mask + mask[inner_i] = true end - end - end - # additional equations - additional_eqs = Equation[] - csets = filter(cset -> any(e -> _getname(e.sys.namespace) === namespace, cset.set), - csets) - for cset′ in csets - cset = cset′.set - connectors = Vector{Any}(undef, length(cset)) - n_inners = n_outers = 0 - for (i, e) in enumerate(cset) - connectors[i] = e.sys.sys - if e.isouter - n_outers += 1 - else - n_inners += 1 - end - end - iszero(n_outers) && continue - connector_representative = first(cset).sys.sys - fv = flowvar(connector_representative) - sv = first(cset).v - vtype = get_connection_type(sv) - vtype === Stream || continue - if n_inners == 1 && n_outers == 1 - push!(additional_eqs, - unknowns(cset[1].sys.sys, sv) ~ unknowns(cset[2].sys.sys, sv)) - elseif n_inners == 0 && n_outers == 2 - # we don't expand `instream` in this case. - v1 = unknowns(cset[1].sys.sys, sv) - v2 = unknowns(cset[2].sys.sys, sv) - push!(additional_eqs, v1 ~ instream(v2)) - push!(additional_eqs, v2 ~ instream(v1)) - else - sq = 0 - s_inners = (s for s in cset if !s.isouter) - s_outers = (s for s in cset if s.isouter) - for (q, oscq) in enumerate(s_outers) - sq += sum(s -> max(-unknowns(s, fv), 0), s_inners, init = 0) - for (k, s) in enumerate(s_outers) - k == q && continue - f = unknowns(s.sys.sys, fv) - sq += max(f, 0) + for q in 1:n_outer + sq = mapreduce(+, inner_flowvars) do fvar + max(-fvar, 0) end + sq += mapreduce(+, enumerate(outer_flowvars)) do (outer_i, fvar) + outer_i == q && return 0 + max(fvar, 0) + end + # sanity check to make sure it isn't going to codegen a `mapreduce` + @assert operation(sq) == (+) - num = 0 - den = 0 - for s in s_inners - f = unknowns(s.sys.sys, fv) - tmp = positivemax(-f, sq; tol = tol) - den += tmp - num += tmp * unknowns(s.sys.sys, sv) + num = mapreduce(+, inner_flowvars, inner_streamvars) do fvar, svar + positivemax(-fvar, sq; tol) * svar end - for (k, s) in enumerate(s_outers) - k == q && continue - f = unknowns(s.sys.sys, fv) - tmp = positivemax(f, sq; tol = tol) - den += tmp - num += tmp * instream(unknowns(s.sys.sys, sv)) + num += mapreduce( + +, enumerate(outer_flowvars), outer_streamvars) do (outer_i, fvar), svar + outer_i == q && return 0 + positivemax(fvar, sq; tol) * instream(svar) end - push!(additional_eqs, unknowns(oscq.sys.sys, sv) ~ num / den) - end - end - end - - subed_eqs = substitute(instream_eqs, sub) - if debug && !(isempty(csets) && isempty(additional_eqs) && isempty(instream_eqs)) - println("======================================") - @info "Additional equations" csets - display(additional_eqs) - println("======================================") - println("Substitutions") - display(sub) - println("======================================") - println("Substituted equations") - foreach(i -> println(instream_eqs[i] => subed_eqs[i]), eachindex(subed_eqs)) - println("======================================") - end + @assert operation(num) == (+) - @set! sys.systems = [] - @set! sys.eqs = [get_eqs(sys); eqs; subed_eqs; additional_eqs] - sys -end - -function get_current_var(namespace, cele, sv) - unknowns( - renamespace(unnamespace(namespace, _getname(cele.sys.namespace)), - cele.sys.sys), - sv) -end + den = mapreduce(+, inner_flowvars) do fvar + positivemax(-fvar, sq; tol) + end + den += mapreduce(+, enumerate(outer_flowvars)) do (outer_i, fvar) + outer_i == q && return 0 + positivemax(fvar, sq; tol) + end -function get_cset_sv(full_name_sv, cset) - for (idx_in_set, v) in enumerate(cset) - if isequal(namespaced_var(v), full_name_sv) - return idx_in_set, v.v + push!(additional_eqs, (outer_streamvars[q] ~ num / den)) + end end end - error("$ns_sv is not a variable inside stream connectors") + return additional_eqs, instream_subs end # instream runtime diff --git a/src/systems/system.jl b/src/systems/system.jl index af721a6ff3..863b840b37 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -214,8 +214,7 @@ struct System <: AbstractSystem (ones between connector systems) and the second contains all such causal variable connections. """ - ignored_connections::Union{ - Nothing, Tuple{Vector{IgnoredAnalysisPoint}, Vector{IgnoredAnalysisPoint}}} + ignored_connections::Union{Nothing, Vector{Connection}} """ `SymbolicUtils.Code.Assignment`s to prepend to all code generated from this system. """ diff --git a/src/utils.jl b/src/utils.jl index 2b8ec4a7a0..48a4e57d71 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1050,3 +1050,11 @@ function flatten_equations(eqs::Vector{Equation}) end const JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump} + +struct NotPossibleError <: Exception end + +function Base.showerror(io::IO, ::NotPossibleError) + print(io, """ + This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE. + """) +end diff --git a/test/causal_variables_connection.jl b/test/causal_variables_connection.jl index 222db540de..eb922879e1 100644 --- a/test/causal_variables_connection.jl +++ b/test/causal_variables_connection.jl @@ -25,13 +25,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D @test_throws ["same size"] connect(xarr, yarr) - @test_throws ["Expected", "x", "output = true", "metadata"] connect(x, y) - @test_throws ["Expected", "y", "output = true", "metadata"] connect(y, v) - - @test_throws ["Expected", "x", "input = true", "metadata"] connect(z, x) - @test_throws ["Expected", "x", "input = true", "metadata"] connect(z, y, x) - @test_throws ["Expected", "u", "input = true", "metadata"] connect(z, u) - @test_throws ["Expected", "u", "input = true", "metadata"] connect(z, y, u) + @test_throws ArgumentError connect(x, y) end @testset "Connection expansion" begin @@ -96,3 +90,34 @@ end @test matrices.D[] == 0 end end + +@testset "Outside input to inside input connection" begin + @mtkmodel Inner begin + @variables begin + x(t), [input = true] + y(t), [output = true] + end + @equations begin + y ~ x + end + end + @mtkmodel Outer begin + @variables begin + u(t), [input = true] + v(t), [output = true] + end + @components begin + inner = Inner() + end + @equations begin + connect(u, inner.x) + connect(inner.y, v) + end + end + @named sys = Outer() + ss = toggle_namespacing(sys, false) + eqs = equations(expand_connections(sys)) + @test issetequal(eqs, [ss.u ~ ss.inner.x + ss.inner.y ~ ss.inner.x + ss.inner.y ~ ss.v]) +end diff --git a/test/components.jl b/test/components.jl index e71ef3fa45..8e5747c750 100644 --- a/test/components.jl +++ b/test/components.jl @@ -91,7 +91,7 @@ end @named sys′ = System(eqs, t) @named sys_inner_outer = compose(sys′, [ground, shape, source, rc_comp]) @test_nowarn show(IOBuffer(), MIME"text/plain"(), sys_inner_outer) - expand_connections(sys_inner_outer, debug = true) + expand_connections(sys_inner_outer) sys_inner_outer = mtkcompile(sys_inner_outer) @test !isempty(ModelingToolkit.defaults(sys_inner_outer)) u0 = [rc_comp.capacitor.v => 0.0] diff --git a/test/stream_connectors.jl b/test/stream_connectors.jl index 6e49cd1fb3..dfca3306f4 100644 --- a/test/stream_connectors.jl +++ b/test/stream_connectors.jl @@ -126,7 +126,7 @@ end eqns = [connect(n1m1.port_a, pipe.port_a) connect(pipe.port_b, sink.port)] -@named sys = System(eqns, t) +@named sys = System(eqns, t; systems = [n1m1, pipe, sink]) eqns = [domain_connect(fluid, n1m1.port_a) connect(n1m1.port_a, pipe.port_a) @@ -141,7 +141,7 @@ ssort(eqs) = sort(eqs, by = string) 0 ~ source.port1.m_flow - port_a.m_flow source.port1.P ~ port_a.P source.port1.P ~ source.P - source.port1.h_outflow ~ port_a.h_outflow + port_a.h_outflow ~ source.port1.h_outflow source.port1.h_outflow ~ source.h]) @unpack port_a, port_b = pipe @test ssort(equations(expand_connections(pipe))) == @@ -151,11 +151,11 @@ ssort(eqs) = sort(eqs, by = string) port_a.P ~ port_b.P port_a.h_outflow ~ instream(port_b.h_outflow) port_b.h_outflow ~ instream(port_a.h_outflow)]) -@test ssort(equations(expand_connections(sys))) == - ssort([0 ~ n1m1.port_a.m_flow + pipe.port_a.m_flow - 0 ~ pipe.port_b.m_flow + sink.port.m_flow - n1m1.port_a.P ~ pipe.port_a.P - pipe.port_b.P ~ sink.port.P]) +@test equations(expand_connections(sys)) ⊇ + [0 ~ n1m1.port_a.m_flow + pipe.port_a.m_flow + 0 ~ pipe.port_b.m_flow + sink.port.m_flow + n1m1.port_a.P ~ pipe.port_a.P + pipe.port_b.P ~ sink.port.P] @test ssort(equations(expand_connections(n1m1Test))) == ssort([0 ~ -pipe.port_a.m_flow - pipe.port_b.m_flow 0 ~ n1m1.source.port1.m_flow - n1m1.port_a.m_flow @@ -165,7 +165,7 @@ ssort(eqs) = sort(eqs, by = string) n1m1.port_a.P ~ pipe.port_a.P n1m1.source.port1.P ~ n1m1.port_a.P n1m1.source.port1.P ~ n1m1.source.P - n1m1.source.port1.h_outflow ~ n1m1.port_a.h_outflow + n1m1.port_a.h_outflow ~ n1m1.source.port1.h_outflow n1m1.source.port1.h_outflow ~ n1m1.source.h pipe.port_a.P ~ pipe.port_b.P pipe.port_a.h_outflow ~ sink.port.h_outflow @@ -278,10 +278,8 @@ sys = expand_connections(compose(simple, [vp1, vp2, vp3])) @test ssort(equations(sys)) == ssort([0 .~ collect(vp1.i) 0 .~ collect(vp2.i) 0 .~ collect(vp3.i) - vp1.v[1] ~ vp2.v[1] - vp1.v[2] ~ vp2.v[2] - vp1.v[1] ~ vp3.v[1] - vp1.v[2] ~ vp3.v[2] + vp1.v ~ vp2.v + vp1.v ~ vp3.v 0 ~ -vp1.i[1] - vp2.i[1] - vp3.i[1] 0 ~ -vp1.i[2] - vp2.i[2] - vp3.i[2]])