@@ -387,6 +387,26 @@ function condition_header(sys::AbstractSystem, integrator = gensym(:MTKIntegrato
387387 end
388388end
389389
390+ function callback_save_header (sys:: AbstractSystem , cb)
391+ if ! (has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing )
392+ return (identity, identity)
393+ end
394+ save_idxs = get (ic. callback_to_clocks, cb, Int[])
395+ isempty (save_idxs) && return (identity, identity)
396+
397+ wrapper = function (expr)
398+ return Func (expr. args, [], LiteralExpr (quote
399+ $ (expr. body)
400+ save_idxs = $ (save_idxs)
401+ for idx in save_idxs
402+ $ (SciMLBase. save_discretes!)($ (expr. args[1 ]), idx)
403+ end
404+ end ))
405+ end
406+
407+ return wrapper, wrapper
408+ end
409+
390410"""
391411 compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps; expression, kwargs...)
392412
@@ -421,7 +441,7 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
421441end
422442
423443function compile_affect (cb:: SymbolicContinuousCallback , args... ; kwargs... )
424- compile_affect (affects (cb), args... ; kwargs... )
444+ compile_affect (affects (cb), cb, args... ; kwargs... )
425445end
426446
427447"""
@@ -441,7 +461,7 @@ Notes
441461 well-formed.
442462 - `kwargs` are passed through to `Symbolics.build_function`.
443463"""
444- function compile_affect (eqs:: Vector{Equation} , sys, dvs, ps; outputidxs = nothing ,
464+ function compile_affect (eqs:: Vector{Equation} , cb, sys, dvs, ps; outputidxs = nothing ,
445465 expression = Val{true }, checkvars = true , eval_expression = false ,
446466 eval_module = @__MODULE__ ,
447467 postprocess_affect_expr! = nothing , kwargs... )
@@ -497,7 +517,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
497517 integ = gensym (:MTKIntegrator )
498518 pre = get_preprocess_constants (rhss)
499519 rf_oop, rf_ip = build_function (rhss, u, p... , t; expression = Val{true },
500- wrap_code = add_integrator_header (sys, integ, outvar) .∘
520+ wrap_code = callback_save_header (sys, cb) .∘
521+ add_integrator_header (sys, integ, outvar) .∘
501522 wrap_array_vars (sys, rhss; dvs, ps = _ps) .∘
502523 wrap_parameter_dependencies (sys, false ),
503524 outputidxs = update_inds,
@@ -606,14 +627,14 @@ Compile a single continuous callback affect function(s).
606627function compile_affect_fn (cb, sys:: AbstractODESystem , dvs, ps, kwargs)
607628 eq_aff = affects (cb)
608629 eq_neg_aff = affect_negs (cb)
609- affect = compile_affect (eq_aff, sys, dvs, ps; expression = Val{false }, kwargs... )
630+ affect = compile_affect (eq_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
610631 if eq_neg_aff === eq_aff
611632 affect_neg = affect
612633 elseif isnothing (eq_neg_aff)
613634 affect_neg = nothing
614635 else
615636 affect_neg = compile_affect (
616- eq_neg_aff, sys, dvs, ps; expression = Val{false }, kwargs... )
637+ eq_neg_aff, cb, sys, dvs, ps; expression = Val{false }, kwargs... )
617638 end
618639 (affect = affect, affect_neg = affect_neg)
619640end
@@ -657,7 +678,7 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
657678 end
658679end
659680
660- function compile_user_affect (affect:: FunctionalAffect , sys, dvs, ps; kwargs... )
681+ function compile_user_affect (affect:: FunctionalAffect , cb, sys, dvs, ps; kwargs... )
661682 dvs_ind = Dict (reverse (en) for en in enumerate (dvs))
662683 v_inds = map (sym -> dvs_ind[sym], unknowns (affect))
663684
@@ -679,21 +700,29 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
679700 p = filter (x -> ! isnothing (x[2 ]), collect (zip (parameters_syms (affect), p_inds))) |>
680701 NamedTuple
681702
682- let u = u, p = p, user_affect = func (affect), ctx = context (affect)
703+ if has_index_cache (sys) && (ic = get_index_cache (sys)) != = nothing
704+ save_idxs = get (ic. callback_to_clocks, cb, Int[])
705+ else
706+ save_idxs = Int[]
707+ end
708+ let u = u, p = p, user_affect = func (affect), ctx = context (affect), save_idxs = save_idxs
683709 function (integ)
684710 user_affect (integ, u, p, ctx)
711+ for idx in save_idxs
712+ SciMLBase. save_discretes! (integ, idx)
713+ end
685714 end
686715 end
687716end
688717
689- function compile_affect (affect:: FunctionalAffect , sys, dvs, ps; kwargs... )
690- compile_user_affect (affect, sys, dvs, ps; kwargs... )
718+ function compile_affect (affect:: FunctionalAffect , cb, sys, dvs, ps; kwargs... )
719+ compile_user_affect (affect, cb, sys, dvs, ps; kwargs... )
691720end
692721
693722function generate_timed_callback (cb, sys, dvs, ps; postprocess_affect_expr! = nothing ,
694723 kwargs... )
695724 cond = condition (cb)
696- as = compile_affect (affects (cb), sys, dvs, ps; expression = Val{false },
725+ as = compile_affect (affects (cb), cb, sys, dvs, ps; expression = Val{false },
697726 postprocess_affect_expr!, kwargs... )
698727 if cond isa AbstractVector
699728 # Preset Time
@@ -711,7 +740,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
711740 kwargs... )
712741 else
713742 c = compile_condition (cb, sys, dvs, ps; expression = Val{false }, kwargs... )
714- as = compile_affect (affects (cb), sys, dvs, ps; expression = Val{false },
743+ as = compile_affect (affects (cb), cb, sys, dvs, ps; expression = Val{false },
715744 postprocess_affect_expr!, kwargs... )
716745 return DiscreteCallback (c, as)
717746 end
0 commit comments