Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/systems/diffeqs/basic_transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,25 @@ function change_independent_variable(
@set! sys.eqs = [get_eqs(sys); eqs] # add extra equations we derived
@set! sys.unknowns = [get_unknowns(sys); [iv1, div2_of_iv1]] # add new variables, will be transformed to e.g. t(u) and uˍt(u)

# A utility function that returns whether var (e.g. f(t)) is a function of iv (e.g. t)
function is_function_of(var, iv)
# Peel off outer calls to find the argument of the function of
if iscall(var) && operation(var) === getindex # handle array variables
var = arguments(var)[1] # (f(t))[1] -> f(t)
end
if iscall(var)
var = only(arguments(var)) # e.g. f(t) -> t
return isequal(var, iv)
end
return false
end

# Create a utility that performs the chain rule on an expression, followed by insertion of the new independent variable:
# e.g. (d/dt)(f(t)) -> (d/dt)(f(u(t))) -> df(u(t))/du(t) * du(t)/dt -> df(u)/du * uˍt(u)
function transform(ex::T) where {T}
# 1) Replace the argument of every function; e.g. f(t) -> f(u(t))
for var in vars(ex; op = Nothing) # loop over all variables in expression (op = Nothing prevents interpreting "D(f(t))" as one big variable)
is_function_of_iv1 = iscall(var) && isequal(only(arguments(var)), iv1) # of the form f(t)?
if is_function_of_iv1 && !isequal(var, iv2_of_iv1) # prevent e.g. u(t) -> u(u(t))
if is_function_of(var, iv1) && !isequal(var, iv2_of_iv1) # of the form f(t)? but prevent e.g. u(t) -> u(u(t))
var_of_iv1 = var # e.g. f(t)
var_of_iv2_of_iv1 = substitute(var_of_iv1, iv1 => iv2_of_iv1) # e.g. f(u(t))
ex = substitute(ex, var_of_iv1 => var_of_iv2_of_iv1; fold)
Expand Down Expand Up @@ -207,15 +219,15 @@ function change_independent_variable(
connector_type = get_connector_type(sys)
assertions = Dict(transform(ass) => msg for (ass, msg) in get_assertions(sys))
wascomplete = iscomplete(sys) # save before reconstructing system
wassplit = is_split(sys)
wasflat = isempty(systems)
sys = typeof(sys)( # recreate system with transformed fields
eqs, iv2, unknowns, ps; observed, initialization_eqs,
parameter_dependencies, defaults, guesses, connector_type,
assertions, name = nameof(sys), description = description(sys)
)
sys = compose(sys, systems) # rebuild hierarchical system
if wascomplete
wasflat = isempty(systems)
wassplit = is_split(sys)
sys = complete(sys; split = wassplit, flatten = wasflat) # complete output if input was complete
end
return sys
Expand Down
17 changes: 17 additions & 0 deletions test/basic_transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,20 @@ end
@test all(isapprox.(sol[ss.t], sol[ss.y]; atol = 1e-10))
@test all(sol[ss.x][2:end] .< sol[ss.x][1])
end

@testset "Change independent variable with array variables" begin
@variables x(t) y(t) z(t)[1:2]
eqs = [
D(x) ~ 2,
z ~ ModelingToolkit.scalarize.([sin(y), cos(y)]),
D(y) ~ z[1]^2 + z[2]^2
]
@named sys = ODESystem(eqs, t)
sys = complete(sys)
new_sys = change_independent_variable(sys, sys.x; add_old_diff = true)
ss_new_sys = structural_simplify(new_sys; allow_symbolic = true)
u0 = [new_sys.y => 0.5, new_sys.t => 0.0]
prob = ODEProblem(ss_new_sys, u0, (0.0, 0.5), [])
sol = solve(prob, Tsit5(); reltol = 1e-5)
@test sol[new_sys.y][end] ≈ 0.75
end
Loading