Skip to content

Commit 1344c1d

Browse files
authored
Merge pull request #2233 from SciML/myb/new_domain
Add `domain_connect`
2 parents a1369f3 + e4b9235 commit 1344c1d

File tree

5 files changed

+226
-131
lines changed

5 files changed

+226
-131
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ export SteadyStateProblem, SteadyStateProblemExpr
191191
export JumpProblem, DiscreteProblem
192192
export NonlinearSystem, OptimizationSystem, ConstraintsSystem
193193
export alias_elimination, flatten
194-
export connect, @connector, Connection, Flow, Stream, instream
194+
export connect, domain_connect, @connector, Connection, Flow, Stream, instream
195195
export @component, @mtkmodel
196196
export isinput, isoutput, getbounds, hasbounds, isdisturbance, istunable, getdist, hasdist,
197197
tunable_parameters, isirreducible, getdescription, hasdescription, isbinaryvar,

src/systems/connectors.jl

Lines changed: 68 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
function domain_connect(sys1, sys2, syss...)
2+
syss = (sys1, sys2, syss...)
3+
length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!")
4+
Equation(Connection(:domain), Connection(syss)) # the RHS are connected systems
5+
end
6+
17
function get_connection_type(s)
28
s = unwrap(s)
39
if istree(s) && operation(s) === getindex
@@ -260,23 +266,24 @@ end
260266

261267
function generate_connection_set(sys::AbstractSystem, find = nothing, replace = nothing)
262268
connectionsets = ConnectionSet[]
263-
sys = generate_connection_set!(connectionsets, sys, find, replace)
264-
domain_free_connectionsets = filter(connectionsets) do cset
265-
!any(s -> is_domain_connector(s.sys.sys), cset.set)
266-
end
269+
domain_csets = ConnectionSet[]
270+
sys = generate_connection_set!(connectionsets, domain_csets, sys, find, replace)
271+
csets = merge(connectionsets)
272+
domain_csets = merge([csets; domain_csets], true)
267273

268-
sys, (merge(domain_free_connectionsets), connectionsets)
274+
sys, (csets, domain_csets)
269275
end
270276

271-
function generate_connection_set!(connectionsets, sys::AbstractSystem, find, replace,
272-
namespace = nothing)
277+
function generate_connection_set!(connectionsets, domain_csets,
278+
sys::AbstractSystem, find, replace, namespace = nothing)
273279
subsys = get_systems(sys)
274280

275281
isouter = generate_isouter(sys)
276282
eqs′ = get_eqs(sys)
277283
eqs = Equation[]
278284

279285
cts = [] # connections
286+
domain_cts = [] # connections
280287
extra_states = []
281288
for eq in eqs′
282289
lhs = eq.lhs
@@ -292,8 +299,14 @@ function generate_connection_set!(connectionsets, sys::AbstractSystem, find, rep
292299
else
293300
if lhs isa Number || lhs isa Symbolic
294301
push!(eqs, eq) # split connections and equations
302+
elseif lhs isa Connection
303+
if get_systems(lhs) === :domain
304+
connection2set!(domain_csets, namespace, get_systems(rhs), isouter)
305+
else
306+
push!(cts, get_systems(rhs))
307+
end
295308
else
296-
push!(cts, get_systems(rhs))
309+
error("$eq is not a legal equation!")
297310
end
298311
end
299312
end
@@ -302,6 +315,7 @@ function generate_connection_set!(connectionsets, sys::AbstractSystem, find, rep
302315
T = ConnectionElement
303316
for s in subsys
304317
isconnector(s) || continue
318+
is_domain_connector(s) && continue
305319
for v in states(s)
306320
Flow === get_connection_type(v) || continue
307321
push!(connectionsets, ConnectionSet([T(LazyNamespace(namespace, s), v, false)]))
@@ -316,24 +330,42 @@ function generate_connection_set!(connectionsets, sys::AbstractSystem, find, rep
316330
if !isempty(extra_states)
317331
@set! sys.states = [get_states(sys); extra_states]
318332
end
319-
@set! sys.systems = map(s -> generate_connection_set!(connectionsets, s, find, replace,
333+
@set! sys.systems = map(s -> generate_connection_set!(connectionsets, domain_csets, s,
334+
find, replace,
320335
renamespace(namespace, s)),
321336
subsys)
322337
@set! sys.eqs = eqs
323338
end
324339

325-
function Base.merge(csets::AbstractVector{<:ConnectionSet})
340+
function Base.merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
341+
csets, merged = partial_merge(csets, allouter)
342+
while merged
343+
csets, merged = partial_merge(csets)
344+
end
345+
csets
346+
end
347+
348+
function partial_merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
326349
mcsets = ConnectionSet[]
327350
ele2idx = Dict{ConnectionElement, Int}()
328351
cacheset = Set{ConnectionElement}()
329-
for cset in csets
352+
merged = false
353+
for (j, cset) in enumerate(csets)
354+
if allouter
355+
cset = ConnectionSet(map(cset.set) do e
356+
@set! e.isouter = true
357+
end)
358+
end
330359
idx = nothing
331360
for e in cset.set
332361
idx = get(ele2idx, e, nothing)
333-
idx !== nothing && break
362+
if idx !== nothing
363+
merged = true
364+
break
365+
end
334366
end
335367
if idx === nothing
336-
push!(mcsets, cset)
368+
push!(mcsets, copy(cset))
337369
for e in cset.set
338370
ele2idx[e] = length(mcsets)
339371
end
@@ -352,79 +384,7 @@ function Base.merge(csets::AbstractVector{<:ConnectionSet})
352384
empty!(cacheset)
353385
end
354386
end
355-
mcsets
356-
end
357-
358-
struct SystemDomainGraph{T, C <: AbstractVector{<:ConnectionSet}} <:
359-
Graphs.AbstractGraph{Int}
360-
ts::T
361-
lineqs::BitSet
362-
var2idx::Dict{Any, Int}
363-
id2cset::Vector{NTuple{2, Int}}
364-
cset2id::Vector{Vector{Int}}
365-
csets::C
366-
sys2id::Dict{Symbol, Int}
367-
outne::Vector{Union{Nothing, Vector{Int}}}
368-
end
369-
370-
Graphs.nv(g::SystemDomainGraph) = length(g.id2cset)
371-
function Graphs.outneighbors(g::SystemDomainGraph, n::Int)
372-
i, j = g.id2cset[n]
373-
ids = copy(g.cset2id[i])
374-
@unpack ts, lineqs, var2idx = g
375-
@unpack fullvars, structure = ts
376-
@unpack graph = structure
377-
visited = BitSet(n)
378-
for s in g.csets[i].set
379-
s.sys.namespace === nothing && continue
380-
sys = s.sys.sys
381-
is_domain_connector(sys) && continue
382-
vidx = get(var2idx, states(s.sys.namespace, states(sys, s.v)), 0)
383-
iszero(vidx) && continue
384-
ies = 𝑑neighbors(graph, vidx)
385-
for ie in ies
386-
ie in lineqs || continue
387-
for iv in 𝑠neighbors(graph, ie)
388-
iv == vidx && continue
389-
fv = ts.fullvars[iv]
390-
vtype = get_connection_type(fv)
391-
vtype === Flow || continue
392-
n′ = get(g.sys2id, getname(fv), nothing)
393-
n′ === nothing && continue
394-
n′ in visited && continue
395-
push!(visited, n′)
396-
append!(ids, g.cset2id[g.id2cset[n′][1]])
397-
end
398-
end
399-
end
400-
ids
401-
end
402-
function rooted_system_domain_graph!(ts, csets::AbstractVector{<:ConnectionSet})
403-
id2cset = NTuple{2, Int}[]
404-
cset2id = Vector{Int}[]
405-
sys2id = Dict{Symbol, Int}()
406-
roots = BitSet()
407-
for (i, c) in enumerate(csets)
408-
cset2id′ = Int[]
409-
for (j, s) in enumerate(c.set)
410-
ij = (i, j)
411-
push!(id2cset, ij)
412-
if !haskey(sys2id, nameof(s))
413-
n = length(id2cset)
414-
sys2id[nameof(s)] = n
415-
else
416-
n = sys2id[nameof(s)]
417-
end
418-
push!(cset2id′, n)
419-
is_domain_connector(s.sys.sys) && push!(roots, n)
420-
end
421-
push!(cset2id, cset2id′)
422-
end
423-
outne = Vector{Union{Nothing, Vector{Int}}}(undef, length(id2cset))
424-
mm = linear_subsys_adjmat!(ts)
425-
lineqs = BitSet(mm.nzrows)
426-
var2idx = Dict{Any, Int}(reverse(en) for en in enumerate(ts.fullvars))
427-
SystemDomainGraph(ts, lineqs, var2idx, id2cset, cset2id, csets, sys2id, outne), roots
387+
mcsets, merged
428388
end
429389

430390
function generate_connection_equations_and_stream_connections(csets::AbstractVector{
@@ -458,48 +418,28 @@ function generate_connection_equations_and_stream_connections(csets::AbstractVec
458418
end
459419

460420
function domain_defaults(sys, domain_csets)
461-
csets = merge(domain_csets)
462-
g, roots = rooted_system_domain_graph!(TearingState(sys), csets)
463-
# a simple way to make `_g` bidirectional
464-
simple_g = SimpleGraph(nv(g))
465-
for v in 1:nv(g), n in neighbors(g, v)
466-
add_edge!(simple_g, v => n)
467-
end
468-
domain_csets = []
469-
root_ijs = Set(g.id2cset[r] for r in roots)
470-
for r in roots
471-
nh = neighborhood(simple_g, r, Inf)
472-
sources_idxs = intersect(nh, roots)
473-
# TODO: error reporting when length(sources_idxs) > 1
474-
length(sources_idxs) > 1 && error()
475-
i′, j′ = g.id2cset[r]
476-
source = csets[i′].set[j′]
477-
domain = source => []
478-
push!(domain_csets, domain)
479-
# get unique cset indices that `r` is (implicitly) connected to.
480-
idxs = BitSet(g.id2cset[i][1] for i in nh)
481-
for i in idxs
482-
for (j, ele) in enumerate(csets[i].set)
483-
(i, j) == (i′, j′) && continue
484-
if (i, j) in root_ijs
485-
error("Domain source $(nameof(source)) and $(nameof(ele)) are connected!")
486-
end
487-
push!(domain[2], ele)
488-
end
489-
end
490-
end
491-
492421
def = Dict()
493-
for (s, mods) in domain_csets
494-
s_def = defaults(s.sys.sys)
495-
for m in mods
496-
ns_s_def = Dict(states(m.sys.sys, n) => n for (n, v) in s_def)
497-
for p in parameters(m.sys.namespace)
498-
d_p = get(ns_s_def, p, nothing)
499-
if d_p !== nothing
500-
def[parameters(m.sys.namespace, p)] = parameters(s.sys.namespace,
501-
parameters(s.sys.sys,
502-
d_p))
422+
for c in domain_csets
423+
cset = c.set
424+
idx = findfirst(s -> is_domain_connector(s.sys.sys), cset)
425+
idx === nothing && continue
426+
s = cset[idx]
427+
root = s.sys
428+
s_def = defaults(root.sys)
429+
for (j, m) in enumerate(cset)
430+
if j == idx
431+
continue
432+
elseif is_domain_connector(m.sys.sys)
433+
error("Domain sources $(nameof(root)) and $(nameof(m)) are connected!")
434+
else
435+
ns_s_def = Dict(states(m.sys.sys, n) => n for (n, v) in s_def)
436+
for p in parameters(m.sys.namespace)
437+
d_p = get(ns_s_def, p, nothing)
438+
if d_p !== nothing
439+
def[parameters(m.sys.namespace, p)] = parameters(s.sys.namespace,
440+
parameters(s.sys.sys,
441+
d_p))
442+
end
503443
end
504444
end
505445
end
@@ -656,7 +596,7 @@ function expand_instream(csets::AbstractVector{<:ConnectionSet}, sys::AbstractSy
656596
s_inners = (s for s in cset if !s.isouter)
657597
s_outers = (s for s in cset if s.isouter)
658598
for (q, oscq) in enumerate(s_outers)
659-
sq += sum(s -> max(-states(s, fv), 0), s_inners)
599+
sq += sum(s -> max(-states(s, fv), 0), s_inners, init = 0)
660600
for (k, s) in enumerate(s_outers)
661601
k == q && continue
662602
f = states(s.sys.sys, fv)

0 commit comments

Comments
 (0)