@@ -388,6 +388,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
388388 return (args... ) -> () # We don't do anything in the callback, we're just after the event
389389 end
390390 else
391+ eqs = flatten_equations (eqs)
391392 rhss = map (x -> x. rhs, eqs)
392393 outvar = :u
393394 if outputidxs === nothing
457458
458459function generate_rootfinding_callback (cbs, sys:: AbstractODESystem , dvs = unknowns (sys),
459460 ps = full_parameters (sys); kwargs... )
460- eqs = map (cb -> cb. eqs, cbs)
461+ eqs = map (cb -> flatten_equations ( cb. eqs) , cbs)
461462 num_eqs = length .(eqs)
462463 (isempty (eqs) || sum (num_eqs) == 0 ) && return nothing
463464 # fuse equations to create VectorContinuousCallback
@@ -471,12 +472,8 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
471472 rhss = map (x -> x. rhs, eqs)
472473 root_eq_vars = unique (collect (Iterators. flatten (map (ModelingToolkit. vars, rhss))))
473474
474- u = map (x -> time_varying_as_func (value (x), sys), dvs)
475- p = map .(x -> time_varying_as_func (value (x), sys), reorder_parameters (sys, ps))
476- t = get_iv (sys)
477- pre = get_preprocess_constants (rhss)
478- rf_oop, rf_ip = build_function (rhss, u, p... , t; expression = Val{false },
479- postprocess_fbody = pre, kwargs... )
475+ rf_oop, rf_ip = generate_custom_function (sys, rhss, dvs, ps; expression = Val{false },
476+ kwargs... )
480477
481478 affect_functions = map (cbs) do cb # Keep affect function separate
482479 eq_aff = affects (cb)
@@ -487,16 +484,16 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
487484 cond = function (u, t, integ)
488485 if DiffEqBase. isinplace (integ. sol. prob)
489486 tmp, = DiffEqBase. get_tmp_cache (integ)
490- rf_ip (tmp, u, parameter_values (integ)... , t)
487+ rf_ip (tmp, u, parameter_values (integ), t)
491488 tmp[1 ]
492489 else
493- rf_oop (u, parameter_values (integ)... , t)
490+ rf_oop (u, parameter_values (integ), t)
494491 end
495492 end
496493 ContinuousCallback (cond, affect_functions[])
497494 else
498495 cond = function (out, u, t, integ)
499- rf_ip (out, u, parameter_values (integ)... , t)
496+ rf_ip (out, u, parameter_values (integ), t)
500497 end
501498
502499 # since there may be different number of conditions and affects,
0 commit comments