Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ function generate_connection_equations_and_stream_connections(
var = variable_from_vertex(sys, cvert)::BasicSymbolic
vtype = cvert.type
if vtype <: Union{InputVar, OutputVar}
length(cset) > 1 || continue
inner_output = nothing
outer_input = nothing
for cvert in cset
Expand All @@ -780,11 +781,11 @@ function generate_connection_equations_and_stream_connections(
inner_output = cvert
end
end
root, rest = Iterators.peel(cset)
root_var = variable_from_vertex(sys, root)
for cvert in rest
var = variable_from_vertex(sys, cvert)
push!(eqs, root_var ~ var)
root_vert = something(inner_output, outer_input)
root_var = variable_from_vertex(sys, root_vert)
for cvert in cset
isequal(cvert, root_vert) && continue
push!(eqs, variable_from_vertex(sys, cvert) ~ root_var)
end
elseif vtype === Stream
push!(stream_connections, cset)
Expand All @@ -807,10 +808,37 @@ function generate_connection_equations_and_stream_connections(
push!(eqs, 0 ~ rhs)
end
else # Equality
base = variable_from_vertex(sys, cset[1])
for i in 2:length(cset)
v = variable_from_vertex(sys, cset[i])
push!(eqs, base ~ v)
vars = map(Base.Fix1(variable_from_vertex, sys), cset)
outer_input = inner_output = nothing
all_io = true
# attempt to interpret the equality as a causal connectionset if
# possible
for (cvert, vert) in zip(cset, vars)
is_i = isinput(vert)
is_o = isoutput(vert)
all_io &= is_i || is_o
all_io || break
if cvert.isouter && is_i && outer_input === nothing
outer_input = cvert
elseif !cvert.isouter && is_o && inner_output === nothing
inner_output = cvert
end
end
# this doesn't necessarily mean this is a well-structured causal connection,
# but it is sufficient and we're generating equalities anyway.
if all_io && xor(outer_input !== nothing, inner_output !== nothing)
root_vert = something(inner_output, outer_input)
root_var = variable_from_vertex(sys, root_vert)
for (cvert, var) in zip(cset, vars)
isequal(cvert, root_vert) && continue
push!(eqs, var ~ root_var)
end
else
base = variable_from_vertex(sys, cset[1])
for i in 2:length(cset)
v = vars[i]
push!(eqs, base ~ v)
end
end
end
end
Expand Down
12 changes: 6 additions & 6 deletions test/causal_variables_connection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ end
connect(C.output.u, P.input.u)]
sys1 = System(eqs, t, systems = [P, C], name = :hej)
sys = expand_connections(sys1)
@test any(isequal(P.output.u ~ C.input.u), equations(sys))
@test any(isequal(C.output.u ~ P.input.u), equations(sys))
@test any(isequal(C.input.u ~ P.output.u), equations(sys))
@test any(isequal(P.input.u ~ C.output.u), equations(sys))

@named sysouter = System(Equation[], t; systems = [sys1])
sys = expand_connections(sysouter)
@test any(isequal(sys1.P.output.u ~ sys1.C.input.u), equations(sys))
@test any(isequal(sys1.C.output.u ~ sys1.P.input.u), equations(sys))
@test any(isequal(sys1.C.input.u ~ sys1.P.output.u), equations(sys))
@test any(isequal(sys1.P.input.u ~ sys1.C.output.u), equations(sys))
end

@testset "With Analysis Points" begin
Expand Down Expand Up @@ -117,7 +117,7 @@ end
@named sys = Outer()
ss = toggle_namespacing(sys, false)
eqs = equations(expand_connections(sys))
@test issetequal(eqs, [ss.u ~ ss.inner.x
@test issetequal(eqs, [ss.inner.x ~ ss.u
ss.inner.y ~ ss.inner.x
ss.inner.y ~ ss.v])
ss.v ~ ss.inner.y])
end
22 changes: 22 additions & 0 deletions test/components.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,25 @@ end
sys = complete(outer)
@test getmetadata(sys, Int, nothing) == "test"
end

@testset "Causal connections generate causal equations" begin
# test interpretation of `Equality` cset as causal connection
@named input = RealInput()
@named comp1 = System(Equation[], t; systems = [input])
@named output = RealOutput()
@named comp2 = System(Equation[], t; systems = [output])
@named sys = System([connect(comp2.output, comp1.input)], t; systems = [comp1, comp2])
eq = only(equations(expand_connections(sys)))
# as opposed to `output.u ~ input.u`
@test isequal(eq, comp1.input.u ~ comp2.output.u)

# test causal ordering of true causal cset
@named input = RealInput()
@named comp1 = System(Equation[], t; systems = [input])
@named output = RealOutput()
@named comp2 = System(Equation[], t; systems = [output])
@named sys = System([connect(comp2.output.u, comp1.input.u)], t; systems = [comp1, comp2])
eq = only(equations(expand_connections(sys)))
# as opposed to `output.u ~ input.u`
@test isequal(eq, comp1.input.u ~ comp2.output.u)
end
Loading