Skip to content

Commit 91b20db

Browse files
MarcBerlinerisaacbasil
authored andcommitted
Fix bug with out-of-bounds discontinuities in time (#5205)
* filter constant-time discontinuities Co-Authored-By: isaacbasil <[email protected]> * Update CHANGELOG.md --------- Co-authored-by: isaacbasil <[email protected]>
1 parent 13ebe0c commit 91b20db

File tree

3 files changed

+72
-46
lines changed

3 files changed

+72
-46
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
## Bug fixes
1212

13+
- Fixed a bug where time-based Heaviside or modulo discontinuities could trigger out-of-bounds errors in time arrays. ([#5205](https://github.com/pybamm-team/PyBaMM/pull/5205))
1314
- Fixed a bug using a time-varying input with heaviside or modulo functions using the `IDAKLUSolver`. ([#4994](https://github.com/pybamm-team/PyBaMM/pull/4994))
1415
- Fix a bug in setting initial stoichiometries where the reference temperature was used instead of the initial temperature. ([#5189](https://github.com/pybamm-team/PyBaMM/pull/5189))
1516
- Fix a bug in the calculation of "Bulk" OCP terms in hysteresis models ([#5169](https://github.com/pybamm-team/PyBaMM/pull/5169))

src/pybamm/solvers/base_solver.py

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
224224
casadi_switch_events,
225225
terminate_events,
226226
interpolant_extrapolation_events,
227+
t_discon_constant,
227228
discontinuity_events,
228229
) = self._set_up_events(model, t_eval, inputs, vars_for_processing)
229230

@@ -233,8 +234,9 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
233234
model.rhs_algebraic_eval = rhs_algebraic
234235

235236
model.terminate_events_eval = terminate_events
236-
model.discontinuity_events_eval = discontinuity_events
237237
model.interpolant_extrapolation_events_eval = interpolant_extrapolation_events
238+
model.discontinuity_events_eval = discontinuity_events
239+
model.t_discon_constant = t_discon_constant
238240

239241
model.jac_rhs_eval = jac_rhs
240242
model.jac_rhs_action_eval = jac_rhs_action
@@ -482,23 +484,17 @@ def _set_up_events(self, model, t_eval, inputs, vars_for_processing):
482484
# discontinuity events if these exist.
483485
# Note: only checks for the case of t < X, t <= X, X < t, or X <= t,
484486
# but also accounts for the fact that t might be dimensional
485-
486-
t0 = np.min(t_eval)
487487
tf = np.max(t_eval)
488488

489489
def supports_t_eval_discontinuities(expr):
490490
# Only IDAKLUSolver supports discontinuities represented by t_eval
491-
return (
492-
self.supports_t_eval_discontinuities
493-
and (t_eval is not None)
494-
and expr.is_constant()
495-
)
491+
return self.supports_t_eval_discontinuities and expr.is_constant()
492+
493+
# Find all the constant time-based discontinuities
494+
t_discon = []
496495

497-
def append_t_eval(t):
498-
if t0 <= t <= tf and t not in t_eval:
499-
# Insert t in the correct position to maintain sorted order
500-
idx = np.searchsorted(t_eval, t)
501-
t_eval.insert(idx, t)
496+
def append_t_discon(t):
497+
t_discon.append(t)
502498

503499
def heaviside_event(symbol, expr):
504500
model.events.append(
@@ -509,28 +505,28 @@ def heaviside_event(symbol, expr):
509505
)
510506
)
511507

512-
def heaviside_t_eval(symbol, expr):
508+
def heaviside_t_discon(symbol, expr):
513509
value = expr.evaluate(0, model.y0.full(), inputs=inputs)
514-
append_t_eval(value)
510+
append_t_discon(value)
515511

516512
if isinstance(symbol, pybamm.EqualHeaviside):
517513
if symbol.left == pybamm.t:
518514
# t <= x
519515
# Stop at t = x and right after t = x
520-
append_t_eval(np.nextafter(value, np.inf))
516+
append_t_discon(np.nextafter(value, np.inf))
521517
else:
522518
# t >= x
523519
# Stop at t = x and right before t = x
524-
append_t_eval(np.nextafter(value, -np.inf))
520+
append_t_discon(np.nextafter(value, -np.inf))
525521
elif isinstance(symbol, pybamm.NotEqualHeaviside):
526522
if symbol.left == pybamm.t:
527523
# t < x
528524
# Stop at t = x and right before t = x
529-
append_t_eval(np.nextafter(value, -np.inf))
525+
append_t_discon(np.nextafter(value, -np.inf))
530526
else:
531527
# t > x
532528
# Stop at t = x and right after t = x
533-
append_t_eval(np.nextafter(value, np.inf))
529+
append_t_discon(np.nextafter(value, np.inf))
534530
else:
535531
raise ValueError(
536532
f"Unknown heaviside function: {symbol}"
@@ -546,13 +542,13 @@ def modulo_event(symbol, expr, num_events):
546542
)
547543
)
548544

549-
def modulo_t_eval(symbol, expr, num_events):
545+
def modulo_t_discon(symbol, expr, num_events):
550546
value = expr.evaluate(0, model.y0.full(), inputs=inputs)
551547
for i in np.arange(num_events):
552548
t = value * (i + 1)
553549
# Stop right before t and at t
554-
append_t_eval(np.nextafter(t, -np.inf))
555-
append_t_eval(t)
550+
append_t_discon(np.nextafter(t, -np.inf))
551+
append_t_discon(t)
556552

557553
for symbol in itertools.chain(
558554
model.concatenated_rhs.pre_order(),
@@ -569,7 +565,7 @@ def modulo_t_eval(symbol, expr, num_events):
569565
continue # pragma: no cover
570566

571567
if supports_t_eval_discontinuities(expr):
572-
heaviside_t_eval(symbol, expr)
568+
heaviside_t_discon(symbol, expr)
573569
else:
574570
heaviside_event(symbol, expr)
575571

@@ -578,7 +574,7 @@ def modulo_t_eval(symbol, expr, num_events):
578574
num_events = 200 if (t_eval is None) else (tf // expr.value)
579575

580576
if supports_t_eval_discontinuities(expr):
581-
modulo_t_eval(symbol, expr, num_events)
577+
modulo_t_discon(symbol, expr, num_events)
582578
else:
583579
modulo_event(symbol, expr, num_events)
584580
else:
@@ -641,6 +637,7 @@ def modulo_t_eval(symbol, expr, num_events):
641637
casadi_switch_events,
642638
terminate_events,
643639
interpolant_extrapolation_events,
640+
t_discon,
644641
discontinuity_events,
645642
)
646643

@@ -1053,39 +1050,53 @@ def solve(
10531050
return solutions
10541051

10551052
@staticmethod
1056-
def _get_discontinuity_start_end_indices(model, inputs, t_eval):
1053+
def filter_discontinuities(t_discon: list, t_eval: list) -> np.ndarray:
1054+
"""
1055+
Filter the discontinuities to only include the unique and sorted
1056+
values within the t_eval range (non-exclusive of end points).
1057+
1058+
Parameters
1059+
----------
1060+
t_discon : list
1061+
The list of all possible discontinuity times.
1062+
t_eval : list
1063+
The integration time points.
1064+
1065+
Returns
1066+
-------
1067+
np.ndarray
1068+
The filtered list of discontinuities within the range of t_eval.
1069+
"""
1070+
t_discon_unique = np.unique(t_discon)
1071+
1072+
# Find the indices within t_eval (non-exclusive of end points)
1073+
idx_start = np.searchsorted(t_discon_unique, t_eval[0], side="right")
1074+
idx_end = np.searchsorted(t_discon_unique, t_eval[-1], side="left")
1075+
return t_discon_unique[idx_start:idx_end]
1076+
1077+
def _get_discontinuity_start_end_indices(self, model, inputs, t_eval):
1078+
if self.supports_t_eval_discontinuities:
1079+
t_discon_constant = self.filter_discontinuities(
1080+
model.t_discon_constant, t_eval
1081+
)
1082+
t_eval = np.union1d(t_eval, t_discon_constant)
1083+
10571084
if not model.discontinuity_events_eval:
10581085
pybamm.logger.verbose("No discontinuity events found")
10591086
return [0], [len(t_eval)], t_eval
10601087

1061-
# Calculate discontinuities
1062-
discontinuities = [
1088+
# Calculate all possible discontinuities
1089+
_t_discon_full = [
10631090
# Assuming that discontinuities do not depend on
10641091
# input parameters when len(input_list) > 1, only
10651092
# `inputs` is passed to `evaluate`.
10661093
# See https://github.com/pybamm-team/PyBaMM/pull/1261
10671094
event.expression.evaluate(inputs=inputs)
10681095
for event in model.discontinuity_events_eval
10691096
]
1097+
t_discon = self.filter_discontinuities(_t_discon_full, t_eval)
10701098

1071-
# make sure they are increasing in time
1072-
discontinuities = sorted(discontinuities)
1073-
1074-
# remove any identical discontinuities
1075-
discontinuities = [
1076-
v
1077-
for i, v in enumerate(discontinuities)
1078-
if (
1079-
i == len(discontinuities) - 1
1080-
or discontinuities[i] < discontinuities[i + 1]
1081-
)
1082-
and v > 0
1083-
]
1084-
1085-
# remove any discontinuities after end of t_eval
1086-
discontinuities = [v for v in discontinuities if v < t_eval[-1]]
1087-
1088-
pybamm.logger.verbose(f"Discontinuity events found at t = {discontinuities}")
1099+
pybamm.logger.verbose(f"Discontinuity events found at t = {t_discon}")
10891100
if isinstance(inputs, list):
10901101
raise pybamm.SolverError(
10911102
"Cannot solve for a list of input parameters sets with discontinuities"
@@ -1096,7 +1107,7 @@ def _get_discontinuity_start_end_indices(model, inputs, t_eval):
10961107
start_indices = [0]
10971108
end_indices = []
10981109
eps = sys.float_info.epsilon
1099-
for dtime in discontinuities:
1110+
for dtime in t_discon:
11001111
dindex = np.searchsorted(t_eval, dtime, side="left")
11011112
end_indices.append(dindex + 1)
11021113
start_indices.append(dindex + 1)

tests/unit/test_simulation.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from scipy.integrate import trapezoid
77

88
import pybamm
9+
from pybamm.solvers.base_solver import BaseSolver
910
from tests import no_internet_connection
1011

1112

@@ -895,6 +896,19 @@ def prevfloat(t):
895896
for t_node in t_nodes:
896897
assert current(t_node) == pytest.approx(sawtooth_current(t_node))
897898

899+
def test_filter_discontinuities_simple(self):
900+
t_eval = [0.0, 3.0, 10.0]
901+
t_discon = [-5.0, 0.0, 1.0, 3.0, 3.0, 5.0, 10.0, 12.0]
902+
903+
result = BaseSolver.filter_discontinuities(t_discon, t_eval)
904+
expected = np.array([1.0, 3.0, 5.0])
905+
906+
# Exclusive of endpoints
907+
t_eval_endpoints = [t_eval[0], t_eval[-1]]
908+
assert all(t not in result for t in t_eval_endpoints)
909+
910+
np.testing.assert_array_equal(result, expected)
911+
898912
def test_t_eval(self):
899913
model = pybamm.lithium_ion.SPM()
900914
sim = pybamm.Simulation(model)

0 commit comments

Comments
 (0)