@@ -256,8 +256,12 @@ function wrap_component_callbacks(nw)
256256 batchcbs = callbacks[typeidx]
257257 if first (batchcbs) isa Union{ContinousComponentCallback, VectorContinousComponentCallback}
258258 cb = ContinousCallbackWrapper (nw, batchcomps, batchcbs)
259- elseif only (batchcbs) isa Union{DiscreteComponentCallback, PresetTimeComponentCallback}
260- cb = DiscreteCallbackWrapper (nw, only (batchcomps), only (batchcbs))
259+ elseif first (batchcbs) isa DiscreteComponentCallback
260+ cb = DiscreteCallbackWrapper (nw, batchcomps, batchcbs)
261+ elseif first (batchcbs) isa PresetTimeComponentCallback
262+ # PresetTimeCallbacks cannot be batched - must be single component
263+ @assert length (batchcbs) == 1 " PresetTimeComponentCallback cannot be batched"
264+ cb = PresetTimeCallbackWrapper (nw, only (batchcomps), only (batchcbs))
261265 else
262266 error (" Unknown callback type, should never be reached. Please report this issue." )
263267 end
@@ -277,6 +281,11 @@ function _batchequal(a::VectorContinousComponentCallback, b::VectorContinousComp
277281 a. len == b. len || return false
278282 return true
279283end
284+ function _batchequal (a:: DiscreteComponentCallback , b:: DiscreteComponentCallback )
285+ _batchequal (a. condition, b. condition) || return false
286+ _batchequal (a. kwargs, b. kwargs) || return false
287+ return true
288+ end
280289function _batchequal (a:: T , b:: T ) where {T <: Union{ComponentCondition, ComponentAffect} }
281290 typeof (a) == typeof (b)
282291end
294303# each callback belongs to
295304abstract type CallbackWrapper end
296305
306+ # Generic functions for all CallbackWrappers with components and callbacks fields
307+ Base. length (cw:: CallbackWrapper ) = length (cw. callbacks)
308+ cbtype (cw:: CallbackWrapper ) = eltype (cw. callbacks)
309+
310+ condition_dim (cw:: CallbackWrapper ) = first (cw. callbacks). condition. sym |> length
311+ condition_pdim (cw:: CallbackWrapper ) = first (cw. callbacks). condition. psym |> length
312+ affect_dim (cw:: CallbackWrapper , i) = cw. callbacks[i]. affect. sym |> length
313+ affect_pdim (cw:: CallbackWrapper , i) = cw. callbacks[i]. affect. psym |> length
314+
315+ condition_urange (cw:: CallbackWrapper , i) = (1 + (i- 1 )* condition_dim (cw)) : i* condition_dim (cw)
316+ condition_prange (cw:: CallbackWrapper , i) = (1 + (i- 1 )* condition_pdim (cw)) : i* condition_pdim (cw)
317+ affect_urange (cw:: CallbackWrapper , i) = (1 + (i- 1 )* affect_dim (cw,i) ) : i* affect_dim (cw,i)
318+ affect_prange (cw:: CallbackWrapper , i) = (1 + (i- 1 )* affect_pdim (cw,i)) : i* affect_pdim (cw,i)
319+
320+ function collect_c_or_a_indices (cw:: CallbackWrapper , c_or_a, u_or_p)
321+ sidxs = SymbolicIndex[]
322+ for (component, cb) in zip (cw. components, cw. callbacks)
323+ syms = getproperty (getproperty (cb, c_or_a), u_or_p)
324+ symidxtype = if component isa VIndex
325+ u_or_p == :sym ? VIndex : VPIndex
326+ else
327+ u_or_p == :sym ? EIndex : EPIndex
328+ end
329+ sidx = collect (symidxtype (component. compidx, syms))
330+ append! (sidxs, sidx)
331+ end
332+ sidxs
333+ end
334+
297335# ###
298336# ### wrapping of continous callbacks
299337# ###
@@ -316,38 +354,11 @@ function ContinousCallbackWrapper(nw, components, callbacks)
316354 ContinousCallbackWrapper (nw, components, callbacks, sublen, condition)
317355end
318356
319- Base. length (ccw:: ContinousCallbackWrapper ) = length (ccw. callbacks)
320- cbtype (ccw:: ContinousCallbackWrapper{T} ) where {T} = T
321-
322- condition_dim (ccw) = first (ccw. callbacks). condition. sym |> length
323- condition_pdim (ccw) = first (ccw. callbacks). condition. psym |> length
324- affect_dim (ccw,i) = ccw. callbacks[i]. affect. sym |> length
325- affect_pdim (ccw,i) = ccw. callbacks[i]. affect. psym |> length
326-
327- condition_urange (ccw, i) = (1 + (i- 1 )* condition_dim (ccw)) : i* condition_dim (ccw)
328- condition_prange (ccw, i) = (1 + (i- 1 )* condition_pdim (ccw)) : i* condition_pdim (ccw)
329- affect_urange (ccw, i) = (1 + (i- 1 )* affect_dim (ccw,i) ) : i* affect_dim (ccw,i)
330- affect_prange (ccw, i) = (1 + (i- 1 )* affect_pdim (ccw,i)) : i* affect_pdim (ccw,i)
331-
332- condition_outrange (ccw, i) = (1 + (i- 1 )* ccw. sublen) : i* ccw. sublen
357+ # Continuous-specific functions (for vector callbacks)
358+ condition_outrange (ccw:: ContinousCallbackWrapper , i) = (1 + (i- 1 )* ccw. sublen) : i* ccw. sublen
333359
334- cbidx_from_outidx (ccw, outidx) = div (outidx- 1 , ccw. sublen) + 1
335- subidx_from_outidx (ccw, outidx) = mod (outidx, 1 : ccw. sublen)
336-
337- function collect_c_or_a_indices (ccw:: ContinousCallbackWrapper , c_or_a, u_or_p)
338- sidxs = SymbolicIndex[]
339- for (component, cb) in zip (ccw. components, ccw. callbacks)
340- syms = getproperty (getproperty (cb, c_or_a), u_or_p)
341- symidxtype = if component isa VIndex
342- u_or_p == :sym ? VIndex : VPIndex
343- else
344- u_or_p == :sym ? EIndex : EPIndex
345- end
346- sidx = collect (symidxtype (component. compidx, syms))
347- append! (sidxs, sidx)
348- end
349- sidxs
350- end
360+ cbidx_from_outidx (ccw:: ContinousCallbackWrapper , outidx) = div (outidx- 1 , ccw. sublen) + 1
361+ subidx_from_outidx (ccw:: ContinousCallbackWrapper , outidx) = mod (outidx, 1 : ccw. sublen)
351362
352363# generate VectorContinuousCallback from a ContinousCallbackWrapper
353364function to_callback (ccw:: ContinousCallbackWrapper )
@@ -457,37 +468,37 @@ end
457468# ###
458469# ### wrapping of discrete callbacks
459470# ###
460- struct DiscreteCallbackWrapper{ST,T} <: CallbackWrapper
471+ struct DiscreteCallbackWrapper{ST,T,C } <: CallbackWrapper
461472 nw:: Network
462- component:: ST
463- callback:: T
464- function DiscreteCallbackWrapper (nw, component, callback)
465- @assert nw isa Network
466- @assert component isa SymbolicIndex
467- @assert callback isa Union{DiscreteComponentCallback, PresetTimeComponentCallback}
468- new {typeof(component),typeof(callback)} (nw, component, callback)
473+ components:: Vector{ST} # Changed to support batching
474+ callbacks:: Vector{T} # Changed to support batching
475+ condition:: C # Store condition function to avoid dynamic dispatch
476+ end
477+ function DiscreteCallbackWrapper (nw, components, callbacks)
478+ @assert nw isa Network
479+ @assert all (c -> c isa SymbolicIndex, components)
480+ @assert all (cb -> cb isa DiscreteComponentCallback, callbacks) # Only DiscreteComponentCallback
481+ if ! isconcretetype (eltype (components))
482+ components = [c for c in components]
469483 end
484+ if ! isconcretetype (eltype (callbacks))
485+ callbacks = [cb for cb in callbacks]
486+ end
487+ # Extract condition function - all callbacks in batch have identical conditions
488+ condition = first (callbacks). condition. f
489+ DiscreteCallbackWrapper {eltype(components),eltype(callbacks),typeof(condition)} (nw, components, callbacks, condition)
470490end
471491
472492# generate a DiscreteCallback from a DiscreteCallbackWrapper
473493function to_callback (dcw:: DiscreteCallbackWrapper )
474- kwargs = dcw. callback . kwargs
494+ kwargs = first ( dcw. callbacks) . kwargs
475495 cond = _batch_condition (dcw)
476496 affect = _batch_affect (dcw)
477497 DiscreteCallback (cond, affect; kwargs... )
478498end
479- # generate a PresetTimeCallback from a DiscreteCallbackWrapper
480- function to_callback (dcw:: DiscreteCallbackWrapper{<:Any,<:PresetTimeComponentCallback} )
481- kwargs = dcw. callback. kwargs
482- affect = _batch_affect (dcw)
483- ts = dcw. callback. ts
484- DiffEqCallbacks. PresetTimeCallback (ts, affect; kwargs... )
485- end
486499function _batch_condition (dcw:: DiscreteCallbackWrapper )
487- uidxtype = dcw. component isa EIndex ? EIndex : VIndex
488- pidxtype = dcw. component isa EIndex ? EPIndex : VPIndex
489- usymidxs = uidxtype (dcw. component. compidx, dcw. callback. condition. sym)
490- psymidxs = pidxtype (dcw. component. compidx, dcw. callback. condition. psym)
500+ usymidxs = collect_c_or_a_indices (dcw, :condition , :sym )
501+ psymidxs = collect_c_or_a_indices (dcw, :condition , :psym )
491502 ucache = DiffCache (zeros (length (usymidxs)), 12 )
492503
493504 obsf = SII. observed (dcw. nw, usymidxs)
@@ -502,37 +513,154 @@ function _batch_condition(dcw::DiscreteCallbackWrapper)
502513 (u, t, integrator) -> begin
503514 us = PreallocationTools. get_tmp (ucache, u)
504515 obsf (u, integrator. p, t, us) # fills us inplace
505- _u = SymbolicView (us, dcw. callback. condition. sym)
506- pv = view (integrator. p, pidxs)
507- _p = SymbolicView (pv, dcw. callback. condition. psym)
508- dcw. callback. condition. f (_u, _p, t)
516+
517+ # OR logic: return true if ANY component condition is true
518+ for i in 1 : length (dcw)
519+ # symbolic view into u
520+ uv = view (us, condition_urange (dcw, i))
521+ _u = SymbolicView (uv, dcw. callbacks[i]. condition. sym)
522+
523+ # symbolic view into p
524+ pidxsv = view (pidxs, condition_prange (dcw, i))
525+ pv = view (integrator. p, pidxsv)
526+ _p = SymbolicView (pv, dcw. callbacks[i]. condition. psym)
527+
528+ # If any condition is true, trigger the callback
529+ if dcw. condition (_u, _p, t)
530+ return true
531+ end
532+ end
533+ return false
509534 end
510535end
511536function _batch_affect (dcw:: DiscreteCallbackWrapper )
512- uidxtype = dcw. component isa EIndex ? EIndex : VIndex
513- pidxtype = dcw. component isa EIndex ? EPIndex : VPIndex
514- usymidxs = uidxtype (dcw. component. compidx, dcw. callback. affect. sym)
515- psymidxs = pidxtype (dcw. component. compidx, dcw. callback. affect. psym)
537+ # Setup for condition re-evaluation
538+ cusymidxs = collect_c_or_a_indices (dcw, :condition , :sym )
539+ cpsymidxs = collect_c_or_a_indices (dcw, :condition , :psym )
540+ cucache = DiffCache (zeros (length (cusymidxs)), 12 )
541+ cobsf = SII. observed (dcw. nw, cusymidxs)
542+ cpidxs = SII. parameter_index .(Ref (dcw. nw), cpsymidxs)
516543
517- uidxs = SII. variable_index .(Ref (dcw. nw), usymidxs)
518- pidxs = SII. parameter_index .(Ref (dcw. nw), psymidxs)
544+ # Setup for affect execution
545+ ausymidxs = collect_c_or_a_indices (dcw, :affect , :sym )
546+ apsymidxs = collect_c_or_a_indices (dcw, :affect , :psym )
547+
548+ auidxs = SII. variable_index .(Ref (dcw. nw), ausymidxs)
549+ apidxs = SII. parameter_index .(Ref (dcw. nw), apsymidxs)
550+
551+ if any (isnothing, auidxs) || any (isnothing, apidxs)
552+ missing_u = []
553+ if any (isnothing, auidxs)
554+ nidxs = findall (isnothing, auidxs)
555+ append! (missing_u, ausymidxs[nidxs])
556+ end
557+ missing_p = []
558+ if any (isnothing, apidxs)
559+ nidxs = findall (isnothing, apidxs)
560+ append! (missing_p, apsymidxs[nidxs])
561+ end
562+ throw (ArgumentError (
563+ " Cannot build callback as it contains refrences to undefined symbols:\n " *
564+ (isempty (missing_u) ? " " : " Missing state symbols: $(missing_u) \n " )*
565+ (isempty (missing_p) ? " " : " Missing param symbols: $(missing_p) \n " )
566+ ))
567+ end
519568
520569 (integrator) -> begin
570+ # Re-evaluate all conditions to determine which affects to execute
571+ cus = PreallocationTools. get_tmp (cucache, integrator. u)
572+ cobsf (integrator. u, integrator. p, integrator. t, cus)
573+
574+ any_uchanged = false
575+ any_pchanged = false
576+
577+ for i in 1 : length (dcw)
578+ # Re-evaluate condition for component i
579+ cuv = view (cus, condition_urange (dcw, i))
580+ c_u = SymbolicView (cuv, dcw. callbacks[i]. condition. sym)
581+ cpidxsv = view (cpidxs, condition_prange (dcw, i))
582+ cpv = view (integrator. p, cpidxsv)
583+ c_p = SymbolicView (cpv, dcw. callbacks[i]. condition. psym)
584+
585+ # Only execute affect if condition is true
586+ if dcw. condition (c_u, c_p, integrator. t)
587+ # Execute affect for component i
588+ auidxsv = view (auidxs, affect_urange (dcw, i))
589+ auv = view (integrator. u, auidxsv)
590+ a_u = SymbolicView (auv, dcw. callbacks[i]. affect. sym)
591+
592+ apidxsv = view (apidxs, affect_prange (dcw, i))
593+ apv = view (integrator. p, apidxsv)
594+ a_p = SymbolicView (apv, dcw. callbacks[i]. affect. psym)
595+
596+ ctx = get_ctx (integrator, dcw. components[i])
597+
598+ uhash = hash (auv)
599+ phash = hash (apv)
600+ dcw. callbacks[i]. affect. f (a_u, a_p, ctx)
601+ pchanged = hash (apv) != phash
602+ uchanged = hash (auv) != uhash
603+
604+ any_uchanged = any_uchanged || uchanged
605+ any_pchanged = any_pchanged || pchanged
606+ end
607+ end
608+
609+ (any_uchanged || any_pchanged) && SciMLBase. auto_dt_reset! (integrator)
610+ any_pchanged && save_parameters! (integrator)
611+ end
612+ end
613+
614+ # ###
615+ # ### wrapping of preset time callbacks
616+ # ###
617+ struct PresetTimeCallbackWrapper{ST,T}
618+ nw:: Network
619+ component:: ST # Single component - PresetTime callbacks cannot be batched
620+ callback:: T # Single callback - PresetTime callbacks cannot be batched
621+ function PresetTimeCallbackWrapper (nw, component:: SymbolicIndex , callback:: PresetTimeComponentCallback )
622+ @assert nw isa Network
623+ @assert component isa SymbolicIndex
624+ @assert callback isa PresetTimeComponentCallback
625+ # PresetTimeCallbacks cannot be batched, so always single component/callback
626+ new {typeof(component), typeof(callback)} (nw, component, callback)
627+ end
628+ end
629+
630+ # generate a PresetTimeCallback from a PresetTimeCallbackWrapper
631+ function to_callback (ptcw:: PresetTimeCallbackWrapper )
632+ callback = ptcw. callback
633+ component = ptcw. component
634+ kwargs = callback. kwargs
635+ ts = callback. ts
636+
637+ # Create affect function for the single component
638+ uidxtype = component isa EIndex ? EIndex : VIndex
639+ pidxtype = component isa EIndex ? EPIndex : VPIndex
640+ usymidxs = uidxtype (component. compidx, callback. affect. sym)
641+ psymidxs = pidxtype (component. compidx, callback. affect. psym)
642+
643+ uidxs = SII. variable_index .(Ref (ptcw. nw), usymidxs)
644+ pidxs = SII. parameter_index .(Ref (ptcw. nw), psymidxs)
645+
646+ affect = (integrator) -> begin
521647 uv = view (integrator. u, uidxs)
522- _u = SymbolicView (uv, dcw . callback. affect. sym)
648+ _u = SymbolicView (uv, callback. affect. sym)
523649 pv = view (integrator. p, pidxs)
524- _p = SymbolicView (pv, dcw . callback. affect. psym)
525- ctx = get_ctx (integrator, dcw . component)
650+ _p = SymbolicView (pv, callback. affect. psym)
651+ ctx = get_ctx (integrator, component)
526652
527653 uhash = hash (uv)
528654 phash = hash (pv)
529- dcw . callback. affect. f (_u, _p, ctx)
655+ callback. affect. f (_u, _p, ctx)
530656 pchanged = hash (pv) != phash
531657 uchanged = hash (uv) != uhash
532658
533659 (pchanged || uchanged) && SciMLBase. auto_dt_reset! (integrator)
534660 pchanged && save_parameters! (integrator)
535661 end
662+
663+ DiffEqCallbacks. PresetTimeCallback (ts, affect; kwargs... )
536664end
537665
538666
0 commit comments