Skip to content

Commit 41e801e

Browse files
refactor: more efficient discrete portion, better handling of callback params
1 parent 874985c commit 41e801e

File tree

10 files changed

+266
-264
lines changed

10 files changed

+266
-264
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "9.33.1"
66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
910
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1011
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1112
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
@@ -67,6 +68,7 @@ MTKDeepDiffsExt = "DeepDiffs"
6768
AbstractTrees = "0.3, 0.4"
6869
ArrayInterface = "6, 7"
6970
BifurcationKit = "0.3"
71+
BlockArrays = "1.1"
7072
Combinatorics = "1"
7173
Compat = "3.42, 4"
7274
ConstructionBase = "1"

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ using NonlinearSolve
5252
using Reexport
5353
using RecursiveArrayTools
5454
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
55+
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
5556

5657
using RuntimeGeneratedFunctions
5758
using RuntimeGeneratedFunctions: drop_expr

src/systems/abstractsystem.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -601,17 +601,10 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
601601
return if sym isa ParameterIndex
602602
sym
603603
elseif (idx = parameter_index(ic, sym)) !== nothing
604-
if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0
605-
return nothing
606-
else
607-
idx
608-
end
604+
idx
609605
elseif iscall(sym) && operation(sym) === getindex &&
610606
(idx = parameter_index(ic, first(arguments(sym)))) !== nothing
611-
if idx.portion isa SciMLStructures.Discrete &&
612-
idx.idx[2] == idx.idx[3] == nothing
613-
return nothing
614-
elseif idx.portion isa SciMLStructures.Tunable
607+
if idx.portion isa SciMLStructures.Tunable
615608
return ParameterIndex(
616609
idx.portion, idx.idx[arguments(sym)[(begin + 1):end]...])
617610
else

src/systems/callbacks.jl

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,26 @@ function condition_header(sys::AbstractSystem, integrator = gensym(:MTKIntegrato
387387
end
388388
end
389389

390+
function callback_save_header(sys::AbstractSystem, cb)
391+
if !(has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing)
392+
return (identity, identity)
393+
end
394+
save_idxs = get(ic.callback_to_clocks, cb, Int[])
395+
isempty(save_idxs) && return (identity, identity)
396+
397+
wrapper = function(expr)
398+
return Func(expr.args, [], LiteralExpr(quote
399+
$(expr.body)
400+
save_idxs = $(save_idxs)
401+
for idx in save_idxs
402+
$(SciMLBase.save_discretes!)($(expr.args[1]), idx)
403+
end
404+
end))
405+
end
406+
407+
return wrapper, wrapper
408+
end
409+
390410
"""
391411
compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps; expression, kwargs...)
392412
@@ -421,7 +441,7 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
421441
end
422442

423443
function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
424-
compile_affect(affects(cb), args...; kwargs...)
444+
compile_affect(affects(cb), cb, args...; kwargs...)
425445
end
426446

427447
"""
@@ -441,7 +461,7 @@ Notes
441461
well-formed.
442462
- `kwargs` are passed through to `Symbolics.build_function`.
443463
"""
444-
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothing,
464+
function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = nothing,
445465
expression = Val{true}, checkvars = true, eval_expression = false,
446466
eval_module = @__MODULE__,
447467
postprocess_affect_expr! = nothing, kwargs...)
@@ -497,7 +517,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
497517
integ = gensym(:MTKIntegrator)
498518
pre = get_preprocess_constants(rhss)
499519
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
500-
wrap_code = add_integrator_header(sys, integ, outvar) .∘
520+
wrap_code = callback_save_header(sys, cb) .∘
521+
add_integrator_header(sys, integ, outvar) .∘
501522
wrap_array_vars(sys, rhss; dvs, ps = _ps) .∘
502523
wrap_parameter_dependencies(sys, false),
503524
outputidxs = update_inds,
@@ -606,14 +627,14 @@ Compile a single continuous callback affect function(s).
606627
function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
607628
eq_aff = affects(cb)
608629
eq_neg_aff = affect_negs(cb)
609-
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
630+
affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
610631
if eq_neg_aff === eq_aff
611632
affect_neg = affect
612633
elseif isnothing(eq_neg_aff)
613634
affect_neg = nothing
614635
else
615636
affect_neg = compile_affect(
616-
eq_neg_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
637+
eq_neg_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
617638
end
618639
(affect = affect, affect_neg = affect_neg)
619640
end
@@ -657,7 +678,7 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
657678
end
658679
end
659680

660-
function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
681+
function compile_user_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
661682
dvs_ind = Dict(reverse(en) for en in enumerate(dvs))
662683
v_inds = map(sym -> dvs_ind[sym], unknowns(affect))
663684

@@ -679,21 +700,29 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
679700
p = filter(x -> !isnothing(x[2]), collect(zip(parameters_syms(affect), p_inds))) |>
680701
NamedTuple
681702

682-
let u = u, p = p, user_affect = func(affect), ctx = context(affect)
703+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
704+
save_idxs = get(ic.callback_to_clocks, cb, Int[])
705+
else
706+
save_idxs = Int[]
707+
end
708+
let u = u, p = p, user_affect = func(affect), ctx = context(affect), save_idxs = save_idxs
683709
function (integ)
684710
user_affect(integ, u, p, ctx)
711+
for idx in save_idxs
712+
SciMLBase.save_discretes!(integ, idx)
713+
end
685714
end
686715
end
687716
end
688717

689-
function compile_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
690-
compile_user_affect(affect, sys, dvs, ps; kwargs...)
718+
function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
719+
compile_user_affect(affect, cb, sys, dvs, ps; kwargs...)
691720
end
692721

693722
function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = nothing,
694723
kwargs...)
695724
cond = condition(cb)
696-
as = compile_affect(affects(cb), sys, dvs, ps; expression = Val{false},
725+
as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false},
697726
postprocess_affect_expr!, kwargs...)
698727
if cond isa AbstractVector
699728
# Preset Time
@@ -711,7 +740,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
711740
kwargs...)
712741
else
713742
c = compile_condition(cb, sys, dvs, ps; expression = Val{false}, kwargs...)
714-
as = compile_affect(affects(cb), sys, dvs, ps; expression = Val{false},
743+
as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false},
715744
postprocess_affect_expr!, kwargs...)
716745
return DiscreteCallback(c, as)
717746
end

0 commit comments

Comments
 (0)