Skip to content

Commit 365f5ed

Browse files
committed
handle connections in change_independent_variable
1 parent fafb47b commit 365f5ed

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

src/systems/diffeqs/basic_transformations.jl

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ function liouville_transform(sys::AbstractODESystem; kwargs...)
5353
)
5454
end
5555

56+
function split_eqs_connections(eqs_in::Vector{<:Equation})
57+
eqs = Equation[]
58+
cons = Equation[]
59+
60+
for eq in eqs_in
61+
eq.lhs isa Connection ? push!(cons, eq) : push!(eqs, eq)
62+
end
63+
64+
return eqs, cons
65+
end
66+
5667
"""
5768
change_independent_variable(
5869
sys::AbstractODESystem, iv, eqs = [];
@@ -158,7 +169,7 @@ function change_independent_variable(
158169
function transform(ex::T) where {T}
159170
# 1) Replace the argument of every function; e.g. f(t) -> f(u(t))
160171
for var in vars(ex; op = Nothing) # loop over all variables in expression (op = Nothing prevents interpreting "D(f(t))" as one big variable)
161-
is_function_of_iv1 = iscall(var) && isequal(only(arguments(var)), iv1) # of the form f(t)?
172+
is_function_of_iv1 = iscall(var) && isequal(first(arguments(var)), iv1) # of the form f(t)?
162173
if is_function_of_iv1 && !isequal(var, iv2_of_iv1) # prevent e.g. u(t) -> u(u(t))
163174
var_of_iv1 = var # e.g. f(t)
164175
var_of_iv2_of_iv1 = substitute(var_of_iv1, iv1 => iv2_of_iv1) # e.g. f(u(t))
@@ -178,9 +189,22 @@ function change_independent_variable(
178189
return ex::T
179190
end
180191

192+
# overload to specifically handle equations, which can be an equation of a connection
193+
function transform(eq::Equation, systems_map)
194+
if eq.rhs isa Connection
195+
eq = connect((systems_map[nameof(s)] for s in eq.rhs.systems)...)
196+
else
197+
eq = transform(eq)
198+
end
199+
return eq::Equation
200+
end
201+
181202
# Use the utility function to transform everything in the system!
182203
function transform(sys::AbstractODESystem)
183-
eqs = map(transform, get_eqs(sys))
204+
systems = map(transform, get_systems(sys)) # recurse through subsystems
205+
# transform equations and connections
206+
systems_map = Dict(get_name(s) => s for s in systems)
207+
eqs = map(eq -> transform(eq, systems_map)::Equation, get_eqs(sys))
184208
unknowns = map(transform, get_unknowns(sys))
185209
unknowns = filter(var -> !isequal(var, iv2), unknowns) # remove e.g. u
186210
ps = map(transform, get_ps(sys))
@@ -191,19 +215,19 @@ function change_independent_variable(
191215
defaults = Dict(transform(var) => transform(val)
192216
for (var, val) in get_defaults(sys))
193217
guesses = Dict(transform(var) => transform(val) for (var, val) in get_guesses(sys))
218+
connector_type = get_connector_type(sys)
194219
assertions = Dict(transform(ass) => msg for (ass, msg) in get_assertions(sys))
195-
systems = get_systems(sys) # save before reconstructing system
196220
wascomplete = iscomplete(sys) # save before reconstructing system
197221
sys = typeof(sys)( # recreate system with transformed fields
198222
eqs, iv2, unknowns, ps; observed, initialization_eqs,
199-
parameter_dependencies, defaults, guesses,
223+
parameter_dependencies, defaults, guesses, connector_type,
200224
assertions, name = nameof(sys), description = description(sys)
201225
)
202-
systems = map(transform, systems) # recurse through subsystems
203226
sys = compose(sys, systems) # rebuild hierarchical system
204227
if wascomplete
205228
wasflat = isempty(systems)
206-
sys = complete(sys; flatten = wasflat) # complete output if input was complete
229+
wassplit = is_split(sys)
230+
sys = complete(sys; split = wassplit, flatten = wasflat) # complete output if input was complete
207231
end
208232
return sys
209233
end

test/basic_transformations.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ModelingToolkit, OrdinaryDiffEq, DataInterpolations, DynamicQuantities, Test
2+
using ModelingToolkitStandardLibrary.Blocks: RealInput, RealOutput
23

34
@independent_variables t
45
D = Differential(t)
@@ -259,3 +260,30 @@ end
259260
nested_input_sys = complete(nested_input_sys; flatten = false)
260261
@test change_independent_variable(nested_input_sys, nested_input_sys.x) isa ODESystem
261262
end
263+
264+
@testset "Change of variables, connections" begin
265+
@mtkmodel ConnectSys begin
266+
@components begin
267+
in = RealInput()
268+
out = RealOutput()
269+
end
270+
@variables begin
271+
x(t)
272+
y(t)
273+
end
274+
@equations begin
275+
connect(in, out)
276+
in.u ~ x
277+
D(x) ~ -out.u
278+
D(y) ~ 1
279+
end
280+
end
281+
@named sys = ConnectSys()
282+
sys = complete(sys; flatten = false)
283+
new_sys = change_independent_variable(sys, sys.y; add_old_diff = true)
284+
ss = structural_simplify(new_sys; allow_symbolic = true)
285+
prob = ODEProblem(ss, [ss.t => 0.0, ss.x => 1.0], (0.0, 1.0))
286+
sol = solve(prob, Tsit5(); reltol = 1e-5)
287+
@test all(isapprox.(sol[ss.t], sol[ss.y]; atol = 1e-10))
288+
@test all(sol[ss.x][2:end] .< sol[ss.x][1])
289+
end

0 commit comments

Comments
 (0)