Skip to content

Commit 26668fd

Browse files
fix: handle causal variable AP ignoring part of normal connect
1 parent 5507939 commit 26668fd

File tree

2 files changed

+107
-27
lines changed

2 files changed

+107
-27
lines changed

src/systems/connectors.jl

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,34 @@ function ori(sys)
315315
end
316316
end
317317

318-
function connection2set!(connectionsets, namespace, ss, isouter)
318+
"""
319+
$(TYPEDSIGNATURES)
320+
321+
Populate `connectionsets` with connections between the connectors `ss`, all of which are
322+
namespaced by `namespace`.
323+
324+
# Keyword Arguments
325+
- `ignored_connects`: A tuple of the systems and variables for which connections should be
326+
ignored. Of the format returned from `as_hierarchy`.
327+
- `namespaced_ignored_systems`: The `from_hierarchy` versions of entries in
328+
`ignored_connects[1]`, purely to avoid unnecessary recomputation.
329+
"""
330+
function connection2set!(connectionsets, namespace, ss, isouter;
331+
ignored_connects = (HierarchySystemT[], HierarchyVariableT[]),
332+
namespaced_ignored_systems = ODESystem[])
333+
ignored_systems, ignored_variables = ignored_connects
334+
# ignore specified systems
335+
ss = filter(ss) do s
336+
all(namespaced_ignored_systems) do igsys
337+
nameof(igsys) != nameof(s)
338+
end
339+
end
340+
# `ignored_variables` for each `s` in `ss`
341+
corresponding_ignored_variables = map(
342+
Base.Fix2(ignored_systems_for_subsystem, ignored_variables), ss)
343+
corresponding_namespaced_ignored_variables = map(
344+
Broadcast.BroadcastFunction(from_hierarchy), corresponding_ignored_variables)
345+
319346
regular_ss = []
320347
domain_ss = nothing
321348
for s in ss
@@ -340,9 +367,12 @@ function connection2set!(connectionsets, namespace, ss, isouter)
340367
for (i, s) in enumerate(ss)
341368
sts = unknowns(s)
342369
io = isouter(s)
343-
for (j, v) in enumerate(sts)
370+
_ignored_variables = corresponding_ignored_variables[i]
371+
_namespaced_ignored_variables = corresponding_namespaced_ignored_variables[i]
372+
for v in sts
344373
vtype = get_connection_type(v)
345374
(vtype === Flow && isequal(v, dv)) || continue
375+
any(isequal(v), _namespaced_ignored_variables) && continue
346376
push!(cset, T(LazyNamespace(namespace, domain_ss), dv, false))
347377
push!(cset, T(LazyNamespace(namespace, s), v, io))
348378
end
@@ -360,6 +390,12 @@ function connection2set!(connectionsets, namespace, ss, isouter)
360390
end
361391
sts1 = Set(sts1v)
362392
num_unknowns = length(sts1)
393+
394+
# we don't filter here because `csets` should include the full set of unknowns.
395+
# not all of `ss` will have the same (or any) variables filtered so the ones
396+
# that aren't should still go in the right cset. Since `sts1` is only used for
397+
# validating that all systems being connected are of the same type, it has
398+
# unfiltered entries.
363399
csets = [T[] for _ in 1:num_unknowns] # Add 9 orientation variables if connection is between multibody frames
364400
for (i, s) in enumerate(ss)
365401
unknown_vars = unknowns(s)
@@ -372,7 +408,10 @@ function connection2set!(connectionsets, namespace, ss, isouter)
372408
all(Base.Fix2(in, sts1), unknown_vars)) ||
373409
connection_error(ss))
374410
io = isouter(s)
411+
# don't `filter!` here so that `j` points to the correct cset regardless of
412+
# which variables are filtered.
375413
for (j, v) in enumerate(unknown_vars)
414+
any(isequal(v), corresponding_namespaced_ignored_variables[i]) && continue
376415
push!(csets[j], T(LazyNamespace(namespace, s), v, io))
377416
end
378417
end
@@ -397,7 +436,7 @@ function generate_connection_set(
397436
sys = generate_connection_set!(
398437
connectionsets, domain_csets, sys, find, replace, scalarize, nothing,
399438
# include systems to be ignored
400-
ignored_connections(sys)[1])
439+
ignored_connections(sys))
401440
csets = merge(connectionsets)
402441
domain_csets = merge([csets; domain_csets], true)
403442

@@ -417,18 +456,23 @@ Generate connection sets from `connect` equations.
417456
- `sys` is the system whose equations are to be searched.
418457
- `namespace` is a system representing the namespace in which `sys` exists, or `nothing`
419458
for no namespace (if `sys` is top-level).
420-
- `ignored_systems` is a list of systems (in the format returned by `as_hierarchy`) to
421-
be ignored when generating connections. This is typically because the connections
422-
they are used in were removed by analysis point transformations.
459+
- `ignored_connects` is a tuple. The first (second) element is a list of systems
460+
(variables) in the format returned by `as_hierarchy` to be ignored when generating
461+
connections. This is typically because the connections they are used in were removed by
462+
analysis point transformations.
423463
"""
424464
function generate_connection_set!(connectionsets, domain_csets,
425-
sys::AbstractSystem, find, replace, scalarize, namespace = nothing, ignored_systems = [])
465+
sys::AbstractSystem, find, replace, scalarize, namespace = nothing,
466+
ignored_connects = (HierarchySystemT[], HierarchyVariableT[]))
426467
subsys = get_systems(sys)
468+
ignored_systems, ignored_variables = ignored_connects
427469
# turn hierarchies into namespaced systems
428-
namespaced_ignored = from_hierarchy.(ignored_systems)
470+
namespaced_ignored_systems = from_hierarchy.(ignored_systems)
471+
namespaced_ignored_variables = from_hierarchy.(ignored_variables)
472+
namespaced_ignored = (namespaced_ignored_systems, namespaced_ignored_variables)
429473
# filter the subsystems of `sys` to exclude ignored ones
430474
filtered_subsys = filter(subsys) do ss
431-
all(namespaced_ignored) do igsys
475+
all(namespaced_ignored_systems) do igsys
432476
nameof(igsys) != nameof(ss)
433477
end
434478
end
@@ -457,21 +501,10 @@ function generate_connection_set!(connectionsets, domain_csets,
457501
neweq isa AbstractArray ? append!(eqs, neweq) : push!(eqs, neweq)
458502
else
459503
if lhs isa Connection && get_systems(lhs) === :domain
460-
# don't consider systems that should be ignored
461-
systems_to_connect = filter(get_systems(rhs)) do ss
462-
all(namespaced_ignored) do igsys
463-
nameof(igsys) != nameof(ss)
464-
end
465-
end
466-
connection2set!(domain_csets, namespace, systems_to_connect, isouter)
504+
connection2set!(domain_csets, namespace, get_systems(rhs), isouter;
505+
ignored_connects, namespaced_ignored_systems)
467506
elseif isconnection(rhs)
468-
# ignore required systems
469-
systems_to_connect = filter(get_systems(rhs)) do ss
470-
all(namespaced_ignored) do igsys
471-
nameof(igsys) != nameof(ss)
472-
end
473-
end
474-
push!(cts, systems_to_connect)
507+
push!(cts, get_systems(rhs))
475508
else
476509
# split connections and equations
477510
if eq.lhs isa AbstractArray || eq.rhs isa AbstractArray
@@ -489,14 +522,19 @@ function generate_connection_set!(connectionsets, domain_csets,
489522
for s in filtered_subsys
490523
isconnector(s) || continue
491524
is_domain_connector(s) && continue
525+
_ignored_variables = ignored_systems_for_subsystem(s, ignored_variables)
526+
_namespaced_ignored_variables = from_hierarchy.(_ignored_variables)
492527
for v in unknowns(s)
493528
Flow === get_connection_type(v) || continue
529+
# ignore specified variables
530+
any(isequal(v), _namespaced_ignored_variables) && continue
494531
push!(connectionsets, ConnectionSet([T(LazyNamespace(namespace, s), v, false)]))
495532
end
496533
end
497534

498535
for ct in cts
499-
connection2set!(connectionsets, namespace, ct, isouter)
536+
connection2set!(connectionsets, namespace, ct, isouter;
537+
ignored_connects, namespaced_ignored_systems)
500538
end
501539

502540
# pre order traversal
@@ -506,7 +544,7 @@ function generate_connection_set!(connectionsets, domain_csets,
506544
@set! sys.systems = map(
507545
s -> generate_connection_set!(connectionsets, domain_csets, s,
508546
find, replace, scalarize, renamespace(namespace, s),
509-
ignored_systems_for_subsystem(s, ignored_systems)),
547+
ignored_systems_for_subsystem.((s,), ignored_connects)),
510548
subsys)
511549
@set! sys.eqs = eqs
512550
end
@@ -522,10 +560,16 @@ their hierarchy to not include `subsys`.
522560
function ignored_systems_for_subsystem(
523561
subsys::AbstractSystem, ignored_systems::Vector{<:HierarchyT})
524562
result = eltype(ignored_systems)[]
563+
# in case `subsys` is namespaced, get its hierarchy and compare suffixes
564+
# instead of the just the last element
565+
suffix = reverse!(namespace_hierarchy(nameof(subsys)))
566+
N = length(suffix)
525567
for igsys in ignored_systems
526-
if igsys[end] == nameof(subsys)
568+
if igsys[(end - N + 1):end] == suffix
527569
push!(result, copy(igsys))
528-
pop!(result[end])
570+
for i in 1:N
571+
pop!(result[end])
572+
end
529573
end
530574
end
531575
return result

test/downstream/analysis_points.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,42 @@ end
152152
@test lsys == lsyso || lsys == -1 * lsyso * (-1) # Output and input sensitivities are equal for SISO systems
153153
end
154154

155+
@testset "Duplicate `connect` statements across subsystems with AP transforms - mixed `connect`" begin
156+
@named P = FirstOrder(k = 1, T = 1)
157+
@named C = Gain(; k = 1)
158+
@named add = Blocks.Add(k2 = -1)
159+
160+
eqs = [connect(P.output.u, :plant_output, add.input2.u)
161+
connect(add.output, C.input)
162+
connect(C.output, P.input)]
163+
164+
sys_inner = ODESystem(eqs, t, systems = [P, C, add], name = :inner)
165+
166+
@named r = Constant(k = 1)
167+
@named F = FirstOrder(k = 1, T = 3)
168+
169+
eqs = [connect(r.output, F.input)
170+
connect(sys_inner.P.output, sys_inner.add.input2)
171+
connect(sys_inner.C.output.u, :plant_input, sys_inner.P.input.u)
172+
connect(F.output, sys_inner.add.input1)]
173+
sys_outer = ODESystem(eqs, t, systems = [F, sys_inner, r], name = :outer)
174+
175+
# test first that the structural_simplify works correctly
176+
ssys = structural_simplify(sys_outer)
177+
prob = ODEProblem(ssys, Pair[], (0, 10))
178+
@test_nowarn solve(prob, Rodas5())
179+
180+
matrices, _ = get_sensitivity(sys_outer, sys_outer.plant_input)
181+
lsys = sminreal(ss(matrices...))
182+
@test lsys.A[] == -2
183+
@test lsys.B[] * lsys.C[] == -1 # either one negative
184+
@test lsys.D[] == 1
185+
186+
matrices_So, _ = get_sensitivity(sys_outer, sys_outer.inner.plant_output)
187+
lsyso = sminreal(ss(matrices_So...))
188+
@test lsys == lsyso || lsys == -1 * lsyso * (-1) # Output and input sensitivities are equal for SISO systems
189+
end
190+
155191
@testset "multilevel system with loop openings" begin
156192
@named P_inner = FirstOrder(k = 1, T = 1)
157193
@named feedback = Feedback()

0 commit comments

Comments
 (0)