From 85ca1a9f8cb8a426cab4302e0517e484951d19df Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 18 Nov 2024 14:04:14 +0530 Subject: [PATCH] feat: add symbolic tstops support to `ODESystem` --- src/systems/abstractsystem.jl | 17 ++++++++ src/systems/diffeqs/abstractodesystem.jl | 49 +++++++++++++++++++++++- src/systems/diffeqs/odesystem.jl | 15 ++++++-- test/odesystem.jl | 28 ++++++++++++++ 4 files changed, 103 insertions(+), 6 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index db2fcb4f10..8f021c9a1d 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1051,6 +1051,7 @@ for prop in [:eqs :split_idxs :parent :is_dde + :tstops :index_cache :is_scalar_noise :isscheduled] @@ -1377,6 +1378,14 @@ function namespace_initialization_equations( map(eq -> namespace_equation(eq, sys; ivs), eqs) end +function namespace_tstops(sys::AbstractSystem) + tstops = symbolic_tstops(sys) + isempty(tstops) && return tstops + map(tstops) do val + namespace_expr(val, sys) + end +end + function namespace_equation(eq::Equation, sys, n = nameof(sys); @@ -1632,6 +1641,14 @@ function initialization_equations(sys::AbstractSystem) end end +function symbolic_tstops(sys::AbstractSystem) + tstops = get_tstops(sys) + systems = get_systems(sys) + isempty(systems) && return tstops + tstops = [tstops; reduce(vcat, namespace_tstops.(get_systems(sys)); init = [])] + return tstops +end + function preface(sys::AbstractSystem) has_preface(sys) || return nothing pre = get_preface(sys) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index ca69c038d9..64a24b8ba3 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -758,6 +758,39 @@ function DAEFunctionExpr(sys::AbstractODESystem, args...; kwargs...) DAEFunctionExpr{true}(sys, args...; kwargs...) end +struct SymbolicTstops{F} + fn::F +end + +function (st::SymbolicTstops)(p, tspan) + unique!(sort!(reduce(vcat, st.fn(p..., tspan...)))) +end + +function SymbolicTstops( + sys::AbstractSystem; eval_expression = false, eval_module = @__MODULE__) + tstops = symbolic_tstops(sys) + isempty(tstops) && return nothing + t0 = gensym(:t0) + t1 = gensym(:t1) + tstops = map(tstops) do val + if is_array_of_symbolics(val) || val isa AbstractArray + collect(val) + else + term(:, t0, unwrap(val), t1; type = AbstractArray{Real}) + end + end + rps = reorder_parameters(sys, parameters(sys)) + tstops, _ = build_function(tstops, + rps..., + t0, + t1; + expression = Val{true}, + wrap_code = wrap_array_vars(sys, tstops; dvs = nothing) .∘ + wrap_parameter_dependencies(sys, false)) + tstops = eval_or_rgf(tstops; eval_expression, eval_module) + return SymbolicTstops(tstops) +end + """ ```julia DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem, u0map, tspan, @@ -817,6 +850,11 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = kwargs1 = merge(kwargs1, (callback = cbs,)) end + tstops = SymbolicTstops(sys; eval_expression, eval_module) + if tstops !== nothing + kwargs1 = merge(kwargs1, (; tstops)) + end + return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...) end get_callback(prob::ODEProblem) = prob.kwargs[:callback] @@ -843,7 +881,7 @@ end function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan, parammap = DiffEqBase.NullParameters(); warn_initialize_determined = true, - check_length = true, kwargs...) where {iip} + check_length = true, eval_expression = false, eval_module = @__MODULE__, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblem`") end @@ -856,8 +894,15 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan differential_vars = map(Base.Fix2(in, diffvars), sts) kwargs = filter_kwargs(kwargs) + kwargs1 = (;) + + tstops = SymbolicTstops(sys; eval_expression, eval_module) + if tstops !== nothing + kwargs1 = merge(kwargs1, (; tstops)) + end + DAEProblem{iip}(f, du0, u0, tspan, p; differential_vars = differential_vars, - kwargs...) + kwargs..., kwargs1...) end function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index b8dee2bac7..2b0bd8c8d7 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -149,6 +149,11 @@ struct ODESystem <: AbstractODESystem """ is_dde::Bool """ + A list of points to provide to the solver as tstops. Uses the same syntax as discrete + events. + """ + tstops::Vector{Any} + """ Cache for intermediate tearing state. """ tearing_state::Any @@ -187,7 +192,7 @@ struct ODESystem <: AbstractODESystem connector_type, preface, cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing, is_dde = false, - tearing_state = nothing, + tstops = [], tearing_state = nothing, substitutions = nothing, complete = false, index_cache = nothing, discrete_subsystems = nothing, solved_unknowns = nothing, split_idxs = nothing, parent = nothing; checks::Union{Bool, Int} = true) @@ -206,7 +211,7 @@ struct ODESystem <: AbstractODESystem ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching, initializesystem, initialization_eqs, schedule, connector_type, preface, cevents, devents, parameter_dependencies, metadata, - gui_metadata, is_dde, tearing_state, substitutions, complete, index_cache, + gui_metadata, is_dde, tstops, tearing_state, substitutions, complete, index_cache, discrete_subsystems, solved_unknowns, split_idxs, parent) end end @@ -233,7 +238,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; checks = true, metadata = nothing, gui_metadata = nothing, - is_dde = nothing) + is_dde = nothing, + tstops = []) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) @assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters." @@ -299,7 +305,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; defaults, guesses, nothing, initializesystem, initialization_eqs, schedule, connector_type, preface, cont_callbacks, disc_callbacks, parameter_dependencies, - metadata, gui_metadata, is_dde, checks = checks) + metadata, gui_metadata, is_dde, tstops, checks = checks) end function ODESystem(eqs, iv; kwargs...) @@ -402,6 +408,7 @@ function flatten(sys::ODESystem, noeqs = false) description = description(sys), initialization_eqs = initialization_equations(sys), is_dde = is_dde(sys), + tstops = symbolic_tstops(sys), checks = false) end end diff --git a/test/odesystem.jl b/test/odesystem.jl index 90d1c4c578..7f292798df 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1524,3 +1524,31 @@ end sol = solve(prob, DFBDF(), abstol=1e-8, reltol=1e-8) @test sol[x]≈sol[y^2 - sum(p)] atol=1e-5 end + +@testset "Symbolic tstops" begin + @variables x(t) = 1.0 + @parameters p=0.15 q=0.25 r[1:2]=[0.35, 0.45] + @mtkbuild sys = ODESystem( + [D(x) ~ p * x + q * t + sum(r)], t; tstops = [0.5p, [0.1, 0.2], [p + 2q], r]) + prob = ODEProblem(sys, [], (0.0, 5.0)) + sol = solve(prob) + expected_tstops = unique!(sort!(vcat(0.0:0.075:5.0, 0.1, 0.2, 0.65, 0.35, 0.45))) + @test all(x -> any(isapprox(x, atol = 1e-6), sol.t), expected_tstops) + prob2 = remake(prob; tspan = (0.0, 10.0)) + sol2 = solve(prob2) + expected_tstops = unique!(sort!(vcat(0.0:0.075:10.0, 0.1, 0.2, 0.65, 0.35, 0.45))) + @test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops) + + @variables y(t) [guess = 1.0] + @mtkbuild sys = ODESystem([D(x) ~ p * x + q * t + sum(r), y^3 ~ 2x + 1], + t; tstops = [0.5p, [0.1, 0.2], [p + 2q], r]) + prob = DAEProblem( + sys, [D(y) => 2D(x) / 3y^2, D(x) => p * x + q * t + sum(r)], [], (0.0, 5.0)) + sol = solve(prob, DImplicitEuler()) + expected_tstops = unique!(sort!(vcat(0.0:0.075:5.0, 0.1, 0.2, 0.65, 0.35, 0.45))) + @test all(x -> any(isapprox(x, atol = 1e-6), sol.t), expected_tstops) + prob2 = remake(prob; tspan = (0.0, 10.0)) + sol2 = solve(prob2, DImplicitEuler()) + expected_tstops = unique!(sort!(vcat(0.0:0.075:10.0, 0.1, 0.2, 0.65, 0.35, 0.45))) + @test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops) +end