@@ -224,6 +224,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
224224 casadi_switch_events ,
225225 terminate_events ,
226226 interpolant_extrapolation_events ,
227+ t_discon_constant ,
227228 discontinuity_events ,
228229 ) = self ._set_up_events (model , t_eval , inputs , vars_for_processing )
229230
@@ -233,8 +234,9 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
233234 model .rhs_algebraic_eval = rhs_algebraic
234235
235236 model .terminate_events_eval = terminate_events
236- model .discontinuity_events_eval = discontinuity_events
237237 model .interpolant_extrapolation_events_eval = interpolant_extrapolation_events
238+ model .discontinuity_events_eval = discontinuity_events
239+ model .t_discon_constant = t_discon_constant
238240
239241 model .jac_rhs_eval = jac_rhs
240242 model .jac_rhs_action_eval = jac_rhs_action
@@ -482,23 +484,17 @@ def _set_up_events(self, model, t_eval, inputs, vars_for_processing):
482484 # discontinuity events if these exist.
483485 # Note: only checks for the case of t < X, t <= X, X < t, or X <= t,
484486 # but also accounts for the fact that t might be dimensional
485-
486- t0 = np .min (t_eval )
487487 tf = np .max (t_eval )
488488
489489 def supports_t_eval_discontinuities (expr ):
490490 # Only IDAKLUSolver supports discontinuities represented by t_eval
491- return (
492- self .supports_t_eval_discontinuities
493- and (t_eval is not None )
494- and expr .is_constant ()
495- )
491+ return self .supports_t_eval_discontinuities and expr .is_constant ()
492+
493+ # Find all the constant time-based discontinuities
494+ t_discon = []
496495
497- def append_t_eval (t ):
498- if t0 <= t <= tf and t not in t_eval :
499- # Insert t in the correct position to maintain sorted order
500- idx = np .searchsorted (t_eval , t )
501- t_eval .insert (idx , t )
496+ def append_t_discon (t ):
497+ t_discon .append (t )
502498
503499 def heaviside_event (symbol , expr ):
504500 model .events .append (
@@ -509,28 +505,28 @@ def heaviside_event(symbol, expr):
509505 )
510506 )
511507
512- def heaviside_t_eval (symbol , expr ):
508+ def heaviside_t_discon (symbol , expr ):
513509 value = expr .evaluate (0 , model .y0 .full (), inputs = inputs )
514- append_t_eval (value )
510+ append_t_discon (value )
515511
516512 if isinstance (symbol , pybamm .EqualHeaviside ):
517513 if symbol .left == pybamm .t :
518514 # t <= x
519515 # Stop at t = x and right after t = x
520- append_t_eval (np .nextafter (value , np .inf ))
516+ append_t_discon (np .nextafter (value , np .inf ))
521517 else :
522518 # t >= x
523519 # Stop at t = x and right before t = x
524- append_t_eval (np .nextafter (value , - np .inf ))
520+ append_t_discon (np .nextafter (value , - np .inf ))
525521 elif isinstance (symbol , pybamm .NotEqualHeaviside ):
526522 if symbol .left == pybamm .t :
527523 # t < x
528524 # Stop at t = x and right before t = x
529- append_t_eval (np .nextafter (value , - np .inf ))
525+ append_t_discon (np .nextafter (value , - np .inf ))
530526 else :
531527 # t > x
532528 # Stop at t = x and right after t = x
533- append_t_eval (np .nextafter (value , np .inf ))
529+ append_t_discon (np .nextafter (value , np .inf ))
534530 else :
535531 raise ValueError (
536532 f"Unknown heaviside function: { symbol } "
@@ -546,13 +542,13 @@ def modulo_event(symbol, expr, num_events):
546542 )
547543 )
548544
549- def modulo_t_eval (symbol , expr , num_events ):
545+ def modulo_t_discon (symbol , expr , num_events ):
550546 value = expr .evaluate (0 , model .y0 .full (), inputs = inputs )
551547 for i in np .arange (num_events ):
552548 t = value * (i + 1 )
553549 # Stop right before t and at t
554- append_t_eval (np .nextafter (t , - np .inf ))
555- append_t_eval (t )
550+ append_t_discon (np .nextafter (t , - np .inf ))
551+ append_t_discon (t )
556552
557553 for symbol in itertools .chain (
558554 model .concatenated_rhs .pre_order (),
@@ -569,7 +565,7 @@ def modulo_t_eval(symbol, expr, num_events):
569565 continue # pragma: no cover
570566
571567 if supports_t_eval_discontinuities (expr ):
572- heaviside_t_eval (symbol , expr )
568+ heaviside_t_discon (symbol , expr )
573569 else :
574570 heaviside_event (symbol , expr )
575571
@@ -578,7 +574,7 @@ def modulo_t_eval(symbol, expr, num_events):
578574 num_events = 200 if (t_eval is None ) else (tf // expr .value )
579575
580576 if supports_t_eval_discontinuities (expr ):
581- modulo_t_eval (symbol , expr , num_events )
577+ modulo_t_discon (symbol , expr , num_events )
582578 else :
583579 modulo_event (symbol , expr , num_events )
584580 else :
@@ -641,6 +637,7 @@ def modulo_t_eval(symbol, expr, num_events):
641637 casadi_switch_events ,
642638 terminate_events ,
643639 interpolant_extrapolation_events ,
640+ t_discon ,
644641 discontinuity_events ,
645642 )
646643
@@ -1053,39 +1050,53 @@ def solve(
10531050 return solutions
10541051
10551052 @staticmethod
1056- def _get_discontinuity_start_end_indices (model , inputs , t_eval ):
1053+ def filter_discontinuities (t_discon : list , t_eval : list ) -> np .ndarray :
1054+ """
1055+ Filter the discontinuities to only include the unique and sorted
1056+ values within the t_eval range (non-exclusive of end points).
1057+
1058+ Parameters
1059+ ----------
1060+ t_discon : list
1061+ The list of all possible discontinuity times.
1062+ t_eval : list
1063+ The integration time points.
1064+
1065+ Returns
1066+ -------
1067+ np.ndarray
1068+ The filtered list of discontinuities within the range of t_eval.
1069+ """
1070+ t_discon_unique = np .unique (t_discon )
1071+
1072+ # Find the indices within t_eval (non-exclusive of end points)
1073+ idx_start = np .searchsorted (t_discon_unique , t_eval [0 ], side = "right" )
1074+ idx_end = np .searchsorted (t_discon_unique , t_eval [- 1 ], side = "left" )
1075+ return t_discon_unique [idx_start :idx_end ]
1076+
1077+ def _get_discontinuity_start_end_indices (self , model , inputs , t_eval ):
1078+ if self .supports_t_eval_discontinuities :
1079+ t_discon_constant = self .filter_discontinuities (
1080+ model .t_discon_constant , t_eval
1081+ )
1082+ t_eval = np .union1d (t_eval , t_discon_constant )
1083+
10571084 if not model .discontinuity_events_eval :
10581085 pybamm .logger .verbose ("No discontinuity events found" )
10591086 return [0 ], [len (t_eval )], t_eval
10601087
1061- # Calculate discontinuities
1062- discontinuities = [
1088+ # Calculate all possible discontinuities
1089+ _t_discon_full = [
10631090 # Assuming that discontinuities do not depend on
10641091 # input parameters when len(input_list) > 1, only
10651092 # `inputs` is passed to `evaluate`.
10661093 # See https://github.com/pybamm-team/PyBaMM/pull/1261
10671094 event .expression .evaluate (inputs = inputs )
10681095 for event in model .discontinuity_events_eval
10691096 ]
1097+ t_discon = self .filter_discontinuities (_t_discon_full , t_eval )
10701098
1071- # make sure they are increasing in time
1072- discontinuities = sorted (discontinuities )
1073-
1074- # remove any identical discontinuities
1075- discontinuities = [
1076- v
1077- for i , v in enumerate (discontinuities )
1078- if (
1079- i == len (discontinuities ) - 1
1080- or discontinuities [i ] < discontinuities [i + 1 ]
1081- )
1082- and v > 0
1083- ]
1084-
1085- # remove any discontinuities after end of t_eval
1086- discontinuities = [v for v in discontinuities if v < t_eval [- 1 ]]
1087-
1088- pybamm .logger .verbose (f"Discontinuity events found at t = { discontinuities } " )
1099+ pybamm .logger .verbose (f"Discontinuity events found at t = { t_discon } " )
10891100 if isinstance (inputs , list ):
10901101 raise pybamm .SolverError (
10911102 "Cannot solve for a list of input parameters sets with discontinuities"
@@ -1096,7 +1107,7 @@ def _get_discontinuity_start_end_indices(model, inputs, t_eval):
10961107 start_indices = [0 ]
10971108 end_indices = []
10981109 eps = sys .float_info .epsilon
1099- for dtime in discontinuities :
1110+ for dtime in t_discon :
11001111 dindex = np .searchsorted (t_eval , dtime , side = "left" )
11011112 end_indices .append (dindex + 1 )
11021113 start_indices .append (dindex + 1 )
0 commit comments