Skip to content

Commit ef6c4bf

Browse files
authored
Merge pull request #1368 from SciML/myb/arr
Add array variable support for connect
2 parents 047dcfd + a04e02a commit ef6c4bf

File tree

3 files changed

+64
-17
lines changed

3 files changed

+64
-17
lines changed

src/systems/connectors.jl

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
get_connection_type(s) = getmetadata(unwrap(s), VariableConnectType, Equality)
2+
13
function with_connector_type(expr)
24
@assert expr isa Expr && (expr.head == :function || (expr.head == :(=) &&
35
expr.args[1] isa Expr &&
@@ -34,9 +36,12 @@ function connector_type(sys::AbstractSystem)
3436
n_stream = 0
3537
n_flow = 0
3638
for s in sts
37-
vtype = getmetadata(s, ModelingToolkit.VariableConnectType, nothing)
38-
vtype === Stream && (n_stream += 1)
39-
vtype === Flow && (n_flow += 1)
39+
vtype = get_connection_type(s)
40+
if vtype === Stream
41+
isarray(s) && error("Array stream variables are not supported. Got $s.")
42+
n_stream += 1
43+
end
44+
vtype === Flow && (n_flow += 1)
4045
end
4146
(n_stream > 0 && n_flow > 1) && error("There are multiple flow variables in $(nameof(sys))!")
4247
n_stream > 0 ? StreamConnector() : RegularConnector()
@@ -84,6 +89,7 @@ function connect(c::Connection; check=true)
8489
flow_eqs = Equation[]
8590
other_eqs = Equation[]
8691

92+
ncnts = length(inners) + length(outers)
8793
cnts = Iterators.flatten((inners, outers))
8894
fs, ss = Iterators.peel(cnts)
8995
splitting_idx = length(inners) # anything after the splitting_idx is outer.
@@ -94,26 +100,47 @@ function connect(c::Connection; check=true)
94100
Set(current_sts) == first_sts_set || error("$(nameof(sys)) ($current_sts) doesn't match the connection type of $(nameof(fs)) ($first_sts).")
95101
end
96102

103+
seen = Set()
97104
ceqs = Equation[]
98105
for s in first_sts
99106
name = getname(s)
100-
vtype = getmetadata(s, VariableConnectType, Equality)
107+
fix_val = getproperty(fs, name) # representative
108+
fix_val in seen && continue
109+
push!(seen, fix_val)
110+
111+
vtype = get_connection_type(fix_val)
101112
vtype === Stream && continue
102-
isflow = vtype === Flow
103-
rhs = 0 # only used for flow variables
104-
fix_val = getproperty(fs, name) # used for equality connections
105-
for (i, c) in enumerate(cnts)
106-
isinner = i <= splitting_idx
107-
# https://specification.modelica.org/v3.4/Ch15.html
108-
var = getproperty(c, name)
109-
if isflow
113+
114+
isarr = isarray(fix_val)
115+
116+
if vtype === Flow
117+
rhs = isarr ? zeros(Int, ncnts) : 0
118+
for (i, c) in enumerate(cnts)
119+
isinner = i <= splitting_idx
120+
# https://specification.modelica.org/v3.4/Ch15.html
121+
var = scalarize(getproperty(c, name))
110122
rhs += isinner ? var : -var
123+
end
124+
if isarr
125+
for r in rhs
126+
push!(ceqs, 0 ~ r)
127+
end
111128
else
112-
i == 1 && continue # skip the first iteration
113-
push!(ceqs, fix_val ~ getproperty(c, name))
129+
push!(ceqs, 0 ~ rhs)
130+
end
131+
else # Equality
132+
for c in ss
133+
var = getproperty(c, name)
134+
if isarr
135+
vs = scalarize(var)
136+
for (i, v) in enumerate(vs)
137+
push!(ceqs, fix_val[i] ~ v)
138+
end
139+
else
140+
push!(ceqs, fix_val ~ var)
141+
end
114142
end
115143
end
116-
isflow && push!(ceqs, 0 ~ rhs)
117144
end
118145

119146
return ceqs
@@ -144,7 +171,7 @@ end
144171
function flowvar(sys::AbstractSystem)
145172
sts = get_states(sys)
146173
for s in sts
147-
vtype = getmetadata(s, ModelingToolkit.VariableConnectType, nothing)
174+
vtype = get_connection_type(s)
148175
vtype === Flow && return s
149176
end
150177
error("There in no flow variable in $(nameof(sys))")
@@ -365,7 +392,7 @@ function expand_instream(instream_eqs, instream_exprs, connects; debug=false, to
365392
connector_representative = first(outer_sc)
366393
fv = flowvar(connector_representative)
367394
for sv in get_states(connector_representative)
368-
vtype = getmetadata(sv, ModelingToolkit.VariableConnectType, nothing)
395+
vtype = get_connection_type(sv)
369396
vtype === Stream || continue
370397
if n_inners == 1 && n_outers == 1
371398
innerstream = states(only(inner_sc), sv)

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,3 +371,5 @@ function find_duplicates(xs, ::Val{Ret}=Val(false)) where Ret
371371
end
372372
return Ret ? (appeared, duplicates) : duplicates
373373
end
374+
375+
isarray(x) = x isa AbstractArray || x isa Symbolics.Arr

test/stream_connectors.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,21 @@ eqns = [
215215
@named sys = ODESystem(eqns, t)
216216
@named n2m2Test = compose(sys, n2m2, source, sink)
217217
@test_nowarn structural_simplify(n2m2Test)
218+
219+
# array var
220+
@connector function VecPin(;name)
221+
sts = @variables v[1:2](t)=[1.0,0.0] i[1:2](t)=1.0 [connect = Flow]
222+
ODESystem(Equation[], t, [sts...;], []; name=name)
223+
end
224+
225+
@named vp1 = VecPin()
226+
@named vp2 = VecPin()
227+
228+
@named simple = ODESystem([connect(vp1, vp2)], t)
229+
sys = expand_connections(compose(simple, [vp1, vp2]))
230+
@test equations(sys) == [
231+
vp1.v[1] ~ vp2.v[1]
232+
vp1.v[2] ~ vp2.v[2]
233+
0 ~ -vp1.i[1] - vp2.i[1]
234+
0 ~ -vp1.i[2] - vp2.i[2]
235+
]

0 commit comments

Comments
 (0)