@@ -68,6 +68,93 @@ SymbolicUtils.promote_symtype(::typeof(instream), _) = Real
6868
6969isconnector (s:: AbstractSystem ) = has_connector_type (s) && get_connector_type (s) != = nothing
7070
71+ """
72+ $(TYPEDEF)
73+
74+ Utility struct which wraps a symbolic variable used in a `Connection` to enable `Base.show`
75+ to work.
76+ """
77+ struct SymbolicWithNameof
78+ var:: Any
79+ end
80+
81+ function Base. nameof (x:: SymbolicWithNameof )
82+ return Symbol (x. var)
83+ end
84+
85+ is_causal_variable_connection (c) = false
86+ function is_causal_variable_connection (c:: Connection )
87+ all (x -> x isa SymbolicWithNameof, get_systems (c))
88+ end
89+
90+ const ConnectableSymbolicT = Union{BasicSymbolic, Num, Symbolics. Arr}
91+
92+ const CAUSAL_CONNECTION_ERR = """
93+ Only causal variables can be used in a `connect` statement. The first argument must \
94+ be a single output variable and all subsequent variables must be input variables.
95+ """
96+
97+ function VariableNotOutputError (var)
98+ ArgumentError ("""
99+ $CAUSAL_CONNECTION_ERR Expected $var to be marked as an output with `[output = true]` \
100+ in the variable metadata.
101+ """ )
102+ end
103+
104+ function VariableNotInputError (var)
105+ ArgumentError ("""
106+ $CAUSAL_CONNECTION_ERR Expected $var to be marked an input with `[input = true]` \
107+ in the variable metadata.
108+ """ )
109+ end
110+
111+ """
112+ $(TYPEDSIGNATURES)
113+
114+ Perform validation for a connect statement involving causal variables.
115+ """
116+ function validate_causal_variables_connection (allvars)
117+ var1 = allvars[1 ]
118+ var2 = allvars[2 ]
119+ vars = Base. tail (Base. tail (allvars))
120+ for var in allvars
121+ vtype = getvariabletype (var)
122+ vtype === VARIABLE ||
123+ throw (ArgumentError (" Expected $var to be of kind `$VARIABLE `. Got `$vtype `." ))
124+ end
125+ if length (unique (allvars)) != = length (allvars)
126+ throw (ArgumentError (" Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries." ))
127+ end
128+ allsizes = map (size, allvars)
129+ if ! allequal (allsizes)
130+ throw (ArgumentError (" Expected all connection variables to have the same size. Got variables $allvars with sizes $allsizes respectively." ))
131+ end
132+ isoutput (var1) || throw (VariableNotOutputError (var1))
133+ isinput (var2) || throw (VariableNotInputError (var2))
134+ for var in vars
135+ isinput (var) || throw (VariableNotInputError (var))
136+ end
137+ end
138+
139+ """
140+ $(TYPEDSIGNATURES)
141+
142+ Connect multiple causal variables. The first variable must be an output, and all subsequent
143+ variables must be inputs. The statement `connect(var1, var2, var3, ...)` expands to:
144+
145+ ```julia
146+ var1 ~ var2
147+ var1 ~ var3
148+ # ...
149+ ```
150+ """
151+ function Symbolics. connect (var1:: ConnectableSymbolicT , var2:: ConnectableSymbolicT ,
152+ vars:: ConnectableSymbolicT... )
153+ allvars = (var1, var2, vars... )
154+ validate_causal_variables_connection (allvars)
155+ return Equation (Connection (), Connection (map (SymbolicWithNameof, allvars)))
156+ end
157+
71158function flowvar (sys:: AbstractSystem )
72159 sts = get_unknowns (sys)
73160 for s in sts
@@ -329,6 +416,10 @@ function generate_connection_set!(connectionsets, domain_csets,
329416 for eq in eqs′
330417 lhs = eq. lhs
331418 rhs = eq. rhs
419+
420+ # causal variable connections will be expanded before we get here,
421+ # but this guard is useful for `n_expanded_connection_equations`.
422+ is_causal_variable_connection (rhs) && continue
332423 if find != = nothing && find (rhs, _getname (namespace))
333424 neweq, extra_unknown = replace (rhs, _getname (namespace))
334425 if extra_unknown isa AbstractArray
@@ -479,9 +570,41 @@ function domain_defaults(sys, domain_csets)
479570 def
480571end
481572
573+ """
574+ $(TYPEDSIGNATURES)
575+
576+ Recursively descend through the hierarchy of `sys` and expand all connection equations
577+ of causal variables. Return the modified system.
578+ """
579+ function expand_variable_connections (sys:: AbstractSystem )
580+ eqs = copy (get_eqs (sys))
581+ valid_idxs = trues (length (eqs))
582+ additional_eqs = Equation[]
583+
584+ for (i, eq) in enumerate (eqs)
585+ eq. lhs isa Connection || continue
586+ connection = eq. rhs
587+ elements = connection. systems
588+ is_causal_variable_connection (connection) || continue
589+
590+ valid_idxs[i] = false
591+ elements = map (x -> x. var, elements)
592+ outvar = first (elements)
593+ for invar in Iterators. drop (elements, 1 )
594+ push! (additional_eqs, outvar ~ invar)
595+ end
596+ end
597+ eqs = [eqs[valid_idxs]; additional_eqs]
598+ subsystems = map (expand_variable_connections, get_systems (sys))
599+ @set! sys. eqs = eqs
600+ @set! sys. systems = subsystems
601+ return sys
602+ end
603+
482604function expand_connections (sys:: AbstractSystem , find = nothing , replace = nothing ;
483605 debug = false , tol = 1e-10 , scalarize = true )
484606 sys = remove_analysis_points (sys)
607+ sys = expand_variable_connections (sys)
485608 sys, (csets, domain_csets) = generate_connection_set (sys, find, replace; scalarize)
486609 ceqs, instream_csets = generate_connection_equations_and_stream_connections (csets)
487610 _sys = expand_instream (instream_csets, sys; debug = debug, tol = tol)
0 commit comments