@@ -596,6 +596,33 @@ Base.isempty(cb::AbstractCallback) = isempty(cb.conditions)
596596# ###################################
597597# ###### Compilation functions ######
598598# ###################################
599+
600+ struct CompiledCondition{IsDiscrete, F}
601+ f:: F
602+ end
603+
604+ function CompiledCondition {ID} (f:: F ) where {ID, F}
605+ return CompiledCondition {ID, F} (f)
606+ end
607+
608+ function (cc:: CompiledCondition )(out, u, t, integ)
609+ cc. f (out, u, parameter_values (integ), t)
610+ end
611+
612+ function (cc:: CompiledCondition{false} )(u, t, integ)
613+ if DiffEqBase. isinplace (SciMLBase. get_sol (integ). prob)
614+ tmp, = DiffEqBase. get_tmp_cache (integ)
615+ cc. f (tmp, u, parameter_values (integ), t)
616+ tmp[1 ]
617+ else
618+ cc. f (u, parameter_values (integ), t)
619+ end
620+ end
621+
622+ function (cc:: CompiledCondition{true} )(u, t, integ)
623+ cc. f (u, parameter_values (integ), t)
624+ end
625+
599626"""
600627 compile_condition(cb::AbstractCallback, sys, dvs, ps; expression, kwargs...)
601628
@@ -615,30 +642,19 @@ function compile_condition(
615642 end
616643
617644 if ! is_discrete (cbs)
618- condit = reduce (vcat, flatten_equations (condit))
645+ condit = reduce (vcat, flatten_equations (Vector {Equation} ( condit) ))
619646 condit = condit isa AbstractVector ? [c. lhs - c. rhs for c in condit] :
620647 [condit. lhs - condit. rhs]
621648 end
622649
623650 fs = build_function_wrapper (
624- sys, condit, u, p... , t; kwargs... , expression = Val{false }, cse = false )
625- (f_oop, f_iip) = is_discrete (cbs) ? (fs, nothing ) : fs
626-
627- cond = if cbs isa AbstractVector
628- (out, u, t, integ) -> f_iip (out, u, parameter_values (integ), t)
629- elseif is_discrete (cbs)
630- (u, t, integ) -> f_oop (u, parameter_values (integ), t)
631- else
632- function (u, t, integ)
633- if DiffEqBase. isinplace (SciMLBase. get_sol (integ). prob)
634- tmp, = DiffEqBase. get_tmp_cache (integ)
635- f_iip (tmp, u, parameter_values (integ), t)
636- tmp[1 ]
637- else
638- f_oop (u, parameter_values (integ), t)
639- end
640- end
651+ sys, condit, u, p... , t; kwargs... , cse = false )
652+ if is_discrete (cbs)
653+ fs = (fs, nothing )
641654 end
655+ fs = GeneratedFunctionWrapper {(2, 3, is_split(sys))} (
656+ Val{false }, fs... ; eval_expression, eval_module)
657+ return CompiledCondition {is_discrete(cbs)} (fs)
642658end
643659
644660"""
0 commit comments