Skip to content

Commit a299ec6

Browse files
fix: fix compile_condition, respect eval_expression and eval_module
1 parent da19258 commit a299ec6

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

src/systems/callbacks.jl

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,33 @@ Base.isempty(cb::AbstractCallback) = isempty(cb.conditions)
596596
####################################
597597
####### Compilation functions ######
598598
####################################
599+
600+
struct CompiledCondition{IsDiscrete, F}
601+
f::F
602+
end
603+
604+
function CompiledCondition{ID}(f::F) where {ID, F}
605+
return CompiledCondition{ID, F}(f)
606+
end
607+
608+
function (cc::CompiledCondition)(out, u, t, integ)
609+
cc.f(out, u, parameter_values(integ), t)
610+
end
611+
612+
function (cc::CompiledCondition{false})(u, t, integ)
613+
if DiffEqBase.isinplace(SciMLBase.get_sol(integ).prob)
614+
tmp, = DiffEqBase.get_tmp_cache(integ)
615+
cc.f(tmp, u, parameter_values(integ), t)
616+
tmp[1]
617+
else
618+
cc.f(u, parameter_values(integ), t)
619+
end
620+
end
621+
622+
function (cc::CompiledCondition{true})(u, t, integ)
623+
cc.f(u, parameter_values(integ), t)
624+
end
625+
599626
"""
600627
compile_condition(cb::AbstractCallback, sys, dvs, ps; expression, kwargs...)
601628
@@ -615,30 +642,19 @@ function compile_condition(
615642
end
616643

617644
if !is_discrete(cbs)
618-
condit = reduce(vcat, flatten_equations(condit))
645+
condit = reduce(vcat, flatten_equations(Vector{Equation}(condit)))
619646
condit = condit isa AbstractVector ? [c.lhs - c.rhs for c in condit] :
620647
[condit.lhs - condit.rhs]
621648
end
622649

623650
fs = build_function_wrapper(
624-
sys, condit, u, p..., t; kwargs..., expression = Val{false}, cse = false)
625-
(f_oop, f_iip) = is_discrete(cbs) ? (fs, nothing) : fs
626-
627-
cond = if cbs isa AbstractVector
628-
(out, u, t, integ) -> f_iip(out, u, parameter_values(integ), t)
629-
elseif is_discrete(cbs)
630-
(u, t, integ) -> f_oop(u, parameter_values(integ), t)
631-
else
632-
function (u, t, integ)
633-
if DiffEqBase.isinplace(SciMLBase.get_sol(integ).prob)
634-
tmp, = DiffEqBase.get_tmp_cache(integ)
635-
f_iip(tmp, u, parameter_values(integ), t)
636-
tmp[1]
637-
else
638-
f_oop(u, parameter_values(integ), t)
639-
end
640-
end
651+
sys, condit, u, p..., t; kwargs..., cse = false)
652+
if is_discrete(cbs)
653+
fs = (fs, nothing)
641654
end
655+
fs = GeneratedFunctionWrapper{(2, 3, is_split(sys))}(
656+
Val{false}, fs...; eval_expression, eval_module)
657+
return CompiledCondition{is_discrete(cbs)}(fs)
642658
end
643659

644660
"""

0 commit comments

Comments
 (0)