diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index 66571fb302..37c8d8d021 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -153,13 +153,25 @@ function change_independent_variable( @set! sys.eqs = [get_eqs(sys); eqs] # add extra equations we derived @set! sys.unknowns = [get_unknowns(sys); [iv1, div2_of_iv1]] # add new variables, will be transformed to e.g. t(u) and uˍt(u) + # A utility function that returns whether var (e.g. f(t)) is a function of iv (e.g. t) + function is_function_of(var, iv) + # Peel off outer calls to find the argument of the function of + if iscall(var) && operation(var) === getindex # handle array variables + var = arguments(var)[1] # (f(t))[1] -> f(t) + end + if iscall(var) + var = only(arguments(var)) # e.g. f(t) -> t + return isequal(var, iv) + end + return false + end + # Create a utility that performs the chain rule on an expression, followed by insertion of the new independent variable: # e.g. (d/dt)(f(t)) -> (d/dt)(f(u(t))) -> df(u(t))/du(t) * du(t)/dt -> df(u)/du * uˍt(u) function transform(ex::T) where {T} # 1) Replace the argument of every function; e.g. f(t) -> f(u(t)) for var in vars(ex; op = Nothing) # loop over all variables in expression (op = Nothing prevents interpreting "D(f(t))" as one big variable) - is_function_of_iv1 = iscall(var) && isequal(only(arguments(var)), iv1) # of the form f(t)? - if is_function_of_iv1 && !isequal(var, iv2_of_iv1) # prevent e.g. u(t) -> u(u(t)) + if is_function_of(var, iv1) && !isequal(var, iv2_of_iv1) # of the form f(t)? but prevent e.g. u(t) -> u(u(t)) var_of_iv1 = var # e.g. f(t) var_of_iv2_of_iv1 = substitute(var_of_iv1, iv1 => iv2_of_iv1) # e.g. f(u(t)) ex = substitute(ex, var_of_iv1 => var_of_iv2_of_iv1; fold) @@ -207,6 +219,8 @@ function change_independent_variable( connector_type = get_connector_type(sys) assertions = Dict(transform(ass) => msg for (ass, msg) in get_assertions(sys)) wascomplete = iscomplete(sys) # save before reconstructing system + wassplit = is_split(sys) + wasflat = isempty(systems) sys = typeof(sys)( # recreate system with transformed fields eqs, iv2, unknowns, ps; observed, initialization_eqs, parameter_dependencies, defaults, guesses, connector_type, @@ -214,8 +228,6 @@ function change_independent_variable( ) sys = compose(sys, systems) # rebuild hierarchical system if wascomplete - wasflat = isempty(systems) - wassplit = is_split(sys) sys = complete(sys; split = wassplit, flatten = wasflat) # complete output if input was complete end return sys diff --git a/test/basic_transformations.jl b/test/basic_transformations.jl index 41923f973c..bb57e27ea5 100644 --- a/test/basic_transformations.jl +++ b/test/basic_transformations.jl @@ -287,3 +287,20 @@ end @test all(isapprox.(sol[ss.t], sol[ss.y]; atol = 1e-10)) @test all(sol[ss.x][2:end] .< sol[ss.x][1]) end + +@testset "Change independent variable with array variables" begin + @variables x(t) y(t) z(t)[1:2] + eqs = [ + D(x) ~ 2, + z ~ ModelingToolkit.scalarize.([sin(y), cos(y)]), + D(y) ~ z[1]^2 + z[2]^2 + ] + @named sys = ODESystem(eqs, t) + sys = complete(sys) + new_sys = change_independent_variable(sys, sys.x; add_old_diff = true) + ss_new_sys = structural_simplify(new_sys; allow_symbolic = true) + u0 = [new_sys.y => 0.5, new_sys.t => 0.0] + prob = ODEProblem(ss_new_sys, u0, (0.0, 0.5), []) + sol = solve(prob, Tsit5(); reltol = 1e-5) + @test sol[new_sys.y][end] ≈ 0.75 +end