Skip to content

Commit b7e4cef

Browse files
committed
More performance
1 parent 9c166cc commit b7e4cef

File tree

3 files changed

+32
-15
lines changed

3 files changed

+32
-15
lines changed

src/systems/connectors.jl

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,32 @@ end
283283

284284
function Base.merge(csets::AbstractVector{<:ConnectionSet})
285285
mcsets = ConnectionSet[]
286-
# FIXME: this is O(m n^3)
286+
ele2idx = Dict{ConnectionElement,Int}()
287+
cacheset = Set{ConnectionElement}()
287288
for cset in csets
288-
idx = findfirst(mcset->any(s->any(z->z == s, cset.set), mcset.set), mcsets)
289+
idx = nothing
290+
for e in cset.set
291+
idx = get(ele2idx, e, nothing)
292+
idx !== nothing && break
293+
end
289294
if idx === nothing
290295
push!(mcsets, cset)
296+
for e in cset.set
297+
ele2idx[e] = length(mcsets)
298+
end
291299
else
292-
union!(mcsets[idx].set, cset.set)
300+
for e in mcsets[idx].set
301+
push!(cacheset, e)
302+
end
303+
for e in cset.set
304+
push!(cacheset, e)
305+
end
306+
empty!(mcsets[idx].set)
307+
for e in cacheset
308+
ele2idx[e] = idx
309+
push!(mcsets[idx].set, e)
310+
end
311+
empty!(cacheset)
293312
end
294313
end
295314
mcsets
@@ -330,8 +349,8 @@ function expand_connections(sys::AbstractSystem; debug=false, tol=1e-10)
330349
sys, csets = generate_connection_set(sys)
331350
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
332351
additional_eqs = Equation[]
333-
_sys = expand_instream2(instream_csets, sys; debug=debug, tol=tol)
334-
sys = flatten(sys)
352+
_sys = expand_instream(instream_csets, sys; debug=debug, tol=tol)
353+
sys = flatten(sys, true)
335354
@set! sys.eqs = [equations(_sys); ceqs; additional_eqs]
336355
end
337356

@@ -348,12 +367,10 @@ function unnamespace(root, namespace)
348367
end
349368
end
350369

351-
function expand_instream2(csets::AbstractVector{<:ConnectionSet}, sys::AbstractSystem, namespace=nothing, prevnamespace=nothing; debug=false, tol=1e-8)
370+
function expand_instream(csets::AbstractVector{<:ConnectionSet}, sys::AbstractSystem, namespace=nothing, prevnamespace=nothing; debug=false, tol=1e-8)
352371
subsys = get_systems(sys)
353-
# no connectors if there are no subsystems
354-
#isempty(subsys) && return sys
355372
# post order traversal
356-
@set! sys.systems = map(s->expand_instream2(csets, s, renamespace(namespace, nameof(s)), namespace; debug, tol), subsys)
373+
@set! sys.systems = map(s->expand_instream(csets, s, renamespace(namespace, nameof(s)), namespace; debug, tol), subsys)
357374
subsys = get_systems(sys)
358375

359376
if debug
@@ -365,8 +382,8 @@ function expand_instream2(csets::AbstractVector{<:ConnectionSet}, sys::AbstractS
365382
instream_eqs = Equation[]
366383
instream_exprs = Set()
367384
for s in subsys
368-
seqs = map(Base.Fix2(namespace_equation, s), get_eqs(s))
369-
for eq in seqs
385+
for eq in get_eqs(s)
386+
eq = namespace_equation(eq, s)
370387
if collect_instream!(instream_exprs, eq)
371388
push!(instream_eqs, eq)
372389
else

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,13 +229,13 @@ function Base.:(==)(sys1::ODESystem, sys2::ODESystem)
229229
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
230230
end
231231

232-
function flatten(sys::ODESystem)
232+
function flatten(sys::ODESystem, noeqs=false)
233233
systems = get_systems(sys)
234234
if isempty(systems)
235235
return sys
236236
else
237237
return ODESystem(
238-
equations(sys),
238+
noeqs ? Equation[] : equations(sys),
239239
get_iv(sys),
240240
states(sys),
241241
parameters(sys),

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,13 @@ function NonlinearProblemExpr{iip}(sys::NonlinearSystem,u0map,
335335
!linenumbers ? striplines(ex) : ex
336336
end
337337

338-
function flatten(sys::NonlinearSystem)
338+
function flatten(sys::NonlinearSystem, noeqs=false)
339339
systems = get_systems(sys)
340340
if isempty(systems)
341341
return sys
342342
else
343343
return NonlinearSystem(
344-
equations(sys),
344+
noeqs ? Equation[] : equations(sys),
345345
states(sys),
346346
parameters(sys),
347347
observed=observed(sys),

0 commit comments

Comments
 (0)