Skip to content

Commit 85ca1a9

Browse files
feat: add symbolic tstops support to ODESystem
1 parent 57e1a43 commit 85ca1a9

File tree

4 files changed

+103
-6
lines changed

4 files changed

+103
-6
lines changed

src/systems/abstractsystem.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,7 @@ for prop in [:eqs
10511051
:split_idxs
10521052
:parent
10531053
:is_dde
1054+
:tstops
10541055
:index_cache
10551056
:is_scalar_noise
10561057
:isscheduled]
@@ -1377,6 +1378,14 @@ function namespace_initialization_equations(
13771378
map(eq -> namespace_equation(eq, sys; ivs), eqs)
13781379
end
13791380

1381+
function namespace_tstops(sys::AbstractSystem)
1382+
tstops = symbolic_tstops(sys)
1383+
isempty(tstops) && return tstops
1384+
map(tstops) do val
1385+
namespace_expr(val, sys)
1386+
end
1387+
end
1388+
13801389
function namespace_equation(eq::Equation,
13811390
sys,
13821391
n = nameof(sys);
@@ -1632,6 +1641,14 @@ function initialization_equations(sys::AbstractSystem)
16321641
end
16331642
end
16341643

1644+
function symbolic_tstops(sys::AbstractSystem)
1645+
tstops = get_tstops(sys)
1646+
systems = get_systems(sys)
1647+
isempty(systems) && return tstops
1648+
tstops = [tstops; reduce(vcat, namespace_tstops.(get_systems(sys)); init = [])]
1649+
return tstops
1650+
end
1651+
16351652
function preface(sys::AbstractSystem)
16361653
has_preface(sys) || return nothing
16371654
pre = get_preface(sys)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,39 @@ function DAEFunctionExpr(sys::AbstractODESystem, args...; kwargs...)
758758
DAEFunctionExpr{true}(sys, args...; kwargs...)
759759
end
760760

761+
struct SymbolicTstops{F}
762+
fn::F
763+
end
764+
765+
function (st::SymbolicTstops)(p, tspan)
766+
unique!(sort!(reduce(vcat, st.fn(p..., tspan...))))
767+
end
768+
769+
function SymbolicTstops(
770+
sys::AbstractSystem; eval_expression = false, eval_module = @__MODULE__)
771+
tstops = symbolic_tstops(sys)
772+
isempty(tstops) && return nothing
773+
t0 = gensym(:t0)
774+
t1 = gensym(:t1)
775+
tstops = map(tstops) do val
776+
if is_array_of_symbolics(val) || val isa AbstractArray
777+
collect(val)
778+
else
779+
term(:, t0, unwrap(val), t1; type = AbstractArray{Real})
780+
end
781+
end
782+
rps = reorder_parameters(sys, parameters(sys))
783+
tstops, _ = build_function(tstops,
784+
rps...,
785+
t0,
786+
t1;
787+
expression = Val{true},
788+
wrap_code = wrap_array_vars(sys, tstops; dvs = nothing) .∘
789+
wrap_parameter_dependencies(sys, false))
790+
tstops = eval_or_rgf(tstops; eval_expression, eval_module)
791+
return SymbolicTstops(tstops)
792+
end
793+
761794
"""
762795
```julia
763796
DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem, u0map, tspan,
@@ -817,6 +850,11 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
817850
kwargs1 = merge(kwargs1, (callback = cbs,))
818851
end
819852

853+
tstops = SymbolicTstops(sys; eval_expression, eval_module)
854+
if tstops !== nothing
855+
kwargs1 = merge(kwargs1, (; tstops))
856+
end
857+
820858
return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
821859
end
822860
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
@@ -843,7 +881,7 @@ end
843881
function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
844882
parammap = DiffEqBase.NullParameters();
845883
warn_initialize_determined = true,
846-
check_length = true, kwargs...) where {iip}
884+
check_length = true, eval_expression = false, eval_module = @__MODULE__, kwargs...) where {iip}
847885
if !iscomplete(sys)
848886
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblem`")
849887
end
@@ -856,8 +894,15 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
856894
differential_vars = map(Base.Fix2(in, diffvars), sts)
857895
kwargs = filter_kwargs(kwargs)
858896

897+
kwargs1 = (;)
898+
899+
tstops = SymbolicTstops(sys; eval_expression, eval_module)
900+
if tstops !== nothing
901+
kwargs1 = merge(kwargs1, (; tstops))
902+
end
903+
859904
DAEProblem{iip}(f, du0, u0, tspan, p; differential_vars = differential_vars,
860-
kwargs...)
905+
kwargs..., kwargs1...)
861906
end
862907

863908
function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...)

src/systems/diffeqs/odesystem.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ struct ODESystem <: AbstractODESystem
149149
"""
150150
is_dde::Bool
151151
"""
152+
A list of points to provide to the solver as tstops. Uses the same syntax as discrete
153+
events.
154+
"""
155+
tstops::Vector{Any}
156+
"""
152157
Cache for intermediate tearing state.
153158
"""
154159
tearing_state::Any
@@ -187,7 +192,7 @@ struct ODESystem <: AbstractODESystem
187192
connector_type, preface, cevents,
188193
devents, parameter_dependencies,
189194
metadata = nothing, gui_metadata = nothing, is_dde = false,
190-
tearing_state = nothing,
195+
tstops = [], tearing_state = nothing,
191196
substitutions = nothing, complete = false, index_cache = nothing,
192197
discrete_subsystems = nothing, solved_unknowns = nothing,
193198
split_idxs = nothing, parent = nothing; checks::Union{Bool, Int} = true)
@@ -206,7 +211,7 @@ struct ODESystem <: AbstractODESystem
206211
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
207212
initializesystem, initialization_eqs, schedule, connector_type, preface,
208213
cevents, devents, parameter_dependencies, metadata,
209-
gui_metadata, is_dde, tearing_state, substitutions, complete, index_cache,
214+
gui_metadata, is_dde, tstops, tearing_state, substitutions, complete, index_cache,
210215
discrete_subsystems, solved_unknowns, split_idxs, parent)
211216
end
212217
end
@@ -233,7 +238,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
233238
checks = true,
234239
metadata = nothing,
235240
gui_metadata = nothing,
236-
is_dde = nothing)
241+
is_dde = nothing,
242+
tstops = [])
237243
name === nothing &&
238244
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
239245
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
@@ -299,7 +305,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
299305
defaults, guesses, nothing, initializesystem,
300306
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
301307
disc_callbacks, parameter_dependencies,
302-
metadata, gui_metadata, is_dde, checks = checks)
308+
metadata, gui_metadata, is_dde, tstops, checks = checks)
303309
end
304310

305311
function ODESystem(eqs, iv; kwargs...)
@@ -402,6 +408,7 @@ function flatten(sys::ODESystem, noeqs = false)
402408
description = description(sys),
403409
initialization_eqs = initialization_equations(sys),
404410
is_dde = is_dde(sys),
411+
tstops = symbolic_tstops(sys),
405412
checks = false)
406413
end
407414
end

test/odesystem.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,3 +1524,31 @@ end
15241524
sol = solve(prob, DFBDF(), abstol=1e-8, reltol=1e-8)
15251525
@test sol[x]sol[y^2 - sum(p)] atol=1e-5
15261526
end
1527+
1528+
@testset "Symbolic tstops" begin
1529+
@variables x(t) = 1.0
1530+
@parameters p=0.15 q=0.25 r[1:2]=[0.35, 0.45]
1531+
@mtkbuild sys = ODESystem(
1532+
[D(x) ~ p * x + q * t + sum(r)], t; tstops = [0.5p, [0.1, 0.2], [p + 2q], r])
1533+
prob = ODEProblem(sys, [], (0.0, 5.0))
1534+
sol = solve(prob)
1535+
expected_tstops = unique!(sort!(vcat(0.0:0.075:5.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
1536+
@test all(x -> any(isapprox(x, atol = 1e-6), sol.t), expected_tstops)
1537+
prob2 = remake(prob; tspan = (0.0, 10.0))
1538+
sol2 = solve(prob2)
1539+
expected_tstops = unique!(sort!(vcat(0.0:0.075:10.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
1540+
@test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops)
1541+
1542+
@variables y(t) [guess = 1.0]
1543+
@mtkbuild sys = ODESystem([D(x) ~ p * x + q * t + sum(r), y^3 ~ 2x + 1],
1544+
t; tstops = [0.5p, [0.1, 0.2], [p + 2q], r])
1545+
prob = DAEProblem(
1546+
sys, [D(y) => 2D(x) / 3y^2, D(x) => p * x + q * t + sum(r)], [], (0.0, 5.0))
1547+
sol = solve(prob, DImplicitEuler())
1548+
expected_tstops = unique!(sort!(vcat(0.0:0.075:5.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
1549+
@test all(x -> any(isapprox(x, atol = 1e-6), sol.t), expected_tstops)
1550+
prob2 = remake(prob; tspan = (0.0, 10.0))
1551+
sol2 = solve(prob2, DImplicitEuler())
1552+
expected_tstops = unique!(sort!(vcat(0.0:0.075:10.0, 0.1, 0.2, 0.65, 0.35, 0.45)))
1553+
@test all(x -> any(isapprox(x, atol = 1e-6), sol2.t), expected_tstops)
1554+
end

0 commit comments

Comments
 (0)