@@ -256,8 +256,12 @@ function wrap_component_callbacks(nw)
256
256
batchcbs = callbacks[typeidx]
257
257
if first (batchcbs) isa Union{ContinousComponentCallback, VectorContinousComponentCallback}
258
258
cb = ContinousCallbackWrapper (nw, batchcomps, batchcbs)
259
- elseif only (batchcbs) isa Union{DiscreteComponentCallback, PresetTimeComponentCallback}
260
- cb = DiscreteCallbackWrapper (nw, only (batchcomps), only (batchcbs))
259
+ elseif first (batchcbs) isa DiscreteComponentCallback
260
+ cb = DiscreteCallbackWrapper (nw, batchcomps, batchcbs)
261
+ elseif first (batchcbs) isa PresetTimeComponentCallback
262
+ # PresetTimeCallbacks cannot be batched - must be single component
263
+ @assert length (batchcbs) == 1 " PresetTimeComponentCallback cannot be batched"
264
+ cb = PresetTimeCallbackWrapper (nw, only (batchcomps), only (batchcbs))
261
265
else
262
266
error (" Unknown callback type, should never be reached. Please report this issue." )
263
267
end
@@ -277,6 +281,11 @@ function _batchequal(a::VectorContinousComponentCallback, b::VectorContinousComp
277
281
a. len == b. len || return false
278
282
return true
279
283
end
284
+ function _batchequal (a:: DiscreteComponentCallback , b:: DiscreteComponentCallback )
285
+ _batchequal (a. condition, b. condition) || return false
286
+ _batchequal (a. kwargs, b. kwargs) || return false
287
+ return true
288
+ end
280
289
function _batchequal (a:: T , b:: T ) where {T <: Union{ComponentCondition, ComponentAffect} }
281
290
typeof (a) == typeof (b)
282
291
end
294
303
# each callback belongs to
295
304
abstract type CallbackWrapper end
296
305
306
+ # Generic functions for all CallbackWrappers with components and callbacks fields
307
+ Base. length (cw:: CallbackWrapper ) = length (cw. callbacks)
308
+ cbtype (cw:: CallbackWrapper ) = eltype (cw. callbacks)
309
+
310
+ condition_dim (cw:: CallbackWrapper ) = first (cw. callbacks). condition. sym |> length
311
+ condition_pdim (cw:: CallbackWrapper ) = first (cw. callbacks). condition. psym |> length
312
+ affect_dim (cw:: CallbackWrapper , i) = cw. callbacks[i]. affect. sym |> length
313
+ affect_pdim (cw:: CallbackWrapper , i) = cw. callbacks[i]. affect. psym |> length
314
+
315
+ condition_urange (cw:: CallbackWrapper , i) = (1 + (i- 1 )* condition_dim (cw)) : i* condition_dim (cw)
316
+ condition_prange (cw:: CallbackWrapper , i) = (1 + (i- 1 )* condition_pdim (cw)) : i* condition_pdim (cw)
317
+ affect_urange (cw:: CallbackWrapper , i) = (1 + (i- 1 )* affect_dim (cw,i) ) : i* affect_dim (cw,i)
318
+ affect_prange (cw:: CallbackWrapper , i) = (1 + (i- 1 )* affect_pdim (cw,i)) : i* affect_pdim (cw,i)
319
+
320
+ function collect_c_or_a_indices (cw:: CallbackWrapper , c_or_a, u_or_p)
321
+ sidxs = SymbolicIndex[]
322
+ for (component, cb) in zip (cw. components, cw. callbacks)
323
+ syms = getproperty (getproperty (cb, c_or_a), u_or_p)
324
+ symidxtype = if component isa VIndex
325
+ u_or_p == :sym ? VIndex : VPIndex
326
+ else
327
+ u_or_p == :sym ? EIndex : EPIndex
328
+ end
329
+ sidx = collect (symidxtype (component. compidx, syms))
330
+ append! (sidxs, sidx)
331
+ end
332
+ sidxs
333
+ end
334
+
297
335
# ###
298
336
# ### wrapping of continous callbacks
299
337
# ###
@@ -316,38 +354,11 @@ function ContinousCallbackWrapper(nw, components, callbacks)
316
354
ContinousCallbackWrapper (nw, components, callbacks, sublen, condition)
317
355
end
318
356
319
- Base. length (ccw:: ContinousCallbackWrapper ) = length (ccw. callbacks)
320
- cbtype (ccw:: ContinousCallbackWrapper{T} ) where {T} = T
321
-
322
- condition_dim (ccw) = first (ccw. callbacks). condition. sym |> length
323
- condition_pdim (ccw) = first (ccw. callbacks). condition. psym |> length
324
- affect_dim (ccw,i) = ccw. callbacks[i]. affect. sym |> length
325
- affect_pdim (ccw,i) = ccw. callbacks[i]. affect. psym |> length
326
-
327
- condition_urange (ccw, i) = (1 + (i- 1 )* condition_dim (ccw)) : i* condition_dim (ccw)
328
- condition_prange (ccw, i) = (1 + (i- 1 )* condition_pdim (ccw)) : i* condition_pdim (ccw)
329
- affect_urange (ccw, i) = (1 + (i- 1 )* affect_dim (ccw,i) ) : i* affect_dim (ccw,i)
330
- affect_prange (ccw, i) = (1 + (i- 1 )* affect_pdim (ccw,i)) : i* affect_pdim (ccw,i)
331
-
332
- condition_outrange (ccw, i) = (1 + (i- 1 )* ccw. sublen) : i* ccw. sublen
357
+ # Continuous-specific functions (for vector callbacks)
358
+ condition_outrange (ccw:: ContinousCallbackWrapper , i) = (1 + (i- 1 )* ccw. sublen) : i* ccw. sublen
333
359
334
- cbidx_from_outidx (ccw, outidx) = div (outidx- 1 , ccw. sublen) + 1
335
- subidx_from_outidx (ccw, outidx) = mod (outidx, 1 : ccw. sublen)
336
-
337
- function collect_c_or_a_indices (ccw:: ContinousCallbackWrapper , c_or_a, u_or_p)
338
- sidxs = SymbolicIndex[]
339
- for (component, cb) in zip (ccw. components, ccw. callbacks)
340
- syms = getproperty (getproperty (cb, c_or_a), u_or_p)
341
- symidxtype = if component isa VIndex
342
- u_or_p == :sym ? VIndex : VPIndex
343
- else
344
- u_or_p == :sym ? EIndex : EPIndex
345
- end
346
- sidx = collect (symidxtype (component. compidx, syms))
347
- append! (sidxs, sidx)
348
- end
349
- sidxs
350
- end
360
+ cbidx_from_outidx (ccw:: ContinousCallbackWrapper , outidx) = div (outidx- 1 , ccw. sublen) + 1
361
+ subidx_from_outidx (ccw:: ContinousCallbackWrapper , outidx) = mod (outidx, 1 : ccw. sublen)
351
362
352
363
# generate VectorContinuousCallback from a ContinousCallbackWrapper
353
364
function to_callback (ccw:: ContinousCallbackWrapper )
@@ -457,37 +468,37 @@ end
457
468
# ###
458
469
# ### wrapping of discrete callbacks
459
470
# ###
460
- struct DiscreteCallbackWrapper{ST,T} <: CallbackWrapper
471
+ struct DiscreteCallbackWrapper{ST,T,C } <: CallbackWrapper
461
472
nw:: Network
462
- component:: ST
463
- callback:: T
464
- function DiscreteCallbackWrapper (nw, component, callback)
465
- @assert nw isa Network
466
- @assert component isa SymbolicIndex
467
- @assert callback isa Union{DiscreteComponentCallback, PresetTimeComponentCallback}
468
- new {typeof(component),typeof(callback)} (nw, component, callback)
473
+ components:: Vector{ST} # Changed to support batching
474
+ callbacks:: Vector{T} # Changed to support batching
475
+ condition:: C # Store condition function to avoid dynamic dispatch
476
+ end
477
+ function DiscreteCallbackWrapper (nw, components, callbacks)
478
+ @assert nw isa Network
479
+ @assert all (c -> c isa SymbolicIndex, components)
480
+ @assert all (cb -> cb isa DiscreteComponentCallback, callbacks) # Only DiscreteComponentCallback
481
+ if ! isconcretetype (eltype (components))
482
+ components = [c for c in components]
469
483
end
484
+ if ! isconcretetype (eltype (callbacks))
485
+ callbacks = [cb for cb in callbacks]
486
+ end
487
+ # Extract condition function - all callbacks in batch have identical conditions
488
+ condition = first (callbacks). condition. f
489
+ DiscreteCallbackWrapper {eltype(components),eltype(callbacks),typeof(condition)} (nw, components, callbacks, condition)
470
490
end
471
491
472
492
# generate a DiscreteCallback from a DiscreteCallbackWrapper
473
493
function to_callback (dcw:: DiscreteCallbackWrapper )
474
- kwargs = dcw. callback . kwargs
494
+ kwargs = first ( dcw. callbacks) . kwargs
475
495
cond = _batch_condition (dcw)
476
496
affect = _batch_affect (dcw)
477
497
DiscreteCallback (cond, affect; kwargs... )
478
498
end
479
- # generate a PresetTimeCallback from a DiscreteCallbackWrapper
480
- function to_callback (dcw:: DiscreteCallbackWrapper{<:Any,<:PresetTimeComponentCallback} )
481
- kwargs = dcw. callback. kwargs
482
- affect = _batch_affect (dcw)
483
- ts = dcw. callback. ts
484
- DiffEqCallbacks. PresetTimeCallback (ts, affect; kwargs... )
485
- end
486
499
function _batch_condition (dcw:: DiscreteCallbackWrapper )
487
- uidxtype = dcw. component isa EIndex ? EIndex : VIndex
488
- pidxtype = dcw. component isa EIndex ? EPIndex : VPIndex
489
- usymidxs = uidxtype (dcw. component. compidx, dcw. callback. condition. sym)
490
- psymidxs = pidxtype (dcw. component. compidx, dcw. callback. condition. psym)
500
+ usymidxs = collect_c_or_a_indices (dcw, :condition , :sym )
501
+ psymidxs = collect_c_or_a_indices (dcw, :condition , :psym )
491
502
ucache = DiffCache (zeros (length (usymidxs)), 12 )
492
503
493
504
obsf = SII. observed (dcw. nw, usymidxs)
@@ -502,37 +513,154 @@ function _batch_condition(dcw::DiscreteCallbackWrapper)
502
513
(u, t, integrator) -> begin
503
514
us = PreallocationTools. get_tmp (ucache, u)
504
515
obsf (u, integrator. p, t, us) # fills us inplace
505
- _u = SymbolicView (us, dcw. callback. condition. sym)
506
- pv = view (integrator. p, pidxs)
507
- _p = SymbolicView (pv, dcw. callback. condition. psym)
508
- dcw. callback. condition. f (_u, _p, t)
516
+
517
+ # OR logic: return true if ANY component condition is true
518
+ for i in 1 : length (dcw)
519
+ # symbolic view into u
520
+ uv = view (us, condition_urange (dcw, i))
521
+ _u = SymbolicView (uv, dcw. callbacks[i]. condition. sym)
522
+
523
+ # symbolic view into p
524
+ pidxsv = view (pidxs, condition_prange (dcw, i))
525
+ pv = view (integrator. p, pidxsv)
526
+ _p = SymbolicView (pv, dcw. callbacks[i]. condition. psym)
527
+
528
+ # If any condition is true, trigger the callback
529
+ if dcw. condition (_u, _p, t)
530
+ return true
531
+ end
532
+ end
533
+ return false
509
534
end
510
535
end
511
536
function _batch_affect (dcw:: DiscreteCallbackWrapper )
512
- uidxtype = dcw. component isa EIndex ? EIndex : VIndex
513
- pidxtype = dcw. component isa EIndex ? EPIndex : VPIndex
514
- usymidxs = uidxtype (dcw. component. compidx, dcw. callback. affect. sym)
515
- psymidxs = pidxtype (dcw. component. compidx, dcw. callback. affect. psym)
537
+ # Setup for condition re-evaluation
538
+ cusymidxs = collect_c_or_a_indices (dcw, :condition , :sym )
539
+ cpsymidxs = collect_c_or_a_indices (dcw, :condition , :psym )
540
+ cucache = DiffCache (zeros (length (cusymidxs)), 12 )
541
+ cobsf = SII. observed (dcw. nw, cusymidxs)
542
+ cpidxs = SII. parameter_index .(Ref (dcw. nw), cpsymidxs)
516
543
517
- uidxs = SII. variable_index .(Ref (dcw. nw), usymidxs)
518
- pidxs = SII. parameter_index .(Ref (dcw. nw), psymidxs)
544
+ # Setup for affect execution
545
+ ausymidxs = collect_c_or_a_indices (dcw, :affect , :sym )
546
+ apsymidxs = collect_c_or_a_indices (dcw, :affect , :psym )
547
+
548
+ auidxs = SII. variable_index .(Ref (dcw. nw), ausymidxs)
549
+ apidxs = SII. parameter_index .(Ref (dcw. nw), apsymidxs)
550
+
551
+ if any (isnothing, auidxs) || any (isnothing, apidxs)
552
+ missing_u = []
553
+ if any (isnothing, auidxs)
554
+ nidxs = findall (isnothing, auidxs)
555
+ append! (missing_u, ausymidxs[nidxs])
556
+ end
557
+ missing_p = []
558
+ if any (isnothing, apidxs)
559
+ nidxs = findall (isnothing, apidxs)
560
+ append! (missing_p, apsymidxs[nidxs])
561
+ end
562
+ throw (ArgumentError (
563
+ " Cannot build callback as it contains refrences to undefined symbols:\n " *
564
+ (isempty (missing_u) ? " " : " Missing state symbols: $(missing_u) \n " )*
565
+ (isempty (missing_p) ? " " : " Missing param symbols: $(missing_p) \n " )
566
+ ))
567
+ end
519
568
520
569
(integrator) -> begin
570
+ # Re-evaluate all conditions to determine which affects to execute
571
+ cus = PreallocationTools. get_tmp (cucache, integrator. u)
572
+ cobsf (integrator. u, integrator. p, integrator. t, cus)
573
+
574
+ any_uchanged = false
575
+ any_pchanged = false
576
+
577
+ for i in 1 : length (dcw)
578
+ # Re-evaluate condition for component i
579
+ cuv = view (cus, condition_urange (dcw, i))
580
+ c_u = SymbolicView (cuv, dcw. callbacks[i]. condition. sym)
581
+ cpidxsv = view (cpidxs, condition_prange (dcw, i))
582
+ cpv = view (integrator. p, cpidxsv)
583
+ c_p = SymbolicView (cpv, dcw. callbacks[i]. condition. psym)
584
+
585
+ # Only execute affect if condition is true
586
+ if dcw. condition (c_u, c_p, integrator. t)
587
+ # Execute affect for component i
588
+ auidxsv = view (auidxs, affect_urange (dcw, i))
589
+ auv = view (integrator. u, auidxsv)
590
+ a_u = SymbolicView (auv, dcw. callbacks[i]. affect. sym)
591
+
592
+ apidxsv = view (apidxs, affect_prange (dcw, i))
593
+ apv = view (integrator. p, apidxsv)
594
+ a_p = SymbolicView (apv, dcw. callbacks[i]. affect. psym)
595
+
596
+ ctx = get_ctx (integrator, dcw. components[i])
597
+
598
+ uhash = hash (auv)
599
+ phash = hash (apv)
600
+ dcw. callbacks[i]. affect. f (a_u, a_p, ctx)
601
+ pchanged = hash (apv) != phash
602
+ uchanged = hash (auv) != uhash
603
+
604
+ any_uchanged = any_uchanged || uchanged
605
+ any_pchanged = any_pchanged || pchanged
606
+ end
607
+ end
608
+
609
+ (any_uchanged || any_pchanged) && SciMLBase. auto_dt_reset! (integrator)
610
+ any_pchanged && save_parameters! (integrator)
611
+ end
612
+ end
613
+
614
+ # ###
615
+ # ### wrapping of preset time callbacks
616
+ # ###
617
+ struct PresetTimeCallbackWrapper{ST,T}
618
+ nw:: Network
619
+ component:: ST # Single component - PresetTime callbacks cannot be batched
620
+ callback:: T # Single callback - PresetTime callbacks cannot be batched
621
+ function PresetTimeCallbackWrapper (nw, component:: SymbolicIndex , callback:: PresetTimeComponentCallback )
622
+ @assert nw isa Network
623
+ @assert component isa SymbolicIndex
624
+ @assert callback isa PresetTimeComponentCallback
625
+ # PresetTimeCallbacks cannot be batched, so always single component/callback
626
+ new {typeof(component), typeof(callback)} (nw, component, callback)
627
+ end
628
+ end
629
+
630
+ # generate a PresetTimeCallback from a PresetTimeCallbackWrapper
631
+ function to_callback (ptcw:: PresetTimeCallbackWrapper )
632
+ callback = ptcw. callback
633
+ component = ptcw. component
634
+ kwargs = callback. kwargs
635
+ ts = callback. ts
636
+
637
+ # Create affect function for the single component
638
+ uidxtype = component isa EIndex ? EIndex : VIndex
639
+ pidxtype = component isa EIndex ? EPIndex : VPIndex
640
+ usymidxs = uidxtype (component. compidx, callback. affect. sym)
641
+ psymidxs = pidxtype (component. compidx, callback. affect. psym)
642
+
643
+ uidxs = SII. variable_index .(Ref (ptcw. nw), usymidxs)
644
+ pidxs = SII. parameter_index .(Ref (ptcw. nw), psymidxs)
645
+
646
+ affect = (integrator) -> begin
521
647
uv = view (integrator. u, uidxs)
522
- _u = SymbolicView (uv, dcw . callback. affect. sym)
648
+ _u = SymbolicView (uv, callback. affect. sym)
523
649
pv = view (integrator. p, pidxs)
524
- _p = SymbolicView (pv, dcw . callback. affect. psym)
525
- ctx = get_ctx (integrator, dcw . component)
650
+ _p = SymbolicView (pv, callback. affect. psym)
651
+ ctx = get_ctx (integrator, component)
526
652
527
653
uhash = hash (uv)
528
654
phash = hash (pv)
529
- dcw . callback. affect. f (_u, _p, ctx)
655
+ callback. affect. f (_u, _p, ctx)
530
656
pchanged = hash (pv) != phash
531
657
uchanged = hash (uv) != uhash
532
658
533
659
(pchanged || uchanged) && SciMLBase. auto_dt_reset! (integrator)
534
660
pchanged && save_parameters! (integrator)
535
661
end
662
+
663
+ DiffEqCallbacks. PresetTimeCallback (ts, affect; kwargs... )
536
664
end
537
665
538
666
0 commit comments