Skip to content

Commit e591cd0

Browse files
feat: add support for causal connections of variables
1 parent 083a639 commit e591cd0

File tree

3 files changed

+149
-2
lines changed

3 files changed

+149
-2
lines changed

src/systems/abstractsystem.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1875,6 +1875,13 @@ Equivalent to `length(equations(expand_connections(sys))) - length(filter(eq ->
18751875
function n_expanded_connection_equations(sys::AbstractSystem)
18761876
# TODO: what about inputs?
18771877
isconnector(sys) && return length(get_unknowns(sys))
1878+
sys = remove_analysis_points(sys)
1879+
n_variable_connect_eqs = 0
1880+
for eq in equations(sys)
1881+
is_causal_variable_connection(eq.rhs) || continue
1882+
n_variable_connect_eqs += length(get_systems(eq.rhs)) - 1
1883+
end
1884+
18781885
sys, (csets, _) = generate_connection_set(sys)
18791886
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
18801887
n_outer_stream_variables = 0
@@ -1897,7 +1904,7 @@ function n_expanded_connection_equations(sys::AbstractSystem)
18971904
# n_toplevel_unused_flows += count(x->get_connection_type(x) === Flow && !(x in toplevel_flows), get_unknowns(m))
18981905
#end
18991906

1900-
nextras = n_outer_stream_variables + length(ceqs)
1907+
nextras = n_outer_stream_variables + length(ceqs) + n_variable_connect_eqs
19011908
end
19021909

19031910
function Base.show(

src/systems/analysis_points.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,14 @@ function Symbolics.connect(in::AbstractSystem, name::Symbol, out, outs...; verbo
208208
return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose)
209209
end
210210

211+
function Symbolics.connect(
212+
in::ConnectableSymbolicT, name::Symbol, out::ConnectableSymbolicT,
213+
outs::ConnectableSymbolicT...; verbose = true)
214+
allvars = (in, out, outs...)
215+
validate_causal_variables_connection(allvars)
216+
return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose)
217+
end
218+
211219
"""
212220
$(TYPEDSIGNATURES)
213221
@@ -240,7 +248,7 @@ connection. This is the variable named `u` if present, and otherwise the only
240248
variable in the system. If the system does not have a variable named `u` and
241249
contains multiple variables, throw an error.
242250
"""
243-
function ap_var(sys)
251+
function ap_var(sys::AbstractSystem)
244252
if hasproperty(sys, :u)
245253
return sys.u
246254
end
@@ -249,6 +257,15 @@ function ap_var(sys)
249257
error("Could not determine the analysis-point variable in system $(nameof(sys)). To use an analysis point, apply it to a connection between causal blocks which have a variable named `u` or a single unknown of the same size.")
250258
end
251259

260+
"""
261+
$(TYPEDSIGNATURES)
262+
263+
For an `AnalysisPoint` involving causal variables. Simply return the variable.
264+
"""
265+
function ap_var(var::ConnectableSymbolicT)
266+
return var
267+
end
268+
252269
"""
253270
$(TYPEDEF)
254271

src/systems/connectors.jl

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,93 @@ SymbolicUtils.promote_symtype(::typeof(instream), _) = Real
6868

6969
isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) !== nothing
7070

71+
"""
72+
$(TYPEDEF)
73+
74+
Utility struct which wraps a symbolic variable used in a `Connection` to enable `Base.show`
75+
to work.
76+
"""
77+
struct SymbolicWithNameof
78+
var::Any
79+
end
80+
81+
function Base.nameof(x::SymbolicWithNameof)
82+
return Symbol(x.var)
83+
end
84+
85+
is_causal_variable_connection(c) = false
86+
function is_causal_variable_connection(c::Connection)
87+
all(x -> x isa SymbolicWithNameof, get_systems(c))
88+
end
89+
90+
const ConnectableSymbolicT = Union{BasicSymbolic, Num, Symbolics.Arr}
91+
92+
const CAUSAL_CONNECTION_ERR = """
93+
Only causal variables can be used in a `connect` statement. The first argument must \
94+
be a single output variable and all subsequent variables must be input variables.
95+
"""
96+
97+
function VariableNotOutputError(var)
98+
ArgumentError("""
99+
$CAUSAL_CONNECTION_ERR Expected $var to be marked as an output with `[output = true]` \
100+
in the variable metadata.
101+
""")
102+
end
103+
104+
function VariableNotInputError(var)
105+
ArgumentError("""
106+
$CAUSAL_CONNECTION_ERR Expected $var to be marked an input with `[input = true]` \
107+
in the variable metadata.
108+
""")
109+
end
110+
111+
"""
112+
$(TYPEDSIGNATURES)
113+
114+
Perform validation for a connect statement involving causal variables.
115+
"""
116+
function validate_causal_variables_connection(allvars)
117+
var1 = allvars[1]
118+
var2 = allvars[2]
119+
vars = Base.tail(Base.tail(allvars))
120+
for var in allvars
121+
vtype = getvariabletype(var)
122+
vtype === VARIABLE ||
123+
throw(ArgumentError("Expected $var to be of kind `$VARIABLE`. Got `$vtype`."))
124+
end
125+
if length(unique(allvars)) !== length(allvars)
126+
throw(ArgumentError("Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries."))
127+
end
128+
allsizes = map(size, allvars)
129+
if !allequal(allsizes)
130+
throw(ArgumentError("Expected all connection variables to have the same size. Got variables $allvars with sizes $allsizes respectively."))
131+
end
132+
isoutput(var1) || throw(VariableNotOutputError(var1))
133+
isinput(var2) || throw(VariableNotInputError(var2))
134+
for var in vars
135+
isinput(var) || throw(VariableNotInputError(var))
136+
end
137+
end
138+
139+
"""
140+
$(TYPEDSIGNATURES)
141+
142+
Connect multiple causal variables. The first variable must be an output, and all subsequent
143+
variables must be inputs. The statement `connect(var1, var2, var3, ...)` expands to:
144+
145+
```julia
146+
var1 ~ var2
147+
var1 ~ var3
148+
# ...
149+
```
150+
"""
151+
function Symbolics.connect(var1::ConnectableSymbolicT, var2::ConnectableSymbolicT,
152+
vars::ConnectableSymbolicT...)
153+
allvars = (var1, var2, vars...)
154+
validate_causal_variables_connection(allvars)
155+
return Equation(Connection(), Connection(map(SymbolicWithNameof, allvars)))
156+
end
157+
71158
function flowvar(sys::AbstractSystem)
72159
sts = get_unknowns(sys)
73160
for s in sts
@@ -329,6 +416,10 @@ function generate_connection_set!(connectionsets, domain_csets,
329416
for eq in eqs′
330417
lhs = eq.lhs
331418
rhs = eq.rhs
419+
420+
# causal variable connections will be expanded before we get here,
421+
# but this guard is useful for `n_expanded_connection_equations`.
422+
is_causal_variable_connection(rhs) && continue
332423
if find !== nothing && find(rhs, _getname(namespace))
333424
neweq, extra_unknown = replace(rhs, _getname(namespace))
334425
if extra_unknown isa AbstractArray
@@ -479,9 +570,41 @@ function domain_defaults(sys, domain_csets)
479570
def
480571
end
481572

573+
"""
574+
$(TYPEDSIGNATURES)
575+
576+
Recursively descend through the hierarchy of `sys` and expand all connection equations
577+
of causal variables. Return the modified system.
578+
"""
579+
function expand_variable_connections(sys::AbstractSystem)
580+
eqs = copy(get_eqs(sys))
581+
valid_idxs = trues(length(eqs))
582+
additional_eqs = Equation[]
583+
584+
for (i, eq) in enumerate(eqs)
585+
eq.lhs isa Connection || continue
586+
connection = eq.rhs
587+
elements = connection.systems
588+
is_causal_variable_connection(connection) || continue
589+
590+
valid_idxs[i] = false
591+
elements = map(x -> x.var, elements)
592+
outvar = first(elements)
593+
for invar in Iterators.drop(elements, 1)
594+
push!(additional_eqs, outvar ~ invar)
595+
end
596+
end
597+
eqs = [eqs[valid_idxs]; additional_eqs]
598+
subsystems = map(expand_variable_connections, get_systems(sys))
599+
@set! sys.eqs = eqs
600+
@set! sys.systems = subsystems
601+
return sys
602+
end
603+
482604
function expand_connections(sys::AbstractSystem, find = nothing, replace = nothing;
483605
debug = false, tol = 1e-10, scalarize = true)
484606
sys = remove_analysis_points(sys)
607+
sys = expand_variable_connections(sys)
485608
sys, (csets, domain_csets) = generate_connection_set(sys, find, replace; scalarize)
486609
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
487610
_sys = expand_instream(instream_csets, sys; debug = debug, tol = tol)

0 commit comments

Comments
 (0)