Skip to content

Commit 97a0cb8

Browse files
authored
Merge pull request #2100 from SciML/fb/frame_connector
RFC: handle multibody frames in `connection2set!`
2 parents 4bfdf0b + 9c3278f commit 97a0cb8

File tree

1 file changed

+31
-3
lines changed

1 file changed

+31
-3
lines changed

src/systems/connectors.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,42 @@ end
181181
error("Different types of connectors are in one connection statement: <$(map(nameof, ss))>")
182182
end
183183

184+
"Return true if the system is a 3D multibody frame, otherwise return false."
185+
function isframe(sys)
186+
sys.metadata isa Dict || return false
187+
get(sys.metadata, :frame, false)
188+
end
189+
190+
"Return orienation object of a multibody frame."
191+
function ori(sys)
192+
if sys.metadata isa Dict && (O = get(sys.metadata, :orientation, nothing)) !== nothing
193+
return O
194+
else
195+
error("System $(sys.name) does not have an orientation object.")
196+
end
197+
end
198+
184199
function connection2set!(connectionsets, namespace, ss, isouter)
185200
nn = map(nameof, ss)
186-
sts1 = Set(states(first(ss)))
201+
s1 = first(ss)
202+
sts1v = states(s1)
203+
if isframe(s1) # Multibody
204+
O = ori(s1)
205+
orientation_vars = Symbolics.unwrap.(collect(vec(O.R)))
206+
sts1v = [sts1v; orientation_vars]
207+
end
208+
sts1 = Set(sts1v)
187209
T = ConnectionElement
188-
csets = [T[] for _ in 1:length(sts1)]
210+
num_statevars = length(sts1)
211+
csets = [T[] for _ in 1:num_statevars] # Add 9 orientation variables if connection is between multibody frames
189212
for (i, s) in enumerate(ss)
190213
sts = states(s)
191-
i != 1 && ((length(sts1) == length(sts) && all(Base.Fix2(in, sts1), sts)) ||
214+
if isframe(s) # Multibody
215+
O = ori(s)
216+
orientation_vars = Symbolics.unwrap.(vec(O.R))
217+
sts = [sts; orientation_vars]
218+
end
219+
i != 1 && ((num_statevars == length(sts) && all(Base.Fix2(in, sts1), sts)) ||
192220
connection_error(ss))
193221
io = isouter(s)
194222
for (j, v) in enumerate(sts)

0 commit comments

Comments
 (0)