Skip to content

Commit 73eddc6

Browse files
committed
Add domain_connect
1 parent a69570a commit 73eddc6

File tree

2 files changed

+36
-121
lines changed

2 files changed

+36
-121
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ using Reexport
5555
using Symbolics: degree
5656
@reexport using Symbolics
5757
export @derivatives
58+
export domain_connect
5859
using Symbolics: _parse_vars, value, @derivatives, get_variables,
5960
exprs_occur_in, solve_for, build_expr, unwrap, wrap,
6061
VariableSource, getname, variable, Connection, connect,

src/systems/connectors.jl

Lines changed: 35 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
function domain_connect(sys1, sys2, syss...)
2+
syss = (sys1, sys2, syss...)
3+
length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!")
4+
Equation(Connection(:domain), Connection(syss)) # the RHS are connected systems
5+
end
6+
17
function get_connection_type(s)
28
s = unwrap(s)
39
if istree(s) && operation(s) === getindex
@@ -260,23 +266,22 @@ end
260266

261267
function generate_connection_set(sys::AbstractSystem, find = nothing, replace = nothing)
262268
connectionsets = ConnectionSet[]
263-
sys = generate_connection_set!(connectionsets, sys, find, replace)
264-
domain_free_connectionsets = filter(connectionsets) do cset
265-
!any(s -> is_domain_connector(s.sys.sys), cset.set)
266-
end
269+
domain_csets = ConnectionSet[]
270+
sys = generate_connection_set!(connectionsets, domain_csets, sys, find, replace)
267271

268-
sys, (merge(domain_free_connectionsets), connectionsets)
272+
sys, (merge(connectionsets), merge([connectionsets; domain_csets]))
269273
end
270274

271-
function generate_connection_set!(connectionsets, sys::AbstractSystem, find, replace,
272-
namespace = nothing)
275+
function generate_connection_set!(connectionsets, domain_csets,
276+
sys::AbstractSystem, find, replace, namespace = nothing)
273277
subsys = get_systems(sys)
274278

275279
isouter = generate_isouter(sys)
276280
eqs′ = get_eqs(sys)
277281
eqs = Equation[]
278282

279283
cts = [] # connections
284+
domain_cts = [] # connections
280285
extra_states = []
281286
for eq in eqs′
282287
lhs = eq.lhs
@@ -292,8 +297,14 @@ function generate_connection_set!(connectionsets, sys::AbstractSystem, find, rep
292297
else
293298
if lhs isa Number || lhs isa Symbolic
294299
push!(eqs, eq) # split connections and equations
300+
elseif lhs isa Connect
301+
if get_systems(lhs) === :domain
302+
connection2set!(domain_csets, namespace, get_systems(rhs), isouter)
303+
else
304+
push!(cts, get_systems(rhs))
305+
end
295306
else
296-
push!(cts, get_systems(rhs))
307+
error("$eq is not a legal equation!")
297308
end
298309
end
299310
end
@@ -355,78 +366,6 @@ function Base.merge(csets::AbstractVector{<:ConnectionSet})
355366
mcsets
356367
end
357368

358-
struct SystemDomainGraph{T, C <: AbstractVector{<:ConnectionSet}} <:
359-
Graphs.AbstractGraph{Int}
360-
ts::T
361-
lineqs::BitSet
362-
var2idx::Dict{Any, Int}
363-
id2cset::Vector{NTuple{2, Int}}
364-
cset2id::Vector{Vector{Int}}
365-
csets::C
366-
sys2id::Dict{Symbol, Int}
367-
outne::Vector{Union{Nothing, Vector{Int}}}
368-
end
369-
370-
Graphs.nv(g::SystemDomainGraph) = length(g.id2cset)
371-
function Graphs.outneighbors(g::SystemDomainGraph, n::Int)
372-
i, j = g.id2cset[n]
373-
ids = copy(g.cset2id[i])
374-
@unpack ts, lineqs, var2idx = g
375-
@unpack fullvars, structure = ts
376-
@unpack graph = structure
377-
visited = BitSet(n)
378-
for s in g.csets[i].set
379-
s.sys.namespace === nothing && continue
380-
sys = s.sys.sys
381-
is_domain_connector(sys) && continue
382-
vidx = get(var2idx, states(s.sys.namespace, states(sys, s.v)), 0)
383-
iszero(vidx) && continue
384-
ies = 𝑑neighbors(graph, vidx)
385-
for ie in ies
386-
ie in lineqs || continue
387-
for iv in 𝑠neighbors(graph, ie)
388-
iv == vidx && continue
389-
fv = ts.fullvars[iv]
390-
vtype = get_connection_type(fv)
391-
vtype === Flow || continue
392-
n′ = get(g.sys2id, getname(fv), nothing)
393-
n′ === nothing && continue
394-
n′ in visited && continue
395-
push!(visited, n′)
396-
append!(ids, g.cset2id[g.id2cset[n′][1]])
397-
end
398-
end
399-
end
400-
ids
401-
end
402-
function rooted_system_domain_graph!(ts, csets::AbstractVector{<:ConnectionSet})
403-
id2cset = NTuple{2, Int}[]
404-
cset2id = Vector{Int}[]
405-
sys2id = Dict{Symbol, Int}()
406-
roots = BitSet()
407-
for (i, c) in enumerate(csets)
408-
cset2id′ = Int[]
409-
for (j, s) in enumerate(c.set)
410-
ij = (i, j)
411-
push!(id2cset, ij)
412-
if !haskey(sys2id, nameof(s))
413-
n = length(id2cset)
414-
sys2id[nameof(s)] = n
415-
else
416-
n = sys2id[nameof(s)]
417-
end
418-
push!(cset2id′, n)
419-
is_domain_connector(s.sys.sys) && push!(roots, n)
420-
end
421-
push!(cset2id, cset2id′)
422-
end
423-
outne = Vector{Union{Nothing, Vector{Int}}}(undef, length(id2cset))
424-
mm = linear_subsys_adjmat!(ts)
425-
lineqs = BitSet(mm.nzrows)
426-
var2idx = Dict{Any, Int}(reverse(en) for en in enumerate(ts.fullvars))
427-
SystemDomainGraph(ts, lineqs, var2idx, id2cset, cset2id, csets, sys2id, outne), roots
428-
end
429-
430369
function generate_connection_equations_and_stream_connections(csets::AbstractVector{
431370
<:ConnectionSet,
432371
})
@@ -458,48 +397,23 @@ function generate_connection_equations_and_stream_connections(csets::AbstractVec
458397
end
459398

460399
function domain_defaults(sys, domain_csets)
461-
csets = merge(domain_csets)
462-
g, roots = rooted_system_domain_graph!(TearingState(sys), csets)
463-
# a simple way to make `_g` bidirectional
464-
simple_g = SimpleGraph(nv(g))
465-
for v in 1:nv(g), n in neighbors(g, v)
466-
add_edge!(simple_g, v => n)
467-
end
468-
domain_csets = []
469-
root_ijs = Set(g.id2cset[r] for r in roots)
470-
for r in roots
471-
nh = neighborhood(simple_g, r, Inf)
472-
sources_idxs = intersect(nh, roots)
473-
# TODO: error reporting when length(sources_idxs) > 1
474-
length(sources_idxs) > 1 && error()
475-
i′, j′ = g.id2cset[r]
476-
source = csets[i′].set[j′]
477-
domain = source => []
478-
push!(domain_csets, domain)
479-
# get unique cset indices that `r` is (implicitly) connected to.
480-
idxs = BitSet(g.id2cset[i][1] for i in nh)
481-
for i in idxs
482-
for (j, ele) in enumerate(csets[i].set)
483-
(i, j) == (i′, j′) && continue
484-
if (i, j) in root_ijs
485-
error("Domain source $(nameof(source)) and $(nameof(ele)) are connected!")
486-
end
487-
push!(domain[2], ele)
488-
end
489-
end
490-
end
491-
492400
def = Dict()
493-
for (s, mods) in domain_csets
494-
s_def = defaults(s.sys.sys)
495-
for m in mods
496-
ns_s_def = Dict(states(m.sys.sys, n) => n for (n, v) in s_def)
497-
for p in parameters(m.sys.namespace)
498-
d_p = get(ns_s_def, p, nothing)
499-
if d_p !== nothing
500-
def[parameters(m.sys.namespace, p)] = parameters(s.sys.namespace,
501-
parameters(s.sys.sys,
502-
d_p))
401+
for c in domain_csets
402+
cset = c.set
403+
idx = findfirst(s -> is_domain_connector(s.sys.sys), cset)
404+
s = cset[idx]
405+
for (j, m) in enumerate(cset)
406+
if j == idx
407+
error("Domain sources $(nameof(root)) and $(nameof(m)) are connected!")
408+
else
409+
ns_s_def = Dict(states(m.sys.sys, n) => n for (n, v) in s_def)
410+
for p in parameters(m.sys.namespace)
411+
d_p = get(ns_s_def, p, nothing)
412+
if d_p !== nothing
413+
def[parameters(m.sys.namespace, p)] = parameters(s.sys.namespace,
414+
parameters(s.sys.sys,
415+
d_p))
416+
end
503417
end
504418
end
505419
end

0 commit comments

Comments
 (0)