Skip to content

Commit e57e56f

Browse files
feat: add support for causal connections of variables
1 parent 083a639 commit e57e56f

File tree

3 files changed

+143
-2
lines changed

3 files changed

+143
-2
lines changed

src/systems/abstractsystem.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1875,6 +1875,13 @@ Equivalent to `length(equations(expand_connections(sys))) - length(filter(eq ->
18751875
function n_expanded_connection_equations(sys::AbstractSystem)
18761876
# TODO: what about inputs?
18771877
isconnector(sys) && return length(get_unknowns(sys))
1878+
sys = remove_analysis_points(sys)
1879+
n_variable_connect_eqs = 0
1880+
for eq in equations(sys)
1881+
is_causal_variable_connection(eq.rhs) || continue
1882+
n_variable_connect_eqs += length(get_systems(eq.rhs)) - 1
1883+
end
1884+
18781885
sys, (csets, _) = generate_connection_set(sys)
18791886
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
18801887
n_outer_stream_variables = 0
@@ -1897,7 +1904,7 @@ function n_expanded_connection_equations(sys::AbstractSystem)
18971904
# n_toplevel_unused_flows += count(x->get_connection_type(x) === Flow && !(x in toplevel_flows), get_unknowns(m))
18981905
#end
18991906

1900-
nextras = n_outer_stream_variables + length(ceqs)
1907+
nextras = n_outer_stream_variables + length(ceqs) + n_variable_connect_eqs
19011908
end
19021909

19031910
function Base.show(

src/systems/analysis_points.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@ function Symbolics.connect(in::AbstractSystem, name::Symbol, out, outs...; verbo
208208
return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose)
209209
end
210210

211+
function Symbolics.connect(in::ConnectableSymbolicT, name::Symbol, out::ConnectableSymbolicT, outs::ConnectableSymbolicT...; verbose = true)
212+
allvars = (in, out, outs...)
213+
validate_causal_variables_connection(allvars)
214+
return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose)
215+
end
216+
211217
"""
212218
$(TYPEDSIGNATURES)
213219
@@ -240,7 +246,7 @@ connection. This is the variable named `u` if present, and otherwise the only
240246
variable in the system. If the system does not have a variable named `u` and
241247
contains multiple variables, throw an error.
242248
"""
243-
function ap_var(sys)
249+
function ap_var(sys::AbstractSystem)
244250
if hasproperty(sys, :u)
245251
return sys.u
246252
end
@@ -249,6 +255,15 @@ function ap_var(sys)
249255
error("Could not determine the analysis-point variable in system $(nameof(sys)). To use an analysis point, apply it to a connection between causal blocks which have a variable named `u` or a single unknown of the same size.")
250256
end
251257

258+
"""
259+
$(TYPEDSIGNATURES)
260+
261+
For an `AnalysisPoint` involving causal variables. Simply return the variable.
262+
"""
263+
function ap_var(var::ConnectableSymbolicT)
264+
return var
265+
end
266+
252267
"""
253268
$(TYPEDEF)
254269

src/systems/connectors.jl

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,89 @@ SymbolicUtils.promote_symtype(::typeof(instream), _) = Real
6868

6969
isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) !== nothing
7070

71+
"""
72+
$(TYPEDEF)
73+
74+
Utility struct which wraps a symbolic variable used in a `Connection` to enable `Base.show`
75+
to work.
76+
"""
77+
struct SymbolicWithNameof
78+
var::Any
79+
end
80+
81+
function Base.nameof(x::SymbolicWithNameof)
82+
return Symbol(x.var)
83+
end
84+
85+
is_causal_variable_connection(c) = false
86+
is_causal_variable_connection(c::Connection) = all(x -> x isa SymbolicWithNameof, get_systems(c))
87+
88+
const ConnectableSymbolicT = Union{BasicSymbolic, Num, Symbolics.Arr}
89+
90+
const CAUSAL_CONNECTION_ERR = """
91+
Only causal variables can be used in a `connect` statement. The first argument must \
92+
be a single output variable and all subsequent variables must be input variables.
93+
"""
94+
95+
function VariableNotOutputError(var)
96+
ArgumentError("""
97+
$CAUSAL_CONNECTION_ERR Expected $var to be marked as an output with `[output = true]` \
98+
in the variable metadata.
99+
""")
100+
end
101+
102+
function VariableNotInputError(var)
103+
ArgumentError("""
104+
$CAUSAL_CONNECTION_ERR Expected $var to be marked an input with `[input = true]` \
105+
in the variable metadata.
106+
""")
107+
end
108+
109+
"""
110+
$(TYPEDSIGNATURES)
111+
112+
Perform validation for a connect statement involving causal variables.
113+
"""
114+
function validate_causal_variables_connection(allvars)
115+
var1 = allvars[1]
116+
var2 = allvars[2]
117+
vars = Base.tail(Base.tail(allvars))
118+
for var in allvars
119+
vtype = getvariabletype(var)
120+
vtype === VARIABLE || throw(ArgumentError("Expected $var to be of kind `$VARIABLE`. Got `$vtype`."))
121+
end
122+
if length(unique(allvars)) !== length(allvars)
123+
throw(ArgumentError("Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries."))
124+
end
125+
allsizes = map(size, allvars)
126+
if !allequal(allsizes)
127+
throw(ArgumentError("Expected all connection variables to have the same size. Got variables $allvars with sizes $allsizes respectively."))
128+
end
129+
isoutput(var1) || throw(VariableNotOutputError(var1))
130+
isinput(var2) || throw(VariableNotInputError(var2))
131+
for var in vars
132+
isinput(var) || throw(VariableNotInputError(var))
133+
end
134+
end
135+
136+
"""
137+
$(TYPEDSIGNATURES)
138+
139+
Connect multiple causal variables. The first variable must be an output, and all subsequent
140+
variables must be inputs. The statement `connect(var1, var2, var3, ...)` expands to:
141+
142+
```julia
143+
var1 ~ var2
144+
var1 ~ var3
145+
# ...
146+
```
147+
"""
148+
function Symbolics.connect(var1::ConnectableSymbolicT, var2::ConnectableSymbolicT, vars::ConnectableSymbolicT...)
149+
allvars = (var1, var2, vars...)
150+
validate_causal_variables_connection(allvars)
151+
return Equation(Connection(), Connection(map(SymbolicWithNameof, allvars)))
152+
end
153+
71154
function flowvar(sys::AbstractSystem)
72155
sts = get_unknowns(sys)
73156
for s in sts
@@ -329,6 +412,10 @@ function generate_connection_set!(connectionsets, domain_csets,
329412
for eq in eqs′
330413
lhs = eq.lhs
331414
rhs = eq.rhs
415+
416+
# causal variable connections will be expanded before we get here,
417+
# but this guard is useful for `n_expanded_connection_equations`.
418+
is_causal_variable_connection(rhs) && continue
332419
if find !== nothing && find(rhs, _getname(namespace))
333420
neweq, extra_unknown = replace(rhs, _getname(namespace))
334421
if extra_unknown isa AbstractArray
@@ -479,9 +566,41 @@ function domain_defaults(sys, domain_csets)
479566
def
480567
end
481568

569+
"""
570+
$(TYPEDSIGNATURES)
571+
572+
Recursively descend through the hierarchy of `sys` and expand all connection equations
573+
of causal variables. Return the modified system.
574+
"""
575+
function expand_variable_connections(sys::AbstractSystem)
576+
eqs = copy(get_eqs(sys))
577+
valid_idxs = trues(length(eqs))
578+
additional_eqs = Equation[]
579+
580+
for (i, eq) in enumerate(eqs)
581+
eq.lhs isa Connection || continue
582+
connection = eq.rhs
583+
elements = connection.systems
584+
is_causal_variable_connection(connection) || continue
585+
586+
valid_idxs[i] = false
587+
elements = map(x -> x.var, elements)
588+
outvar = first(elements)
589+
for invar in Iterators.drop(elements, 1)
590+
push!(additional_eqs, outvar ~ invar)
591+
end
592+
end
593+
eqs = [eqs[valid_idxs]; additional_eqs]
594+
subsystems = map(expand_variable_connections, get_systems(sys))
595+
@set! sys.eqs = eqs
596+
@set! sys.systems = subsystems
597+
return sys
598+
end
599+
482600
function expand_connections(sys::AbstractSystem, find = nothing, replace = nothing;
483601
debug = false, tol = 1e-10, scalarize = true)
484602
sys = remove_analysis_points(sys)
603+
sys = expand_variable_connections(sys)
485604
sys, (csets, domain_csets) = generate_connection_set(sys, find, replace; scalarize)
486605
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
487606
_sys = expand_instream(instream_csets, sys; debug = debug, tol = tol)

0 commit comments

Comments
 (0)