Skip to content

Commit 7e5b26e

Browse files
committed
batch DiscreteCallbacks for better performance
- no need to evaluate obsf multiple times - no need to have huge tuples of discrete callbacks in CallbackSet
1 parent 34c9152 commit 7e5b26e

File tree

1 file changed

+195
-67
lines changed

1 file changed

+195
-67
lines changed

src/callbacks.jl

Lines changed: 195 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,12 @@ function wrap_component_callbacks(nw)
256256
batchcbs = callbacks[typeidx]
257257
if first(batchcbs) isa Union{ContinousComponentCallback, VectorContinousComponentCallback}
258258
cb = ContinousCallbackWrapper(nw, batchcomps, batchcbs)
259-
elseif only(batchcbs) isa Union{DiscreteComponentCallback, PresetTimeComponentCallback}
260-
cb = DiscreteCallbackWrapper(nw, only(batchcomps), only(batchcbs))
259+
elseif first(batchcbs) isa DiscreteComponentCallback
260+
cb = DiscreteCallbackWrapper(nw, batchcomps, batchcbs)
261+
elseif first(batchcbs) isa PresetTimeComponentCallback
262+
# PresetTimeCallbacks cannot be batched - must be single component
263+
@assert length(batchcbs) == 1 "PresetTimeComponentCallback cannot be batched"
264+
cb = PresetTimeCallbackWrapper(nw, only(batchcomps), only(batchcbs))
261265
else
262266
error("Unknown callback type, should never be reached. Please report this issue.")
263267
end
@@ -277,6 +281,11 @@ function _batchequal(a::VectorContinousComponentCallback, b::VectorContinousComp
277281
a.len == b.len || return false
278282
return true
279283
end
284+
function _batchequal(a::DiscreteComponentCallback, b::DiscreteComponentCallback)
285+
_batchequal(a.condition, b.condition) || return false
286+
_batchequal(a.kwargs, b.kwargs) || return false
287+
return true
288+
end
280289
function _batchequal(a::T, b::T) where {T <: Union{ComponentCondition, ComponentAffect}}
281290
typeof(a) == typeof(b)
282291
end
@@ -294,6 +303,35 @@ end
294303
# each callback belongs to
295304
abstract type CallbackWrapper end
296305

306+
# Generic functions for all CallbackWrappers with components and callbacks fields
307+
Base.length(cw::CallbackWrapper) = length(cw.callbacks)
308+
cbtype(cw::CallbackWrapper) = eltype(cw.callbacks)
309+
310+
condition_dim(cw::CallbackWrapper) = first(cw.callbacks).condition.sym |> length
311+
condition_pdim(cw::CallbackWrapper) = first(cw.callbacks).condition.psym |> length
312+
affect_dim(cw::CallbackWrapper, i) = cw.callbacks[i].affect.sym |> length
313+
affect_pdim(cw::CallbackWrapper, i) = cw.callbacks[i].affect.psym |> length
314+
315+
condition_urange(cw::CallbackWrapper, i) = (1 + (i-1)*condition_dim(cw)) : i*condition_dim(cw)
316+
condition_prange(cw::CallbackWrapper, i) = (1 + (i-1)*condition_pdim(cw)) : i*condition_pdim(cw)
317+
affect_urange(cw::CallbackWrapper, i) = (1 + (i-1)*affect_dim(cw,i) ) : i*affect_dim(cw,i)
318+
affect_prange(cw::CallbackWrapper, i) = (1 + (i-1)*affect_pdim(cw,i)) : i*affect_pdim(cw,i)
319+
320+
function collect_c_or_a_indices(cw::CallbackWrapper, c_or_a, u_or_p)
321+
sidxs = SymbolicIndex[]
322+
for (component, cb) in zip(cw.components, cw.callbacks)
323+
syms = getproperty(getproperty(cb, c_or_a), u_or_p)
324+
symidxtype = if component isa VIndex
325+
u_or_p == :sym ? VIndex : VPIndex
326+
else
327+
u_or_p ==:sym ? EIndex : EPIndex
328+
end
329+
sidx = collect(symidxtype(component.compidx, syms))
330+
append!(sidxs, sidx)
331+
end
332+
sidxs
333+
end
334+
297335
####
298336
#### wrapping of continous callbacks
299337
####
@@ -316,38 +354,11 @@ function ContinousCallbackWrapper(nw, components, callbacks)
316354
ContinousCallbackWrapper(nw, components, callbacks, sublen, condition)
317355
end
318356

319-
Base.length(ccw::ContinousCallbackWrapper) = length(ccw.callbacks)
320-
cbtype(ccw::ContinousCallbackWrapper{T}) where {T} = T
321-
322-
condition_dim(ccw) = first(ccw.callbacks).condition.sym |> length
323-
condition_pdim(ccw) = first(ccw.callbacks).condition.psym |> length
324-
affect_dim(ccw,i) = ccw.callbacks[i].affect.sym |> length
325-
affect_pdim(ccw,i) = ccw.callbacks[i].affect.psym |> length
326-
327-
condition_urange(ccw, i) = (1 + (i-1)*condition_dim(ccw)) : i*condition_dim(ccw)
328-
condition_prange(ccw, i) = (1 + (i-1)*condition_pdim(ccw)) : i*condition_pdim(ccw)
329-
affect_urange(ccw, i) = (1 + (i-1)*affect_dim(ccw,i) ) : i*affect_dim(ccw,i)
330-
affect_prange(ccw, i) = (1 + (i-1)*affect_pdim(ccw,i)) : i*affect_pdim(ccw,i)
331-
332-
condition_outrange(ccw, i) = (1 + (i-1)*ccw.sublen) : i*ccw.sublen
357+
# Continuous-specific functions (for vector callbacks)
358+
condition_outrange(ccw::ContinousCallbackWrapper, i) = (1 + (i-1)*ccw.sublen) : i*ccw.sublen
333359

334-
cbidx_from_outidx(ccw, outidx) = div(outidx-1, ccw.sublen) + 1
335-
subidx_from_outidx(ccw, outidx) = mod(outidx, 1:ccw.sublen)
336-
337-
function collect_c_or_a_indices(ccw::ContinousCallbackWrapper, c_or_a, u_or_p)
338-
sidxs = SymbolicIndex[]
339-
for (component, cb) in zip(ccw.components, ccw.callbacks)
340-
syms = getproperty(getproperty(cb, c_or_a), u_or_p)
341-
symidxtype = if component isa VIndex
342-
u_or_p == :sym ? VIndex : VPIndex
343-
else
344-
u_or_p ==:sym ? EIndex : EPIndex
345-
end
346-
sidx = collect(symidxtype(component.compidx, syms))
347-
append!(sidxs, sidx)
348-
end
349-
sidxs
350-
end
360+
cbidx_from_outidx(ccw::ContinousCallbackWrapper, outidx) = div(outidx-1, ccw.sublen) + 1
361+
subidx_from_outidx(ccw::ContinousCallbackWrapper, outidx) = mod(outidx, 1:ccw.sublen)
351362

352363
# generate VectorContinuousCallback from a ContinousCallbackWrapper
353364
function to_callback(ccw::ContinousCallbackWrapper)
@@ -457,37 +468,37 @@ end
457468
####
458469
#### wrapping of discrete callbacks
459470
####
460-
struct DiscreteCallbackWrapper{ST,T} <: CallbackWrapper
471+
struct DiscreteCallbackWrapper{ST,T,C} <: CallbackWrapper
461472
nw::Network
462-
component::ST
463-
callback::T
464-
function DiscreteCallbackWrapper(nw, component, callback)
465-
@assert nw isa Network
466-
@assert component isa SymbolicIndex
467-
@assert callback isa Union{DiscreteComponentCallback, PresetTimeComponentCallback}
468-
new{typeof(component),typeof(callback)}(nw, component, callback)
473+
components::Vector{ST} # Changed to support batching
474+
callbacks::Vector{T} # Changed to support batching
475+
condition::C # Store condition function to avoid dynamic dispatch
476+
end
477+
function DiscreteCallbackWrapper(nw, components, callbacks)
478+
@assert nw isa Network
479+
@assert all(c -> c isa SymbolicIndex, components)
480+
@assert all(cb -> cb isa DiscreteComponentCallback, callbacks) # Only DiscreteComponentCallback
481+
if !isconcretetype(eltype(components))
482+
components = [c for c in components]
469483
end
484+
if !isconcretetype(eltype(callbacks))
485+
callbacks = [cb for cb in callbacks]
486+
end
487+
# Extract condition function - all callbacks in batch have identical conditions
488+
condition = first(callbacks).condition.f
489+
DiscreteCallbackWrapper{eltype(components),eltype(callbacks),typeof(condition)}(nw, components, callbacks, condition)
470490
end
471491

472492
# generate a DiscreteCallback from a DiscreteCallbackWrapper
473493
function to_callback(dcw::DiscreteCallbackWrapper)
474-
kwargs = dcw.callback.kwargs
494+
kwargs = first(dcw.callbacks).kwargs
475495
cond = _batch_condition(dcw)
476496
affect = _batch_affect(dcw)
477497
DiscreteCallback(cond, affect; kwargs...)
478498
end
479-
# generate a PresetTimeCallback from a DiscreteCallbackWrapper
480-
function to_callback(dcw::DiscreteCallbackWrapper{<:Any,<:PresetTimeComponentCallback})
481-
kwargs = dcw.callback.kwargs
482-
affect = _batch_affect(dcw)
483-
ts = dcw.callback.ts
484-
DiffEqCallbacks.PresetTimeCallback(ts, affect; kwargs...)
485-
end
486499
function _batch_condition(dcw::DiscreteCallbackWrapper)
487-
uidxtype = dcw.component isa EIndex ? EIndex : VIndex
488-
pidxtype = dcw.component isa EIndex ? EPIndex : VPIndex
489-
usymidxs = uidxtype(dcw.component.compidx, dcw.callback.condition.sym)
490-
psymidxs = pidxtype(dcw.component.compidx, dcw.callback.condition.psym)
500+
usymidxs = collect_c_or_a_indices(dcw, :condition, :sym)
501+
psymidxs = collect_c_or_a_indices(dcw, :condition, :psym)
491502
ucache = DiffCache(zeros(length(usymidxs)), 12)
492503

493504
obsf = SII.observed(dcw.nw, usymidxs)
@@ -502,37 +513,154 @@ function _batch_condition(dcw::DiscreteCallbackWrapper)
502513
(u, t, integrator) -> begin
503514
us = PreallocationTools.get_tmp(ucache, u)
504515
obsf(u, integrator.p, t, us) # fills us inplace
505-
_u = SymbolicView(us, dcw.callback.condition.sym)
506-
pv = view(integrator.p, pidxs)
507-
_p = SymbolicView(pv, dcw.callback.condition.psym)
508-
dcw.callback.condition.f(_u, _p, t)
516+
517+
# OR logic: return true if ANY component condition is true
518+
for i in 1:length(dcw)
519+
# symbolic view into u
520+
uv = view(us, condition_urange(dcw, i))
521+
_u = SymbolicView(uv, dcw.callbacks[i].condition.sym)
522+
523+
# symbolic view into p
524+
pidxsv = view(pidxs, condition_prange(dcw, i))
525+
pv = view(integrator.p, pidxsv)
526+
_p = SymbolicView(pv, dcw.callbacks[i].condition.psym)
527+
528+
# If any condition is true, trigger the callback
529+
if dcw.condition(_u, _p, t)
530+
return true
531+
end
532+
end
533+
return false
509534
end
510535
end
511536
function _batch_affect(dcw::DiscreteCallbackWrapper)
512-
uidxtype = dcw.component isa EIndex ? EIndex : VIndex
513-
pidxtype = dcw.component isa EIndex ? EPIndex : VPIndex
514-
usymidxs = uidxtype(dcw.component.compidx, dcw.callback.affect.sym)
515-
psymidxs = pidxtype(dcw.component.compidx, dcw.callback.affect.psym)
537+
# Setup for condition re-evaluation
538+
cusymidxs = collect_c_or_a_indices(dcw, :condition, :sym)
539+
cpsymidxs = collect_c_or_a_indices(dcw, :condition, :psym)
540+
cucache = DiffCache(zeros(length(cusymidxs)), 12)
541+
cobsf = SII.observed(dcw.nw, cusymidxs)
542+
cpidxs = SII.parameter_index.(Ref(dcw.nw), cpsymidxs)
516543

517-
uidxs = SII.variable_index.(Ref(dcw.nw), usymidxs)
518-
pidxs = SII.parameter_index.(Ref(dcw.nw), psymidxs)
544+
# Setup for affect execution
545+
ausymidxs = collect_c_or_a_indices(dcw, :affect, :sym)
546+
apsymidxs = collect_c_or_a_indices(dcw, :affect, :psym)
547+
548+
auidxs = SII.variable_index.(Ref(dcw.nw), ausymidxs)
549+
apidxs = SII.parameter_index.(Ref(dcw.nw), apsymidxs)
550+
551+
if any(isnothing, auidxs) || any(isnothing, apidxs)
552+
missing_u = []
553+
if any(isnothing, auidxs)
554+
nidxs = findall(isnothing, auidxs)
555+
append!(missing_u, ausymidxs[nidxs])
556+
end
557+
missing_p = []
558+
if any(isnothing, apidxs)
559+
nidxs = findall(isnothing, apidxs)
560+
append!(missing_p, apsymidxs[nidxs])
561+
end
562+
throw(ArgumentError(
563+
"Cannot build callback as it contains refrences to undefined symbols:\n"*
564+
(isempty(missing_u) ? "" : "Missing state symbols: $(missing_u)\n")*
565+
(isempty(missing_p) ? "" : "Missing param symbols: $(missing_p)\n")
566+
))
567+
end
519568

520569
(integrator) -> begin
570+
# Re-evaluate all conditions to determine which affects to execute
571+
cus = PreallocationTools.get_tmp(cucache, integrator.u)
572+
cobsf(integrator.u, integrator.p, integrator.t, cus)
573+
574+
any_uchanged = false
575+
any_pchanged = false
576+
577+
for i in 1:length(dcw)
578+
# Re-evaluate condition for component i
579+
cuv = view(cus, condition_urange(dcw, i))
580+
c_u = SymbolicView(cuv, dcw.callbacks[i].condition.sym)
581+
cpidxsv = view(cpidxs, condition_prange(dcw, i))
582+
cpv = view(integrator.p, cpidxsv)
583+
c_p = SymbolicView(cpv, dcw.callbacks[i].condition.psym)
584+
585+
# Only execute affect if condition is true
586+
if dcw.condition(c_u, c_p, integrator.t)
587+
# Execute affect for component i
588+
auidxsv = view(auidxs, affect_urange(dcw, i))
589+
auv = view(integrator.u, auidxsv)
590+
a_u = SymbolicView(auv, dcw.callbacks[i].affect.sym)
591+
592+
apidxsv = view(apidxs, affect_prange(dcw, i))
593+
apv = view(integrator.p, apidxsv)
594+
a_p = SymbolicView(apv, dcw.callbacks[i].affect.psym)
595+
596+
ctx = get_ctx(integrator, dcw.components[i])
597+
598+
uhash = hash(auv)
599+
phash = hash(apv)
600+
dcw.callbacks[i].affect.f(a_u, a_p, ctx)
601+
pchanged = hash(apv) != phash
602+
uchanged = hash(auv) != uhash
603+
604+
any_uchanged = any_uchanged || uchanged
605+
any_pchanged = any_pchanged || pchanged
606+
end
607+
end
608+
609+
(any_uchanged || any_pchanged) && SciMLBase.auto_dt_reset!(integrator)
610+
any_pchanged && save_parameters!(integrator)
611+
end
612+
end
613+
614+
####
615+
#### wrapping of preset time callbacks
616+
####
617+
struct PresetTimeCallbackWrapper{ST,T}
618+
nw::Network
619+
component::ST # Single component - PresetTime callbacks cannot be batched
620+
callback::T # Single callback - PresetTime callbacks cannot be batched
621+
function PresetTimeCallbackWrapper(nw, component::SymbolicIndex, callback::PresetTimeComponentCallback)
622+
@assert nw isa Network
623+
@assert component isa SymbolicIndex
624+
@assert callback isa PresetTimeComponentCallback
625+
# PresetTimeCallbacks cannot be batched, so always single component/callback
626+
new{typeof(component), typeof(callback)}(nw, component, callback)
627+
end
628+
end
629+
630+
# generate a PresetTimeCallback from a PresetTimeCallbackWrapper
631+
function to_callback(ptcw::PresetTimeCallbackWrapper)
632+
callback = ptcw.callback
633+
component = ptcw.component
634+
kwargs = callback.kwargs
635+
ts = callback.ts
636+
637+
# Create affect function for the single component
638+
uidxtype = component isa EIndex ? EIndex : VIndex
639+
pidxtype = component isa EIndex ? EPIndex : VPIndex
640+
usymidxs = uidxtype(component.compidx, callback.affect.sym)
641+
psymidxs = pidxtype(component.compidx, callback.affect.psym)
642+
643+
uidxs = SII.variable_index.(Ref(ptcw.nw), usymidxs)
644+
pidxs = SII.parameter_index.(Ref(ptcw.nw), psymidxs)
645+
646+
affect = (integrator) -> begin
521647
uv = view(integrator.u, uidxs)
522-
_u = SymbolicView(uv, dcw.callback.affect.sym)
648+
_u = SymbolicView(uv, callback.affect.sym)
523649
pv = view(integrator.p, pidxs)
524-
_p = SymbolicView(pv, dcw.callback.affect.psym)
525-
ctx = get_ctx(integrator, dcw.component)
650+
_p = SymbolicView(pv, callback.affect.psym)
651+
ctx = get_ctx(integrator, component)
526652

527653
uhash = hash(uv)
528654
phash = hash(pv)
529-
dcw.callback.affect.f(_u, _p, ctx)
655+
callback.affect.f(_u, _p, ctx)
530656
pchanged = hash(pv) != phash
531657
uchanged = hash(uv) != uhash
532658

533659
(pchanged || uchanged) && SciMLBase.auto_dt_reset!(integrator)
534660
pchanged && save_parameters!(integrator)
535661
end
662+
663+
DiffEqCallbacks.PresetTimeCallback(ts, affect; kwargs...)
536664
end
537665

538666

0 commit comments

Comments
 (0)