Skip to content

Commit 06445fa

Browse files
Merge pull request #2945 from AayushSabharwal/as/callback-split
refactor: store discrete portion in BlockedArray for efficiency, better handle event variables
2 parents 2e35294 + 996907e commit 06445fa

File tree

11 files changed

+313
-264
lines changed

11 files changed

+313
-264
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "9.33.1"
66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
910
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1011
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1112
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
@@ -69,6 +70,7 @@ MTKDeepDiffsExt = "DeepDiffs"
6970
AbstractTrees = "0.3, 0.4"
7071
ArrayInterface = "6, 7"
7172
BifurcationKit = "0.3"
73+
BlockArrays = "1.1"
7274
Combinatorics = "1"
7375
Compat = "3.42, 4"
7476
ConstructionBase = "1"

docs/src/basics/Events.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,45 @@ one must still use a vector
336336
```julia
337337
discrete_events = [[2.0] => [v ~ -v]]
338338
```
339+
340+
## Saving discrete values
341+
342+
Time-dependent parameters which are updated in callbacks are termed as discrete variables.
343+
ModelingToolkit enables automatically saving the timeseries of these discrete variables,
344+
and indexing the solution object to obtain the saved timeseries. Consider the following
345+
example:
346+
347+
```@example events
348+
@variables x(t)
349+
@parameters c(t)
350+
351+
@mtkbuild sys = ODESystem(
352+
D(x) ~ c * cos(x), t, [x], [c]; discrete_events = [1.0 => [c ~ c + 1]])
353+
354+
prob = ODEProblem(sys, [x => 0.0], (0.0, 2pi), [c => 1.0])
355+
sol = solve(prob, Tsit5())
356+
sol[c]
357+
```
358+
359+
The solution object can also be interpolated with the discrete variables
360+
361+
```@example events
362+
sol([1.0, 2.0], idxs = [c, c * cos(x)])
363+
```
364+
365+
Note that only time-dependent parameters will be saved. If we repeat the above example with
366+
this change:
367+
368+
```@example events
369+
@variables x(t)
370+
@parameters c
371+
372+
@mtkbuild sys = ODESystem(
373+
D(x) ~ c * cos(x), t, [x], [c]; discrete_events = [1.0 => [c ~ c + 1]])
374+
375+
prob = ODEProblem(sys, [x => 0.0], (0.0, 2pi), [c => 1.0])
376+
sol = solve(prob, Tsit5())
377+
sol.ps[c] # sol[c] will error, since `c` is not a timeseries value
378+
```
379+
380+
It can be seen that the timeseries for `c` is not saved.

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ using NonlinearSolve
5252
using Reexport
5353
using RecursiveArrayTools
5454
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
55+
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
5556

5657
using RuntimeGeneratedFunctions
5758
using RuntimeGeneratedFunctions: drop_expr

src/systems/abstractsystem.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -601,17 +601,10 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
601601
return if sym isa ParameterIndex
602602
sym
603603
elseif (idx = parameter_index(ic, sym)) !== nothing
604-
if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0
605-
return nothing
606-
else
607-
idx
608-
end
604+
idx
609605
elseif iscall(sym) && operation(sym) === getindex &&
610606
(idx = parameter_index(ic, first(arguments(sym)))) !== nothing
611-
if idx.portion isa SciMLStructures.Discrete &&
612-
idx.idx[2] == idx.idx[3] == nothing
613-
return nothing
614-
elseif idx.portion isa SciMLStructures.Tunable
607+
if idx.portion isa SciMLStructures.Tunable
615608
return ParameterIndex(
616609
idx.portion, idx.idx[arguments(sym)[(begin + 1):end]...])
617610
else

src/systems/callbacks.jl

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,27 @@ function condition_header(sys::AbstractSystem, integrator = gensym(:MTKIntegrato
387387
end
388388
end
389389

390+
function callback_save_header(sys::AbstractSystem, cb)
391+
if !(has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing)
392+
return (identity, identity)
393+
end
394+
save_idxs = get(ic.callback_to_clocks, cb, Int[])
395+
isempty(save_idxs) && return (identity, identity)
396+
397+
wrapper = function (expr)
398+
return Func(expr.args, [],
399+
LiteralExpr(quote
400+
$(expr.body)
401+
save_idxs = $(save_idxs)
402+
for idx in save_idxs
403+
$(SciMLBase.save_discretes!)($(expr.args[1]), idx)
404+
end
405+
end))
406+
end
407+
408+
return wrapper, wrapper
409+
end
410+
390411
"""
391412
compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps; expression, kwargs...)
392413
@@ -421,7 +442,7 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
421442
end
422443

423444
function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
424-
compile_affect(affects(cb), args...; kwargs...)
445+
compile_affect(affects(cb), cb, args...; kwargs...)
425446
end
426447

427448
"""
@@ -441,7 +462,7 @@ Notes
441462
well-formed.
442463
- `kwargs` are passed through to `Symbolics.build_function`.
443464
"""
444-
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothing,
465+
function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = nothing,
445466
expression = Val{true}, checkvars = true, eval_expression = false,
446467
eval_module = @__MODULE__,
447468
postprocess_affect_expr! = nothing, kwargs...)
@@ -497,7 +518,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
497518
integ = gensym(:MTKIntegrator)
498519
pre = get_preprocess_constants(rhss)
499520
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
500-
wrap_code = add_integrator_header(sys, integ, outvar) .∘
521+
wrap_code = callback_save_header(sys, cb) .∘
522+
add_integrator_header(sys, integ, outvar) .∘
501523
wrap_array_vars(sys, rhss; dvs, ps = _ps) .∘
502524
wrap_parameter_dependencies(sys, false),
503525
outputidxs = update_inds,
@@ -606,14 +628,14 @@ Compile a single continuous callback affect function(s).
606628
function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
607629
eq_aff = affects(cb)
608630
eq_neg_aff = affect_negs(cb)
609-
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
631+
affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
610632
if eq_neg_aff === eq_aff
611633
affect_neg = affect
612634
elseif isnothing(eq_neg_aff)
613635
affect_neg = nothing
614636
else
615637
affect_neg = compile_affect(
616-
eq_neg_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
638+
eq_neg_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
617639
end
618640
(affect = affect, affect_neg = affect_neg)
619641
end
@@ -657,7 +679,7 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
657679
end
658680
end
659681

660-
function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
682+
function compile_user_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
661683
dvs_ind = Dict(reverse(en) for en in enumerate(dvs))
662684
v_inds = map(sym -> dvs_ind[sym], unknowns(affect))
663685

@@ -679,21 +701,31 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
679701
p = filter(x -> !isnothing(x[2]), collect(zip(parameters_syms(affect), p_inds))) |>
680702
NamedTuple
681703

682-
let u = u, p = p, user_affect = func(affect), ctx = context(affect)
704+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
705+
save_idxs = get(ic.callback_to_clocks, cb, Int[])
706+
else
707+
save_idxs = Int[]
708+
end
709+
let u = u, p = p, user_affect = func(affect), ctx = context(affect),
710+
save_idxs = save_idxs
711+
683712
function (integ)
684713
user_affect(integ, u, p, ctx)
714+
for idx in save_idxs
715+
SciMLBase.save_discretes!(integ, idx)
716+
end
685717
end
686718
end
687719
end
688720

689-
function compile_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
690-
compile_user_affect(affect, sys, dvs, ps; kwargs...)
721+
function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
722+
compile_user_affect(affect, cb, sys, dvs, ps; kwargs...)
691723
end
692724

693725
function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = nothing,
694726
kwargs...)
695727
cond = condition(cb)
696-
as = compile_affect(affects(cb), sys, dvs, ps; expression = Val{false},
728+
as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false},
697729
postprocess_affect_expr!, kwargs...)
698730
if cond isa AbstractVector
699731
# Preset Time
@@ -711,7 +743,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
711743
kwargs...)
712744
else
713745
c = compile_condition(cb, sys, dvs, ps; expression = Val{false}, kwargs...)
714-
as = compile_affect(affects(cb), sys, dvs, ps; expression = Val{false},
746+
as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false},
715747
postprocess_affect_expr!, kwargs...)
716748
return DiscreteCallback(c, as)
717749
end

0 commit comments

Comments
 (0)