@@ -68,6 +68,89 @@ SymbolicUtils.promote_symtype(::typeof(instream), _) = Real
6868
6969isconnector (s:: AbstractSystem ) = has_connector_type (s) && get_connector_type (s) != = nothing
7070
71+ """
72+ $(TYPEDEF)
73+
74+ Utility struct which wraps a symbolic variable used in a `Connection` to enable `Base.show`
75+ to work.
76+ """
77+ struct SymbolicWithNameof
78+ var:: Any
79+ end
80+
81+ function Base. nameof (x:: SymbolicWithNameof )
82+ return Symbol (x. var)
83+ end
84+
85+ is_causal_variable_connection (c) = false
86+ is_causal_variable_connection (c:: Connection ) = all (x -> x isa SymbolicWithNameof, get_systems (c))
87+
88+ const ConnectableSymbolicT = Union{BasicSymbolic, Num, Symbolics. Arr}
89+
90+ const CAUSAL_CONNECTION_ERR = """
91+ Only causal variables can be used in a `connect` statement. The first argument must \
92+ be a single output variable and all subsequent variables must be input variables.
93+ """
94+
95+ function VariableNotOutputError (var)
96+ ArgumentError ("""
97+ $CAUSAL_CONNECTION_ERR Expected $var to be marked as an output with `[output = true]` \
98+ in the variable metadata.
99+ """ )
100+ end
101+
102+ function VariableNotInputError (var)
103+ ArgumentError ("""
104+ $CAUSAL_CONNECTION_ERR Expected $var to be marked an input with `[input = true]` \
105+ in the variable metadata.
106+ """ )
107+ end
108+
109+ """
110+ $(TYPEDSIGNATURES)
111+
112+ Perform validation for a connect statement involving causal variables.
113+ """
114+ function validate_causal_variables_connection (allvars)
115+ var1 = allvars[1 ]
116+ var2 = allvars[2 ]
117+ vars = Base. tail (Base. tail (allvars))
118+ for var in allvars
119+ vtype = getvariabletype (var)
120+ vtype === VARIABLE || throw (ArgumentError (" Expected $var to be of kind `$VARIABLE `. Got `$vtype `." ))
121+ end
122+ if length (unique (allvars)) != = length (allvars)
123+ throw (ArgumentError (" Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries." ))
124+ end
125+ allsizes = map (size, allvars)
126+ if ! allequal (allsizes)
127+ throw (ArgumentError (" Expected all connection variables to have the same size. Got variables $allvars with sizes $allsizes respectively." ))
128+ end
129+ isoutput (var1) || throw (VariableNotOutputError (var1))
130+ isinput (var2) || throw (VariableNotInputError (var2))
131+ for var in vars
132+ isinput (var) || throw (VariableNotInputError (var))
133+ end
134+ end
135+
136+ """
137+ $(TYPEDSIGNATURES)
138+
139+ Connect multiple causal variables. The first variable must be an output, and all subsequent
140+ variables must be inputs. The statement `connect(var1, var2, var3, ...)` expands to:
141+
142+ ```julia
143+ var1 ~ var2
144+ var1 ~ var3
145+ # ...
146+ ```
147+ """
148+ function Symbolics. connect (var1:: ConnectableSymbolicT , var2:: ConnectableSymbolicT , vars:: ConnectableSymbolicT... )
149+ allvars = (var1, var2, vars... )
150+ validate_causal_variables_connection (allvars)
151+ return Equation (Connection (), Connection (map (SymbolicWithNameof, allvars)))
152+ end
153+
71154function flowvar (sys:: AbstractSystem )
72155 sts = get_unknowns (sys)
73156 for s in sts
@@ -329,6 +412,10 @@ function generate_connection_set!(connectionsets, domain_csets,
329412 for eq in eqs′
330413 lhs = eq. lhs
331414 rhs = eq. rhs
415+
416+ # causal variable connections will be expanded before we get here,
417+ # but this guard is useful for `n_expanded_connection_equations`.
418+ is_causal_variable_connection (rhs) && continue
332419 if find != = nothing && find (rhs, _getname (namespace))
333420 neweq, extra_unknown = replace (rhs, _getname (namespace))
334421 if extra_unknown isa AbstractArray
@@ -479,9 +566,41 @@ function domain_defaults(sys, domain_csets)
479566 def
480567end
481568
569+ """
570+ $(TYPEDSIGNATURES)
571+
572+ Recursively descend through the hierarchy of `sys` and expand all connection equations
573+ of causal variables. Return the modified system.
574+ """
575+ function expand_variable_connections (sys:: AbstractSystem )
576+ eqs = copy (get_eqs (sys))
577+ valid_idxs = trues (length (eqs))
578+ additional_eqs = Equation[]
579+
580+ for (i, eq) in enumerate (eqs)
581+ eq. lhs isa Connection || continue
582+ connection = eq. rhs
583+ elements = connection. systems
584+ is_causal_variable_connection (connection) || continue
585+
586+ valid_idxs[i] = false
587+ elements = map (x -> x. var, elements)
588+ outvar = first (elements)
589+ for invar in Iterators. drop (elements, 1 )
590+ push! (additional_eqs, outvar ~ invar)
591+ end
592+ end
593+ eqs = [eqs[valid_idxs]; additional_eqs]
594+ subsystems = map (expand_variable_connections, get_systems (sys))
595+ @set! sys. eqs = eqs
596+ @set! sys. systems = subsystems
597+ return sys
598+ end
599+
482600function expand_connections (sys:: AbstractSystem , find = nothing , replace = nothing ;
483601 debug = false , tol = 1e-10 , scalarize = true )
484602 sys = remove_analysis_points (sys)
603+ sys = expand_variable_connections (sys)
485604 sys, (csets, domain_csets) = generate_connection_set (sys, find, replace; scalarize)
486605 ceqs, instream_csets = generate_connection_equations_and_stream_connections (csets)
487606 _sys = expand_instream (instream_csets, sys; debug = debug, tol = tol)
0 commit comments