@@ -780,11 +780,11 @@ function generate_connection_equations_and_stream_connections(
780
780
inner_output = cvert
781
781
end
782
782
end
783
- root, rest = Iterators . peel (cset )
784
- root_var = variable_from_vertex (sys, root )
785
- for cvert in rest
786
- var = variable_from_vertex (sys, cvert)
787
- push! (eqs, root_var ~ var )
783
+ root_vert = something (inner_output, outer_input )
784
+ root_var = variable_from_vertex (sys, root_vert )
785
+ for cvert in cset
786
+ isequal (cvert, root_vert) && continue
787
+ push! (eqs, variable_from_vertex (sys, cvert) ~ root_var )
788
788
end
789
789
elseif vtype === Stream
790
790
push! (stream_connections, cset)
@@ -807,10 +807,37 @@ function generate_connection_equations_and_stream_connections(
807
807
push! (eqs, 0 ~ rhs)
808
808
end
809
809
else # Equality
810
- base = variable_from_vertex (sys, cset[1 ])
811
- for i in 2 : length (cset)
812
- v = variable_from_vertex (sys, cset[i])
813
- push! (eqs, base ~ v)
810
+ vars = map (Base. Fix1 (variable_from_vertex, sys), cset)
811
+ outer_input = inner_output = nothing
812
+ all_io = true
813
+ # attempt to interpret the equality as a causal connectionset if
814
+ # possible
815
+ for (cvert, vert) in zip (cset, vars)
816
+ is_i = isinput (vert)
817
+ is_o = isoutput (vert)
818
+ all_io &= is_i || is_o
819
+ all_io || break
820
+ if cvert. isouter && is_i && outer_input === nothing
821
+ outer_input = cvert
822
+ elseif ! cvert. isouter && is_o && inner_output === nothing
823
+ inner_output = cvert
824
+ end
825
+ end
826
+ # this doesn't necessarily mean this is a well-structured causal connection,
827
+ # but it is sufficient and we're generating equalities anyway.
828
+ if all_io && xor (outer_input != = nothing , inner_output != = nothing )
829
+ root_vert = something (inner_output, outer_input)
830
+ root_var = variable_from_vertex (sys, root_vert)
831
+ for (cvert, var) in zip (cset, vars)
832
+ isequal (cvert, root_vert) && continue
833
+ push! (eqs, var ~ root_var)
834
+ end
835
+ else
836
+ base = variable_from_vertex (sys, cset[1 ])
837
+ for i in 2 : length (cset)
838
+ v = vars[i]
839
+ push! (eqs, base ~ v)
840
+ end
814
841
end
815
842
end
816
843
end
0 commit comments