diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 4e0c2adb14..c0ddf5baee 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -759,6 +759,7 @@ function generate_connection_equations_and_stream_connections( var = variable_from_vertex(sys, cvert)::BasicSymbolic vtype = cvert.type if vtype <: Union{InputVar, OutputVar} + length(cset) > 1 || continue inner_output = nothing outer_input = nothing for cvert in cset @@ -780,11 +781,11 @@ function generate_connection_equations_and_stream_connections( inner_output = cvert end end - root, rest = Iterators.peel(cset) - root_var = variable_from_vertex(sys, root) - for cvert in rest - var = variable_from_vertex(sys, cvert) - push!(eqs, root_var ~ var) + root_vert = something(inner_output, outer_input) + root_var = variable_from_vertex(sys, root_vert) + for cvert in cset + isequal(cvert, root_vert) && continue + push!(eqs, variable_from_vertex(sys, cvert) ~ root_var) end elseif vtype === Stream push!(stream_connections, cset) @@ -807,10 +808,37 @@ function generate_connection_equations_and_stream_connections( push!(eqs, 0 ~ rhs) end else # Equality - base = variable_from_vertex(sys, cset[1]) - for i in 2:length(cset) - v = variable_from_vertex(sys, cset[i]) - push!(eqs, base ~ v) + vars = map(Base.Fix1(variable_from_vertex, sys), cset) + outer_input = inner_output = nothing + all_io = true + # attempt to interpret the equality as a causal connectionset if + # possible + for (cvert, vert) in zip(cset, vars) + is_i = isinput(vert) + is_o = isoutput(vert) + all_io &= is_i || is_o + all_io || break + if cvert.isouter && is_i && outer_input === nothing + outer_input = cvert + elseif !cvert.isouter && is_o && inner_output === nothing + inner_output = cvert + end + end + # this doesn't necessarily mean this is a well-structured causal connection, + # but it is sufficient and we're generating equalities anyway. + if all_io && xor(outer_input !== nothing, inner_output !== nothing) + root_vert = something(inner_output, outer_input) + root_var = variable_from_vertex(sys, root_vert) + for (cvert, var) in zip(cset, vars) + isequal(cvert, root_vert) && continue + push!(eqs, var ~ root_var) + end + else + base = variable_from_vertex(sys, cset[1]) + for i in 2:length(cset) + v = vars[i] + push!(eqs, base ~ v) + end end end end diff --git a/test/causal_variables_connection.jl b/test/causal_variables_connection.jl index eb922879e1..3124ac2f3b 100644 --- a/test/causal_variables_connection.jl +++ b/test/causal_variables_connection.jl @@ -36,13 +36,13 @@ end connect(C.output.u, P.input.u)] sys1 = System(eqs, t, systems = [P, C], name = :hej) sys = expand_connections(sys1) - @test any(isequal(P.output.u ~ C.input.u), equations(sys)) - @test any(isequal(C.output.u ~ P.input.u), equations(sys)) + @test any(isequal(C.input.u ~ P.output.u), equations(sys)) + @test any(isequal(P.input.u ~ C.output.u), equations(sys)) @named sysouter = System(Equation[], t; systems = [sys1]) sys = expand_connections(sysouter) - @test any(isequal(sys1.P.output.u ~ sys1.C.input.u), equations(sys)) - @test any(isequal(sys1.C.output.u ~ sys1.P.input.u), equations(sys)) + @test any(isequal(sys1.C.input.u ~ sys1.P.output.u), equations(sys)) + @test any(isequal(sys1.P.input.u ~ sys1.C.output.u), equations(sys)) end @testset "With Analysis Points" begin @@ -117,7 +117,7 @@ end @named sys = Outer() ss = toggle_namespacing(sys, false) eqs = equations(expand_connections(sys)) - @test issetequal(eqs, [ss.u ~ ss.inner.x + @test issetequal(eqs, [ss.inner.x ~ ss.u ss.inner.y ~ ss.inner.x - ss.inner.y ~ ss.v]) + ss.v ~ ss.inner.y]) end diff --git a/test/components.jl b/test/components.jl index 8e5747c750..a66725ca35 100644 --- a/test/components.jl +++ b/test/components.jl @@ -335,3 +335,25 @@ end sys = complete(outer) @test getmetadata(sys, Int, nothing) == "test" end + +@testset "Causal connections generate causal equations" begin + # test interpretation of `Equality` cset as causal connection + @named input = RealInput() + @named comp1 = System(Equation[], t; systems = [input]) + @named output = RealOutput() + @named comp2 = System(Equation[], t; systems = [output]) + @named sys = System([connect(comp2.output, comp1.input)], t; systems = [comp1, comp2]) + eq = only(equations(expand_connections(sys))) + # as opposed to `output.u ~ input.u` + @test isequal(eq, comp1.input.u ~ comp2.output.u) + + # test causal ordering of true causal cset + @named input = RealInput() + @named comp1 = System(Equation[], t; systems = [input]) + @named output = RealOutput() + @named comp2 = System(Equation[], t; systems = [output]) + @named sys = System([connect(comp2.output.u, comp1.input.u)], t; systems = [comp1, comp2]) + eq = only(equations(expand_connections(sys))) + # as opposed to `output.u ~ input.u` + @test isequal(eq, comp1.input.u ~ comp2.output.u) +end