Skip to content

Commit 354c436

Browse files
Merge pull request #3219 from AayushSabharwal/as/symbolic-tstops
feat: add symbolic tstops support to `ODESystem`
2 parents 8211ddc + 85ca1a9 commit 354c436

File tree

4 files changed

+103
-6
lines changed

4 files changed

+103
-6
lines changed

src/systems/abstractsystem.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,7 @@ for prop in [:eqs
10511051
:split_idxs
10521052
:parent
10531053
:is_dde
1054+
:tstops
10541055
:index_cache
10551056
:is_scalar_noise
10561057
:isscheduled]
@@ -1377,6 +1378,14 @@ function namespace_initialization_equations(
13771378
map(eq -> namespace_equation(eq, sys; ivs), eqs)
13781379
end
13791380

1381+
function namespace_tstops(sys::AbstractSystem)
1382+
tstops = symbolic_tstops(sys)
1383+
isempty(tstops) && return tstops
1384+
map(tstops) do val
1385+
namespace_expr(val, sys)
1386+
end
1387+
end
1388+
13801389
function namespace_equation(eq::Equation,
13811390
sys,
13821391
n = nameof(sys);
@@ -1632,6 +1641,14 @@ function initialization_equations(sys::AbstractSystem)
16321641
end
16331642
end
16341643

1644+
function symbolic_tstops(sys::AbstractSystem)
1645+
tstops = get_tstops(sys)
1646+
systems = get_systems(sys)
1647+
isempty(systems) && return tstops
1648+
tstops = [tstops; reduce(vcat, namespace_tstops.(get_systems(sys)); init = [])]
1649+
return tstops
1650+
end
1651+
16351652
function preface(sys::AbstractSystem)
16361653
has_preface(sys) || return nothing
16371654
pre = get_preface(sys)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,39 @@ function DAEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
754754
DAEFunctionExpr{true}(sys, args...; kwargs...)
755755
end
756756

757+
struct SymbolicTstops{F}
758+
fn::F
759+
end
760+
761+
function (st::SymbolicTstops)(p, tspan)
762+
unique!(sort!(reduce(vcat, st.fn(p..., tspan...))))
763+
end
764+
765+
function SymbolicTstops(
766+
sys::AbstractSystem; eval_expression = false, eval_module = @__MODULE__)
767+
tstops = symbolic_tstops(sys)
768+
isempty(tstops) && return nothing
769+
t0 = gensym(:t0)
770+
t1 = gensym(:t1)
771+
tstops = map(tstops) do val
772+
if is_array_of_symbolics(val) || val isa AbstractArray
773+
collect(val)
774+
else
775+
term(:, t0, unwrap(val), t1; type = AbstractArray{Real})
776+
end
777+
end
778+
rps = reorder_parameters(sys, parameters(sys))
779+
tstops, _ = build_function(tstops,
780+
rps...,
781+
t0,
782+
t1;
783+
expression = Val{true},
784+
wrap_code = wrap_array_vars(sys, tstops; dvs = nothing) .∘
785+
wrap_parameter_dependencies(sys, false))
786+
tstops = eval_or_rgf(tstops; eval_expression, eval_module)
787+
return SymbolicTstops(tstops)
788+
end
789+
757790
"""
758791
```julia
759792
DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem, u0map, tspan,
@@ -813,6 +846,11 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
813846
kwargs1 = merge(kwargs1, (callback = cbs,))
814847
end
815848

849+
tstops = SymbolicTstops(sys; eval_expression, eval_module)
850+
if tstops !== nothing
851+
kwargs1 = merge(kwargs1, (; tstops))
852+
end
853+
816854
return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
817855
end
818856
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
@@ -839,7 +877,7 @@ end
839877
function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
840878
parammap = DiffEqBase.NullParameters();
841879
warn_initialize_determined = true,
842-
check_length = true, kwargs...) where {iip}
880+
check_length = true, eval_expression = false, eval_module = @__MODULE__, kwargs...) where {iip}
843881
if !iscomplete(sys)
844882
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblem`")
845883
end
@@ -852,8 +890,15 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
852890
differential_vars = map(Base.Fix2(in, diffvars), sts)
853891
kwargs = filter_kwargs(kwargs)
854892

893+
kwargs1 = (;)
894+
895+
tstops = SymbolicTstops(sys; eval_expression, eval_module)
896+
if tstops !== nothing
897+
kwargs1 = merge(kwargs1, (; tstops))
898+
end
899+
855900
DAEProblem{iip}(f, du0, u0, tspan, p; differential_vars = differential_vars,
856-
kwargs...)
901+
kwargs..., kwargs1...)
857902
end
858903

859904
function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...)

src/systems/diffeqs/odesystem.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ struct ODESystem <: AbstractODESystem
149149
"""
150150
is_dde::Bool
151151
"""
152+
A list of points to provide to the solver as tstops. Uses the same syntax as discrete
153+
events.
154+
"""
155+
tstops::Vector{Any}
156+
"""
152157
Cache for intermediate tearing state.
153158
"""
154159
tearing_state::Any
@@ -187,7 +192,7 @@ struct ODESystem <: AbstractODESystem
187192
connector_type, preface, cevents,
188193
devents, parameter_dependencies,
189194
metadata = nothing, gui_metadata = nothing, is_dde = false,
190-
tearing_state = nothing,
195+
tstops = [], tearing_state = nothing,
191196
substitutions = nothing, complete = false, index_cache = nothing,
192197
discrete_subsystems = nothing, solved_unknowns = nothing,
193198
split_idxs = nothing, parent = nothing; checks::Union{Bool, Int} = true)
@@ -206,7 +211,7 @@ struct ODESystem <: AbstractODESystem
206211
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
207212
initializesystem, initialization_eqs, schedule, connector_type, preface,
208213
cevents, devents, parameter_dependencies, metadata,
209-
gui_metadata, is_dde, tearing_state, substitutions, complete, index_cache,
214+
gui_metadata, is_dde, tstops, tearing_state, substitutions, complete, index_cache,
210215
discrete_subsystems, solved_unknowns, split_idxs, parent)
211216
end
212217
end
@@ -233,7 +238,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
233238
checks = true,
234239
metadata = nothing,
235240
gui_metadata = nothing,
236-
is_dde = nothing)
241+
is_dde = nothing,
242+
tstops = [])
237243
name === nothing &&
238244
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
239245
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
@@ -299,7 +305,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
299305
defaults, guesses, nothing, initializesystem,
300306
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
301307
disc_callbacks, parameter_dependencies,
302-
metadata, gui_metadata, is_dde, checks = checks)
308+
metadata, gui_metadata, is_dde, tstops, checks = checks)
303309
end
304310

305311
function ODESystem(eqs, iv; kwargs...)
@@ -402,6 +408,7 @@ function flatten(sys::ODESystem, noeqs = false)
402408
description = description(sys),
403409
initialization_eqs = initialization_equations(sys),
404410
is_dde = is_dde(sys),
411+
tstops = symbolic_tstops(sys),
405412
checks = false)
406413
end
407414
end

test/odesystem.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,3 +1524,31 @@ end
15241524
sol = solve(prob, DFBDF(), abstol=1e-8, reltol=1e-8)
15251525
@test sol[x]sol[y^2 - sum(p)] atol=1e-5
15261526
end
1527+
1528+
@testset "Symbolic tstops" begin
1529+
@variables x(t) = 1.0
1530+
@parameters p=0.15 q=0.25 r[1:2]=[0.35, 0.45]
1531+
@mtkbuild sys = ODESystem(
1532+
[D(x) ~ p * x + q * t + sum(r)], t; tstops = [0.5p, [0.1, 0.2], [p + 2q], r])
1533+
prob = ODEProblem(sys, [], (0.0, 5.0))
1534+
sol = solve(prob)
1535+
expected_tstops = unique!(sort!(vcat(0.0:0.075:5.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
1536+
@test all(x -> any(isapprox(x, atol = 1e-6), sol.t), expected_tstops)
1537+
prob2 = remake(prob; tspan = (0.0, 10.0))
1538+
sol2 = solve(prob2)
1539+
expected_tstops = unique!(sort!(vcat(0.0:0.075:10.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
1540+
@test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops)
1541+
1542+
@variables y(t) [guess = 1.0]
1543+
@mtkbuild sys = ODESystem([D(x) ~ p * x + q * t + sum(r), y^3 ~ 2x + 1],
1544+
t; tstops = [0.5p, [0.1, 0.2], [p + 2q], r])
1545+
prob = DAEProblem(
1546+
sys, [D(y) => 2D(x) / 3y^2, D(x) => p * x + q * t + sum(r)], [], (0.0, 5.0))
1547+
sol = solve(prob, DImplicitEuler())
1548+
expected_tstops = unique!(sort!(vcat(0.0:0.075:5.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
1549+
@test all(x -> any(isapprox(x, atol = 1e-6), sol.t), expected_tstops)
1550+
prob2 = remake(prob; tspan = (0.0, 10.0))
1551+
sol2 = solve(prob2, DImplicitEuler())
1552+
expected_tstops = unique!(sort!(vcat(0.0:0.075:10.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
1553+
@test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops)
1554+
end

0 commit comments

Comments
 (0)