diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index b53d3ec098..99d38c7b33 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -291,7 +291,7 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc hasunit, getunit, hasconnect, getconnect, hasmisc, getmisc, state_priority export ode_order_lowering, dae_order_lowering, liouville_transform, - change_independent_variable + change_independent_variable, substitute_component export PDESystem export Differential, expand_derivatives, @derivatives export Equation, ConstrainedEquation diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index fb692b4028..299990951a 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -3181,3 +3181,188 @@ has_diff_eqs(osys21) # returns `false`. ``` """ has_diff_eqs(sys::AbstractSystem) = any(is_diff_equation, get_eqs(sys)) + +""" + $(TYPEDSIGNATURES) + +Validate the rules for replacement of subcomponents as defined in `substitute_component`. +""" +function validate_replacement_rule( + rule::Pair{T, T}; namespace = []) where {T <: AbstractSystem} + lhs, rhs = rule + + iscomplete(lhs) && throw(ArgumentError("LHS of replacement rule cannot be completed.")) + iscomplete(rhs) && throw(ArgumentError("RHS of replacement rule cannot be completed.")) + + rhs_h = namespace_hierarchy(nameof(rhs)) + if length(rhs_h) != 1 + throw(ArgumentError("RHS of replacement rule must not be namespaced.")) + end + rhs_h[1] == namespace_hierarchy(nameof(lhs))[end] || + throw(ArgumentError("LHS and RHS must have the same name.")) + + if !isequal(get_iv(lhs), get_iv(rhs)) + throw(ArgumentError("LHS and RHS of replacement rule must have the same independent variable.")) + end + + lhs_u = get_unknowns(lhs) + rhs_u = Dict(get_unknowns(rhs) .=> nothing) + for u in lhs_u + if !haskey(rhs_u, u) + if isempty(namespace) + throw(ArgumentError("RHS of replacement rule does not contain unknown $u.")) + else + throw(ArgumentError("Subsystem $(join([namespace; nameof(lhs)], NAMESPACE_SEPARATOR)) of RHS does not contain unknown $u.")) + end + end + ru = getkey(rhs_u, u, nothing) + name = join([namespace; nameof(lhs); (hasname(u) ? getname(u) : Symbol(u))], + NAMESPACE_SEPARATOR) + l_connect = something(getconnect(u), Equality) + r_connect = something(getconnect(ru), Equality) + if l_connect != r_connect + throw(ArgumentError("Variable $(name) should have connection metadata $(l_connect),")) + end + + l_input = isinput(u) + r_input = isinput(ru) + if l_input != r_input + throw(ArgumentError("Variable $name has differing causality. Marked as `input = $l_input` in LHS and `input = $r_input` in RHS.")) + end + l_output = isoutput(u) + r_output = isoutput(ru) + if l_output != r_output + throw(ArgumentError("Variable $name has differing causality. Marked as `output = $l_output` in LHS and `output = $r_output` in RHS.")) + end + end + + lhs_p = get_ps(lhs) + rhs_p = Set(get_ps(rhs)) + for p in lhs_p + if !(p in rhs_p) + if isempty(namespace) + throw(ArgumentError("RHS of replacement rule does not contain parameter $p")) + else + throw(ArgumentError("Subsystem $(join([namespace; nameof(lhs)], NAMESPACE_SEPARATOR)) of RHS does not contain parameter $p.")) + end + end + end + + lhs_s = get_systems(lhs) + rhs_s = Dict(nameof(s) => s for s in get_systems(rhs)) + + for s in lhs_s + if haskey(rhs_s, nameof(s)) + rs = rhs_s[nameof(s)] + if isconnector(s) + name = join([namespace; nameof(lhs); nameof(s)], NAMESPACE_SEPARATOR) + if !isconnector(rs) + throw(ArgumentError("Subsystem $name of RHS is not a connector.")) + end + if (lct = get_connector_type(s)) !== (rct = get_connector_type(rs)) + throw(ArgumentError("Subsystem $name of RHS has connection type $rct but LHS has $lct.")) + end + end + validate_replacement_rule(s => rs; namespace = [namespace; nameof(rhs)]) + continue + end + name1 = join([namespace; nameof(lhs)], NAMESPACE_SEPARATOR) + throw(ArgumentError("$name1 of replacement rule does not contain subsystem $(nameof(s)).")) + end +end + +""" + $(TYPEDSIGNATURES) + +Chain `getproperty` calls on `root` in the order given in `hierarchy`. + +# Keyword Arguments + +- `skip_namespace_first`: Whether to avoid namespacing in the first `getproperty` call. +""" +function recursive_getproperty( + root::AbstractSystem, hierarchy::Vector{Symbol}; skip_namespace_first = true) + cur = root + for (i, name) in enumerate(hierarchy) + cur = getproperty(cur, name; namespace = i > 1 || !skip_namespace_first) + end + return unwrap(cur) +end + +""" + $(TYPEDSIGNATURES) + +Recursively descend through `sys`, finding all connection equations and re-creating them +using the names of the involved variables/systems and finding the required variables/ +systems in the hierarchy. +""" +function recreate_connections(sys::AbstractSystem) + eqs = map(get_eqs(sys)) do eq + eq.lhs isa Union{Connection, AnalysisPoint} || return eq + if eq.lhs isa Connection + oldargs = get_systems(eq.rhs) + else + ap::AnalysisPoint = eq.rhs + oldargs = [ap.input; ap.outputs] + end + newargs = map(get_systems(eq.rhs)) do arg + rewrap_nameof = arg isa SymbolicWithNameof + if rewrap_nameof + arg = arg.var + end + name = arg isa AbstractSystem ? nameof(arg) : getname(arg) + hierarchy = namespace_hierarchy(name) + newarg = recursive_getproperty(sys, hierarchy) + if rewrap_nameof + newarg = SymbolicWithNameof(newarg) + end + return newarg + end + if eq.lhs isa Connection + return eq.lhs ~ Connection(newargs) + else + return eq.lhs ~ AnalysisPoint(newargs[1], eq.rhs.name, newargs[2:end]) + end + end + @set! sys.eqs = eqs + @set! sys.systems = map(recreate_connections, get_systems(sys)) + return sys +end + +""" + $(TYPEDSIGNATURES) + +Given a hierarchical system `sys` and a rule `lhs => rhs`, replace the subsystem `lhs` in +`sys` by `rhs`. The `lhs` must be the namespaced version of a subsystem of `sys` (e.g. +obtained via `sys.inner.component`). The `rhs` must be valid as per the following +conditions: + +1. `rhs` must not be namespaced. +2. The name of `rhs` must be the same as the unnamespaced name of `lhs`. +3. Neither one of `lhs` or `rhs` can be marked as complete. +4. Both `lhs` and `rhs` must share the same independent variable. +5. `rhs` must contain at least all of the unknowns and parameters present in + `lhs`. +6. Corresponding unknowns in `rhs` must share the same connection and causality + (input/output) metadata as their counterparts in `lhs`. +7. For each subsystem of `lhs`, there must be an identically named subsystem of `rhs`. + These two corresponding subsystems must satisfy conditions 3, 4, 5, 6, 7. If the + subsystem of `lhs` is a connector, the corresponding subsystem of `rhs` must also + be a connector of the same type. + +`sys` also cannot be marked as complete. +""" +function substitute_component(sys::T, rule::Pair{T, T}) where {T <: AbstractSystem} + iscomplete(sys) && + throw(ArgumentError("Cannot replace subsystems of completed systems")) + + validate_replacement_rule(rule) + + lhs, rhs = rule + hierarchy = namespace_hierarchy(nameof(lhs)) + + newsys, _ = modify_nested_subsystem(sys, hierarchy) do inner + return rhs, () + end + return recreate_connections(newsys) +end diff --git a/test/runtests.jl b/test/runtests.jl index 301134219e..04d99cc8b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -98,6 +98,7 @@ end @safetestset "Causal Variables Connection Test" include("causal_variables_connection.jl") @safetestset "Debugging Test" include("debugging.jl") @safetestset "Namespacing test" include("namespacing.jl") + @safetestset "Subsystem replacement" include("substitute_component.jl") end end diff --git a/test/substitute_component.jl b/test/substitute_component.jl new file mode 100644 index 0000000000..9fb254136b --- /dev/null +++ b/test/substitute_component.jl @@ -0,0 +1,273 @@ +using ModelingToolkit, ModelingToolkitStandardLibrary, Test +using ModelingToolkitStandardLibrary.Blocks +using ModelingToolkitStandardLibrary.Electrical +using OrdinaryDiffEq +using ModelingToolkit: t_nounits as t, D_nounits as D, renamespace, + NAMESPACE_SEPARATOR as NS + +@mtkmodel SignalInterface begin + @components begin + output = RealOutput() + end +end + +@mtkmodel TwoComponent begin + @components begin + component1 = OnePort() + component2 = OnePort() + source = Voltage() + signal = SignalInterface() + ground = Ground() + end + @equations begin + connect(signal.output.u, source.V.u) + connect(source.p, component1.p) + connect(component1.n, component2.p) + connect(component2.n, source.n, ground.g) + end +end + +@mtkmodel RC begin + @parameters begin + R = 1.0 + C = 1.0 + V = 1.0 + end + @components begin + component1 = Resistor(R = R) + component2 = Capacitor(C = C, v = 0.0) + source = Voltage() + constant = Constant(k = V) + ground = Ground() + end + @equations begin + connect(constant.output, source.V) + connect(source.p, component1.p) + connect(component1.n, component2.p) + connect(component2.n, source.n, ground.g) + end +end + +@testset "Replacement with connections works" begin + @named templated = TwoComponent() + @named component1 = Resistor(R = 1.0) + @named component2 = Capacitor(C = 1.0, v = 0.0) + @named signal = Constant(k = 1.0) + rsys = substitute_component(templated, templated.component1 => component1) + rcsys = substitute_component(rsys, rsys.component2 => component2) + rcsys = substitute_component(rcsys, rcsys.signal => signal) + + @named reference = RC() + + sys1 = structural_simplify(rcsys) + sys2 = structural_simplify(reference) + @test isequal(unknowns(sys1), unknowns(sys2)) + @test isequal(equations(sys1), equations(sys2)) + + prob1 = ODEProblem(sys1, [], (0.0, 10.0)) + prob2 = ODEProblem(sys2, [], (0.0, 10.0)) + + sol1 = solve(prob1, Tsit5()) + sol2 = solve(prob2, Tsit5(); saveat = sol1.t) + @test sol1.u≈sol2.u atol=1e-8 +end + +@mtkmodel BadOnePort1 begin + @components begin + p = Pin() + n = Pin() + end + @variables begin + i(t) + end + @equations begin + 0 ~ p.i + n.i + i ~ p.i + end +end + +@connector BadPin1 begin + v(t) +end + +@mtkmodel BadOnePort2 begin + @components begin + p = BadPin1() + n = Pin() + end + @variables begin + v(t) + i(t) + end + @equations begin + 0 ~ p.v + n.v + v ~ p.v + end +end + +@connector BadPin2 begin + v(t) + i(t) +end + +@mtkmodel BadOnePort3 begin + @components begin + p = BadPin2() + n = Pin() + end + @variables begin + v(t) + i(t) + end + @equations begin + 0 ~ p.v + n.v + v ~ p.v + end +end + +@connector BadPin3 begin + v(t), [input = true] + i(t), [connect = Flow] +end + +@mtkmodel BadOnePort4 begin + @components begin + p = BadPin3() + n = Pin() + end + @variables begin + v(t) + i(t) + end + @equations begin + 0 ~ p.v + n.v + v ~ p.v + end +end + +@connector BadPin4 begin + v(t), [output = true] + i(t), [connect = Flow] +end + +@mtkmodel BadOnePort5 begin + @components begin + p = BadPin4() + n = Pin() + end + @variables begin + v(t) + i(t) + end + @equations begin + 0 ~ p.v + n.v + v ~ p.v + end +end + +@mtkmodel BadPin5 begin + @variables begin + v(t) + i(t), [connect = Flow] + end +end + +@mtkmodel BadOnePort6 begin + @components begin + p = BadPin5() + n = Pin() + end + @variables begin + v(t) + i(t) + end + @equations begin + 0 ~ p.v + n.v + v ~ p.v + end +end + +@connector BadPin6 begin + i(t), [connect = Flow] +end + +@mtkmodel BadOnePort7 begin + @components begin + p = BadPin6() + n = Pin() + end + @variables begin + v(t) + i(t) + end + @equations begin + 0 ~ p.i + n.i + i ~ p.i + end +end + +@mtkmodel BadOnePort8 begin + @components begin + n = Pin() + end + @variables begin + v(t) + i(t) + end +end + +@testset "Error checking" begin + @named templated = TwoComponent() + @named component1 = Resistor(R = 1.0) + @named component2 = Capacitor(C = 1.0, v = 0.0) + @test_throws ["LHS", "cannot be completed"] substitute_component( + templated, complete(templated.component1) => component1) + @test_throws ["RHS", "cannot be completed"] substitute_component( + templated, templated.component1 => complete(component1)) + @test_throws ["RHS", "not be namespaced"] substitute_component( + templated, templated.component1 => renamespace(templated, component1)) + @named resistor = Resistor(R = 1.0) + @test_throws ["RHS", "same name"] substitute_component( + templated, templated.component1 => resistor) + + @testset "Different indepvar" begin + @independent_variables tt + @named empty = ODESystem(Equation[], t) + @named outer = ODESystem(Equation[], t; systems = [empty]) + @named empty = ODESystem(Equation[], tt) + @test_throws ["independent variable"] substitute_component( + outer, outer.empty => empty) + end + + @named component1 = BadOnePort1() + @test_throws ["RHS", "unknown", "v(t)"] substitute_component( + templated, templated.component1 => component1) + + @named component1 = BadOnePort2() + @test_throws ["component1$(NS)p", "i(t)"] substitute_component( + templated, templated.component1 => component1) + + @named component1 = BadOnePort3() + @test_throws ["component1$(NS)p$(NS)i", "Flow"] substitute_component( + templated, templated.component1 => component1) + + @named component1 = BadOnePort4() + @test_throws ["component1$(NS)p$(NS)v", "differing causality", "input"] substitute_component( + templated, templated.component1 => component1) + + @named component1 = BadOnePort5() + @test_throws ["component1$(NS)p$(NS)v", "differing causality", "output"] substitute_component( + templated, templated.component1 => component1) + + @named component1 = BadOnePort6() + @test_throws ["templated$(NS)component1$(NS)p", "not a connector"] substitute_component( + templated, templated.component1 => component1) + + @named component1 = BadOnePort7() + @test_throws ["templated$(NS)component1$(NS)p", "DomainConnector", "RegularConnector"] substitute_component( + templated, templated.component1 => component1) + + @named component1 = BadOnePort8() + @test_throws ["templated$(NS)component1", "subsystem p"] substitute_component( + templated, templated.component1 => component1) +end