Skip to content

Commit 9560875

Browse files
Fix indexing bug with infeasible experiments for IDAKLUSolver (#4541)
* fix interp indexing Co-Authored-By: Pip Liggins <[email protected]> * simplify indexing --------- Co-authored-by: Pip Liggins <[email protected]>
1 parent 164f71e commit 9560875

File tree

4 files changed

+33
-21
lines changed

4 files changed

+33
-21
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
- Removed the `start_step_offset` setting and disabled minimum `dt` warnings for drive cycles with the (`IDAKLUSolver`). ([#4416](https://github.com/pybamm-team/PyBaMM/pull/4416))
2121

2222
## Bug Fixes
23-
23+
- Fixed bug in post-processing solutions with infeasible experiments using the (`IDAKLUSolver`). ([#4541](https://github.com/pybamm-team/PyBaMM/pull/4541))
2424
- Disabled IREE on MacOS due to compatibility issues and added the CasADI
2525
path to the environment to resolve issues on MacOS and Linux. Windows
2626
users may still experience issues with interpolation. ([#4528](https://github.com/pybamm-team/PyBaMM/pull/4528))

src/pybamm/solvers/c_solvers/idaklu/observe.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ class TimeSeriesInterpolator {
100100
) {
101101
for (size_t i = 0; i < ts_data_np.size(); i++) {
102102
const auto& t_data = ts_data_np[i].unchecked<1>();
103+
// Continue if there is no data
104+
if (t_data.size() == 0) {
105+
continue;
106+
}
107+
103108
const realtype t_data_final = t_data(t_data.size() - 1);
104109
realtype t_interp_next = t_interp(i_interp);
105110
// Continue if the next interpolation point is beyond the final data point
@@ -227,6 +232,10 @@ class TimeSeriesProcessor {
227232
int i_entries = 0;
228233
for (size_t i = 0; i < ts.size(); i++) {
229234
const auto& t = ts[i].unchecked<1>();
235+
// Continue if there is no data
236+
if (t.size() == 0) {
237+
continue;
238+
}
230239
const auto& y = ys[i].unchecked<2>();
231240
const auto input = inputs[i].data();
232241
const auto func = *funcs[i];

src/pybamm/solvers/processed_variable.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,22 @@ def _setup_cpp_inputs(self, t, full_range):
133133
ys = self.all_ys
134134
yps = self.all_yps
135135
inputs = self.all_inputs_casadi
136-
# Find the indices of the time points to observe
137-
if full_range:
138-
idxs = range(len(ts))
139-
else:
140-
idxs = _find_ts_indices(ts, t)
141136

142-
if isinstance(idxs, list):
143-
# Extract the time points and inputs
144-
ts = [ts[idx] for idx in idxs]
145-
ys = [ys[idx] for idx in idxs]
146-
if self.hermite_interpolation:
147-
yps = [yps[idx] for idx in idxs]
148-
inputs = [self.all_inputs_casadi[idx] for idx in idxs]
137+
# Remove all empty ts
138+
idxs = np.where([ti.size > 0 for ti in ts])[0]
139+
140+
# Find the indices of the time points to observe
141+
if not full_range:
142+
ts_nonempty = [ts[idx] for idx in idxs]
143+
idxs_subset = _find_ts_indices(ts_nonempty, t)
144+
idxs = idxs[idxs_subset]
145+
146+
# Extract the time points and inputs
147+
ts = [ts[idx] for idx in idxs]
148+
ys = [ys[idx] for idx in idxs]
149+
if self.hermite_interpolation:
150+
yps = [yps[idx] for idx in idxs]
151+
inputs = [self.all_inputs_casadi[idx] for idx in idxs]
149152

150153
is_f_contiguous = _is_f_contiguous(ys)
151154

@@ -977,8 +980,4 @@ def _find_ts_indices(ts, t):
977980
if (t[-1] > ts[-1][-1]) and (len(indices) == 0 or indices[-1] != len(ts) - 1):
978981
indices.append(len(ts) - 1)
979982

980-
if len(indices) == len(ts):
981-
# All indices are included
982-
return range(len(ts))
983-
984983
return indices

src/pybamm/solvers/solution.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,15 @@ def _update_variable(self, variable):
580580
# Iterate through all models, some may be in the list several times and
581581
# therefore only get set up once
582582
vars_casadi = []
583-
for i, (model, ys, inputs, var_pybamm) in enumerate(
584-
zip(self.all_models, self.all_ys, self.all_inputs, vars_pybamm)
583+
for i, (model, ts, ys, inputs, var_pybamm) in enumerate(
584+
zip(self.all_models, self.all_ts, self.all_ys, self.all_inputs, vars_pybamm)
585585
):
586-
if ys.size == 0 and var_pybamm.has_symbol_of_classes(
587-
pybamm.expression_tree.state_vector.StateVector
586+
if (
587+
ys.size == 0
588+
and var_pybamm.has_symbol_of_classes(
589+
pybamm.expression_tree.state_vector.StateVector
590+
)
591+
and not ts.size == 0
588592
):
589593
raise KeyError(
590594
f"Cannot process variable '{variable}' as it was not part of the "

0 commit comments

Comments
 (0)