Skip to content

Commit d046e88

Browse files
feat: ensure causal connectors generate causally ordered equations
1 parent 2426464 commit d046e88

File tree

2 files changed

+58
-9
lines changed

2 files changed

+58
-9
lines changed

src/systems/connectors.jl

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -780,11 +780,11 @@ function generate_connection_equations_and_stream_connections(
780780
inner_output = cvert
781781
end
782782
end
783-
root, rest = Iterators.peel(cset)
784-
root_var = variable_from_vertex(sys, root)
785-
for cvert in rest
786-
var = variable_from_vertex(sys, cvert)
787-
push!(eqs, root_var ~ var)
783+
root_vert = something(inner_output, outer_input)
784+
root_var = variable_from_vertex(sys, root_vert)
785+
for cvert in cset
786+
isequal(cvert, root_vert) && continue
787+
push!(eqs, variable_from_vertex(sys, cvert) ~ root_var)
788788
end
789789
elseif vtype === Stream
790790
push!(stream_connections, cset)
@@ -807,10 +807,37 @@ function generate_connection_equations_and_stream_connections(
807807
push!(eqs, 0 ~ rhs)
808808
end
809809
else # Equality
810-
base = variable_from_vertex(sys, cset[1])
811-
for i in 2:length(cset)
812-
v = variable_from_vertex(sys, cset[i])
813-
push!(eqs, base ~ v)
810+
vars = map(Base.Fix1(variable_from_vertex, sys), cset)
811+
outer_input = inner_output = nothing
812+
all_io = true
813+
# attempt to interpret the equality as a causal connectionset if
814+
# possible
815+
for (cvert, vert) in zip(cset, vars)
816+
is_i = isinput(vert)
817+
is_o = isoutput(vert)
818+
all_io &= is_i || is_o
819+
all_io || break
820+
if cvert.isouter && is_i && outer_input === nothing
821+
outer_input = cvert
822+
elseif !cvert.isouter && is_o && inner_output === nothing
823+
inner_output = cvert
824+
end
825+
end
826+
# this doesn't necessarily mean this is a well-structured causal connection,
827+
# but it is sufficient and we're generating equalities anyway.
828+
if all_io && xor(outer_input !== nothing, inner_output !== nothing)
829+
root_vert = something(inner_output, outer_input)
830+
root_var = variable_from_vertex(sys, root_vert)
831+
for (cvert, var) in zip(cset, vars)
832+
isequal(cvert, root_vert) && continue
833+
push!(eqs, var ~ root_var)
834+
end
835+
else
836+
base = variable_from_vertex(sys, cset[1])
837+
for i in 2:length(cset)
838+
v = vars[i]
839+
push!(eqs, base ~ v)
840+
end
814841
end
815842
end
816843
end

test/components.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,25 @@ end
335335
sys = complete(outer)
336336
@test getmetadata(sys, Int, nothing) == "test"
337337
end
338+
339+
@testset "Causal connections generate causal equations" begin
340+
# test interpretation of `Equality` cset as causal connection
341+
@named input = RealInput()
342+
@named comp1 = System(Equation[], t; systems = [input])
343+
@named output = RealOutput()
344+
@named comp2 = System(Equation[], t; systems = [output])
345+
@named sys = System([connect(comp2.output, comp1.input)], t; systems = [comp1, comp2])
346+
eq = only(equations(expand_connections(sys)))
347+
# as opposed to `output.u ~ input.u`
348+
@test isequal(eq, comp1.input.u ~ comp2.output.u)
349+
350+
# test causal ordering of true causal cset
351+
@named input = RealInput()
352+
@named comp1 = System(Equation[], t; systems = [input])
353+
@named output = RealOutput()
354+
@named comp2 = System(Equation[], t; systems = [output])
355+
@named sys = System([connect(comp2.output.u, comp1.input.u)], t; systems = [comp1, comp2])
356+
eq = only(equations(expand_connections(sys)))
357+
# as opposed to `output.u ~ input.u`
358+
@test isequal(eq, comp1.input.u ~ comp2.output.u)
359+
end

0 commit comments

Comments
 (0)