Skip to content

Commit 54bb95a

Browse files
committed
Fix tests
1 parent e6ce6ab commit 54bb95a

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

src/systems/callbacks.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -781,21 +781,24 @@ function generate_single_rootfinding_callback(
781781
end
782782
end
783783

784+
user_initfun = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function.initialize(i)
784785
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing &&
785786
(save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing
786787
initfn = let save_idxs = save_idxs
787788
function (cb, u, t, integrator)
789+
user_initfun(cb, u, t, integrator)
788790
for idx in save_idxs
789791
SciMLBase.save_discretes!(integrator, idx)
790792
end
791793
end
792794
end
793795
else
794-
initfn = SciMLBase.INITIALIZE_DEFAULT
796+
initfn = user_initfun
795797
end
798+
796799
return ContinuousCallback(
797800
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind,
798-
initialize = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function.initialize(i),
801+
initialize = initfn,
799802
finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i),
800803
initializealg = reinitialization_alg(cb))
801804
end
@@ -878,8 +881,8 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
878881
eq_aff = affects(cb)
879882
eq_neg_aff = affect_negs(cb)
880883
affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
881-
function compile_optional_affect(aff)
882-
if isnothing(aff)
884+
function compile_optional_affect(aff, default=nothing)
885+
if isnothing(aff) || aff==default
883886
return nothing
884887
else
885888
return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...)
@@ -890,8 +893,8 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
890893
else
891894
affect_neg = compile_optional_affect(eq_neg_aff)
892895
end
893-
initialize = compile_optional_affect(initialize_affects(cb))
894-
finalize = compile_optional_affect(finalize_affects(cb))
896+
initialize = compile_optional_affect(initialize_affects(cb), NULL_AFFECT)
897+
finalize = compile_optional_affect(finalize_affects(cb), NULL_AFFECT)
895898
(affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize)
896899
end
897900

@@ -1097,11 +1100,11 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
10971100
let user_affect = func(affect), ctx = context(affect)
10981101
function (integ)
10991102
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
1100-
modvals = mod_og_val_fun(integ.u, integ.p..., integ.t)
1103+
modvals = mod_og_val_fun(integ.u, integ.p, integ.t)
11011104
upd_component_array = NamedTuple{mod_names}(modvals)
11021105

11031106
# update the observed values
1104-
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(integ.u, integ.p..., integ.t))
1107+
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(integ.u, integ.p, integ.t))
11051108

11061109
# let the user do their thing
11071110
modvals = if applicable(user_affect, upd_component_array, obs_component_array, ctx, integ)

test/symbolic_events.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,3 +1219,17 @@ end
12191219
sol = solve(prob, Tsit5(); dtmax = 0.01)
12201220
@test getp(sol, cnt)(sol) == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state
12211221
end
1222+
1223+
1224+
1225+
import RuntimeGeneratedFunctions
1226+
function (f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id})(args::Vararg{Any, N}) where {N, argnames, cache_tag, context_tag, id}
1227+
try
1228+
RuntimeGeneratedFunctions.generated_callfunc(f, args...)
1229+
catch e
1230+
@error "Caught error in RuntimeGeneratedFunction; source code follows"
1231+
func_expr = Expr(:->, Expr(:tuple, argnames...), RuntimeGeneratedFunctions._lookup_body(cache_tag, id))
1232+
@show func_expr
1233+
rethrow(e)
1234+
end
1235+
end

0 commit comments

Comments
 (0)