@@ -387,6 +387,26 @@ function condition_header(sys::AbstractSystem, integrator = gensym(:MTKIntegrato
387
387
end
388
388
end
389
389
390
+ function callback_save_header (sys:: AbstractSystem , cb)
391
+ if ! (has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing )
392
+ return (identity, identity)
393
+ end
394
+ save_idxs = get (ic. callback_to_clocks, cb, Int[])
395
+ isempty (save_idxs) && return (identity, identity)
396
+
397
+ wrapper = function (expr)
398
+ return Func (expr. args, [], LiteralExpr (quote
399
+ $ (expr. body)
400
+ save_idxs = $ (save_idxs)
401
+ for idx in save_idxs
402
+ $ (SciMLBase. save_discretes!)($ (expr. args[1 ]), idx)
403
+ end
404
+ end ))
405
+ end
406
+
407
+ return wrapper, wrapper
408
+ end
409
+
390
410
"""
391
411
compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps; expression, kwargs...)
392
412
@@ -421,7 +441,7 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
421
441
end
422
442
423
443
function compile_affect (cb:: SymbolicContinuousCallback , args... ; kwargs... )
424
- compile_affect (affects (cb), args... ; kwargs... )
444
+ compile_affect (affects (cb), cb, args... ; kwargs... )
425
445
end
426
446
427
447
"""
@@ -441,7 +461,7 @@ Notes
441
461
well-formed.
442
462
- `kwargs` are passed through to `Symbolics.build_function`.
443
463
"""
444
- function compile_affect (eqs:: Vector{Equation} , sys, dvs, ps; outputidxs = nothing ,
464
+ function compile_affect (eqs:: Vector{Equation} , cb, sys, dvs, ps; outputidxs = nothing ,
445
465
expression = Val{true }, checkvars = true , eval_expression = false ,
446
466
eval_module = @__MODULE__ ,
447
467
postprocess_affect_expr! = nothing , kwargs... )
@@ -497,7 +517,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
497
517
integ = gensym (:MTKIntegrator )
498
518
pre = get_preprocess_constants (rhss)
499
519
rf_oop, rf_ip = build_function (rhss, u, p... , t; expression = Val{true },
500
- wrap_code = add_integrator_header (sys, integ, outvar) .∘
520
+ wrap_code = callback_save_header (sys, cb) .∘
521
+ add_integrator_header (sys, integ, outvar) .∘
501
522
wrap_array_vars (sys, rhss; dvs, ps = _ps) .∘
502
523
wrap_parameter_dependencies (sys, false ),
503
524
outputidxs = update_inds,
@@ -606,14 +627,14 @@ Compile a single continuous callback affect function(s).
606
627
function compile_affect_fn (cb, sys:: AbstractODESystem , dvs, ps, kwargs)
607
628
eq_aff = affects (cb)
608
629
eq_neg_aff = affect_negs (cb)
609
- affect = compile_affect (eq_aff, sys, dvs, ps; expression = Val{false }, kwargs... )
630
+ affect = compile_affect (eq_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
610
631
if eq_neg_aff === eq_aff
611
632
affect_neg = affect
612
633
elseif isnothing (eq_neg_aff)
613
634
affect_neg = nothing
614
635
else
615
636
affect_neg = compile_affect (
616
- eq_neg_aff, sys, dvs, ps; expression = Val{false }, kwargs... )
637
+ eq_neg_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
617
638
end
618
639
(affect = affect, affect_neg = affect_neg)
619
640
end
@@ -657,7 +678,7 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
657
678
end
658
679
end
659
680
660
- function compile_user_affect (affect:: FunctionalAffect , sys, dvs, ps; kwargs... )
681
+ function compile_user_affect (affect:: FunctionalAffect , cb, sys, dvs, ps; kwargs... )
661
682
dvs_ind = Dict (reverse (en) for en in enumerate (dvs))
662
683
v_inds = map (sym -> dvs_ind[sym], unknowns (affect))
663
684
@@ -679,21 +700,29 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
679
700
p = filter (x -> ! isnothing (x[2 ]), collect (zip (parameters_syms (affect), p_inds))) |>
680
701
NamedTuple
681
702
682
- let u = u, p = p, user_affect = func (affect), ctx = context (affect)
703
+ if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
704
+ save_idxs = get (ic. callback_to_clocks, cb, Int[])
705
+ else
706
+ save_idxs = Int[]
707
+ end
708
+ let u = u, p = p, user_affect = func (affect), ctx = context (affect), save_idxs = save_idxs
683
709
function (integ)
684
710
user_affect (integ, u, p, ctx)
711
+ for idx in save_idxs
712
+ SciMLBase. save_discretes! (integ, idx)
713
+ end
685
714
end
686
715
end
687
716
end
688
717
689
- function compile_affect (affect:: FunctionalAffect , sys, dvs, ps; kwargs... )
690
- compile_user_affect (affect, sys, dvs, ps; kwargs... )
718
+ function compile_affect (affect:: FunctionalAffect , cb, sys, dvs, ps; kwargs... )
719
+ compile_user_affect (affect, cb, sys, dvs, ps; kwargs... )
691
720
end
692
721
693
722
function generate_timed_callback (cb, sys, dvs, ps; postprocess_affect_expr! = nothing ,
694
723
kwargs... )
695
724
cond = condition (cb)
696
- as = compile_affect (affects (cb), sys, dvs, ps; expression = Val{false },
725
+ as = compile_affect (affects (cb), cb, sys, dvs, ps; expression = Val{false },
697
726
postprocess_affect_expr!, kwargs... )
698
727
if cond isa AbstractVector
699
728
# Preset Time
@@ -711,7 +740,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
711
740
kwargs... )
712
741
else
713
742
c = compile_condition (cb, sys, dvs, ps; expression = Val{false }, kwargs... )
714
- as = compile_affect (affects (cb), sys, dvs, ps; expression = Val{false },
743
+ as = compile_affect (affects (cb), cb, sys, dvs, ps; expression = Val{false },
715
744
postprocess_affect_expr!, kwargs... )
716
745
return DiscreteCallback (c, as)
717
746
end
0 commit comments