diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index b8268e884e..73f47c9c64 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -322,14 +322,30 @@ function change_independent_variable( return ex::T end + function transform(rhs::Connection, systems_map) + new_syss = map(rhs.systems) do sys + # in the case of connect(a, b.c.d.e), our systems_map will hold a and b. + # Therefore, in the case of b.c.d.e, we will need to split the key into + # name = b and subnames = [c, d, e] to recreate the connection of the new systems + sname = string(getname(sys)) + parts = split(sname, NAMESPACE_SEPARATOR) + name, subnames = parts[1], parts[2:end] + if !(Symbol(name) in keys(systems_map)) + error("The system $name was not found in the systems map.") + end + new_sys = systems_map[Symbol(name)] + for sub in subnames + new_sys = getproperty(new_sys, Symbol(sub)) + end + return new_sys + end + return connect(new_syss...) + end + # overload to specifically handle equations, which can be an equation or a connection function transform(eq::Equation, systems_map) - if eq.rhs isa Connection - eq = connect((systems_map[nameof(s)] for s in eq.rhs.systems)...) - else - eq = transform(eq) - end - return eq::Equation + new_eq = eq.rhs isa Connection ? transform(eq.rhs, systems_map) : transform(eq) + return new_eq::Equation end # Use the utility function to transform everything in the system! diff --git a/test/basic_transformations.jl b/test/basic_transformations.jl index bafb5cf9e2..f813cb876a 100644 --- a/test/basic_transformations.jl +++ b/test/basic_transformations.jl @@ -264,10 +264,22 @@ end end @testset "Change of variables, connections" begin + @mtkmodel NestedConnect begin + @components begin + out = RealOutput() + end + end + @mtkmodel DoubleNestedConnect begin + @components begin + nested = NestedConnect() + end + end @mtkmodel ConnectSys begin @components begin in = RealInput() out = RealOutput() + nested = NestedConnect() + double_nested = DoubleNestedConnect() end @variables begin x(t) @@ -275,8 +287,10 @@ end end @equations begin connect(in, out) + connect(in, nested.out) + connect(in, double_nested.nested.out) in.u ~ x - D(x) ~ -out.u + D(x) ~ -double_nested.nested.out.u D(y) ~ 1 end end