Skip to content

Commit 3864e5d

Browse files
authored
Allow using IDAKLU(output_variables=...) with Experiments (#4534)
* Add test for idaklu+output_variables+experiment * edit Solution.last_state to pull y_event if all_ys is empty * Ensure ProcessedVariableComputed variables are passed through Solution copies during an Experiment Don't compute 'Change in x' summary variables if output_variables are specified * populate first_state using the initial condition if output_variables used Remove warnings about 'Change in x' summary variables * Add to computed processed variable tests * Add test for solution::add with computed variables * add test for solution::copy with computed variables * add check for idaklu on copy test * Add 'variables_returned' attribute to Solution Indicates if 'output_variables' are specified in solver and therefore empty state vector * Use `variables_returned` in `_update_variable()`, update test * Update CHANGELOG * Add test
1 parent 9560875 commit 3864e5d

File tree

9 files changed

+311
-14
lines changed

9 files changed

+311
-14
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Features
44

5+
- Adds support to `pybamm.Experiment` for the `output_variables` option in the `IDAKLUSolver`. ([#4534](https://github.com/pybamm-team/PyBaMM/pull/4534))
56
- Adds an option "voltage as a state" that can be "false" (default) or "true". If "true" adds an explicit algebraic equation for the voltage. ([#4507](https://github.com/pybamm-team/PyBaMM/pull/4507))
67
- Improved `QuickPlot` accuracy for simulations with Hermite interpolation. ([#4483](https://github.com/pybamm-team/PyBaMM/pull/4483))
78
- Added Hermite interpolation to the (`IDAKLUSolver`) that improves the accuracy and performance of post-processing variables. ([#4464](https://github.com/pybamm-team/PyBaMM/pull/4464))

src/pybamm/solvers/base_solver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,7 @@ def get_termination_reason(solution, events):
14521452
solution.t_event,
14531453
solution.y_event,
14541454
solution.termination,
1455+
variables_returned=solution.variables_returned,
14551456
)
14561457
event_sol.solve_time = 0
14571458
event_sol.integration_time = 0

src/pybamm/solvers/idaklu_solver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,7 @@ def _post_process_solution(self, sol, model, integration_time, inputs_dict):
863863
termination,
864864
all_sensitivities=yS_out,
865865
all_yps=yp,
866+
variables_returned=bool(save_outputs_only),
866867
)
867868

868869
newsol.integration_time = integration_time

src/pybamm/solvers/processed_variable_computed.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#
2-
# Processed Variable class
2+
# Processed Variable Computed class
33
#
4+
from __future__ import annotations
45
import casadi
56
import numpy as np
67
import pybamm
@@ -450,3 +451,27 @@ def sensitivities(self):
450451
if len(self.all_inputs[0]) == 0:
451452
return {}
452453
return self._sensitivities
454+
455+
def _update(
456+
self, other: pybamm.ProcessedVariableComputed, new_sol: pybamm.Solution
457+
) -> pybamm.ProcessedVariableComputed:
458+
"""
459+
Returns a new ProcessedVariableComputed object that is the result of appending
460+
the data from other to this object. Used exclusively in running experiments, to
461+
append the data from one cycle to the next.
462+
463+
Parameters
464+
----------
465+
other : :class:`pybamm.ProcessedVariableComputed`
466+
The other ProcessedVariableComputed object to append to this one
467+
new_sol : :class:`pybamm.Solution`
468+
The new solution object to be used to create the processed variables
469+
"""
470+
471+
bv = self.base_variables + other.base_variables
472+
bvc = self.base_variables_casadi + other.base_variables_casadi
473+
bvd = self.base_variables_data + other.base_variables_data
474+
475+
new_var = self.__class__(bv, bvc, bvd, new_sol)
476+
477+
return new_var

src/pybamm/solvers/solution.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ class Solution:
6262
True if sensitivities included as the solution of the explicit forwards
6363
equations. False if no sensitivities included/wanted. Dict if sensitivities are
6464
provided as a dict of {parameter: [sensitivities]} pairs.
65+
variables_returned: bool
66+
Bool to indicate if `all_ys` contains the full state vector, or is empty because
67+
only requested variables have been returned. True if `output_variables` is used
68+
with a solver, otherwise False.
6569
6670
"""
6771

@@ -76,6 +80,7 @@ def __init__(
7680
termination="final time",
7781
all_sensitivities=False,
7882
all_yps=None,
83+
variables_returned=False,
7984
check_solution=True,
8085
):
8186
if not isinstance(all_ts, list):
@@ -93,6 +98,8 @@ def __init__(
9398
all_yps = [all_yps]
9499
self._all_yps = all_yps
95100

101+
self.variables_returned = variables_returned
102+
96103
# Set up inputs
97104
if not isinstance(all_inputs, list):
98105
all_inputs_copy = dict(all_inputs)
@@ -460,9 +467,15 @@ def first_state(self):
460467
else:
461468
all_yps = self.all_yps[0][:, :1]
462469

470+
if not self.variables_returned:
471+
all_ys = self.all_ys[0][:, :1]
472+
else:
473+
# Get first state from initial conditions as all_ys is empty
474+
all_ys = self.all_models[0].y0full.reshape(-1, 1)
475+
463476
new_sol = Solution(
464477
self.all_ts[0][:1],
465-
self.all_ys[0][:, :1],
478+
all_ys,
466479
self.all_models[:1],
467480
self.all_inputs[:1],
468481
None,
@@ -500,9 +513,15 @@ def last_state(self):
500513
else:
501514
all_yps = self.all_yps[-1][:, -1:]
502515

516+
if not self.variables_returned:
517+
all_ys = self.all_ys[-1][:, -1:]
518+
else:
519+
# Get last state from y_event as all_ys is empty
520+
all_ys = self.y_event.reshape(len(self.y_event), 1)
521+
503522
new_sol = Solution(
504523
self.all_ts[-1][-1:],
505-
self.all_ys[-1][:, -1:],
524+
all_ys,
506525
self.all_models[-1:],
507526
self.all_inputs[-1:],
508527
self.t_event,
@@ -580,15 +599,11 @@ def _update_variable(self, variable):
580599
# Iterate through all models, some may be in the list several times and
581600
# therefore only get set up once
582601
vars_casadi = []
583-
for i, (model, ts, ys, inputs, var_pybamm) in enumerate(
584-
zip(self.all_models, self.all_ts, self.all_ys, self.all_inputs, vars_pybamm)
602+
for i, (model, ys, inputs, var_pybamm) in enumerate(
603+
zip(self.all_models, self.all_ys, self.all_inputs, vars_pybamm)
585604
):
586-
if (
587-
ys.size == 0
588-
and var_pybamm.has_symbol_of_classes(
589-
pybamm.expression_tree.state_vector.StateVector
590-
)
591-
and not ts.size == 0
605+
if self.variables_returned and var_pybamm.has_symbol_of_classes(
606+
pybamm.expression_tree.state_vector.StateVector
592607
):
593608
raise KeyError(
594609
f"Cannot process variable '{variable}' as it was not part of the "
@@ -682,7 +697,7 @@ def __getitem__(self, key):
682697
683698
Returns
684699
-------
685-
:class:`pybamm.ProcessedVariable`
700+
:class:`pybamm.ProcessedVariable` or :class:`pybamm.ProcessedVariableComputed`
686701
A variable that can be evaluated at any time or spatial point. The
687702
underlying data for this variable is available in its attribute ".data"
688703
"""
@@ -950,6 +965,7 @@ def __add__(self, other):
950965
other.termination,
951966
all_sensitivities=all_sensitivities,
952967
all_yps=all_yps,
968+
variables_returned=other.variables_returned,
953969
)
954970

955971
new_sol.closest_event_idx = other.closest_event_idx
@@ -966,6 +982,19 @@ def __add__(self, other):
966982
# Set sub_solutions
967983
new_sol._sub_solutions = self.sub_solutions + other.sub_solutions
968984

985+
# update variables which were derived at the solver stage
986+
if other._variables and all(
987+
isinstance(v, pybamm.ProcessedVariableComputed)
988+
for v in other._variables.values()
989+
):
990+
if not self._variables:
991+
new_sol._variables = other._variables.copy()
992+
else:
993+
new_sol._variables = {
994+
v: self._variables[v]._update(other._variables[v], new_sol)
995+
for v in self._variables.keys()
996+
}
997+
969998
return new_sol
970999

9711000
def __radd__(self, other):
@@ -983,6 +1012,7 @@ def copy(self):
9831012
self.termination,
9841013
self._all_sensitivities,
9851014
self.all_yps,
1015+
self.variables_returned,
9861016
)
9871017
new_sol._all_inputs_casadi = self.all_inputs_casadi
9881018
new_sol._sub_solutions = self.sub_solutions
@@ -992,6 +1022,13 @@ def copy(self):
9921022
new_sol.integration_time = self.integration_time
9931023
new_sol.set_up_time = self.set_up_time
9941024

1025+
# copy over variables which were derived at the solver stage
1026+
if self._variables and all(
1027+
isinstance(v, pybamm.ProcessedVariableComputed)
1028+
for v in self._variables.values()
1029+
):
1030+
new_sol._variables = self._variables.copy()
1031+
9951032
return new_sol
9961033

9971034
def plot_voltage_components(
@@ -1094,6 +1131,7 @@ def make_cycle_solution(
10941131
sum_sols.termination,
10951132
sum_sols._all_sensitivities,
10961133
sum_sols.all_yps,
1134+
sum_sols.variables_returned,
10971135
)
10981136
cycle_solution._all_inputs_casadi = sum_sols.all_inputs_casadi
10991137
cycle_solution._sub_solutions = sum_sols.sub_solutions

tests/integration/test_solvers/test_idaklu.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,55 @@ def test_interpolation(self):
152152
# test that y[1:3] = to true solution
153153
true_solution = b_value * sol.t
154154
np.testing.assert_array_almost_equal(sol.y[1:3], true_solution)
155+
156+
def test_with_experiments(self):
157+
summary_vars = []
158+
sols = []
159+
for out_vars in [True, False]:
160+
model = pybamm.lithium_ion.SPM()
161+
162+
if out_vars:
163+
output_variables = [
164+
"Discharge capacity [A.h]", # 0D variables
165+
"Time [s]",
166+
"Current [A]",
167+
"Voltage [V]",
168+
"Pressure [Pa]", # 1D variable
169+
"Positive particle effective diffusivity [m2.s-1]", # 2D variable
170+
]
171+
else:
172+
output_variables = None
173+
174+
solver = pybamm.IDAKLUSolver(output_variables=output_variables)
175+
176+
experiment = pybamm.Experiment(
177+
[
178+
(
179+
"Charge at 1C until 4.2 V",
180+
"Hold at 4.2 V until C/50",
181+
"Rest for 1 hour",
182+
)
183+
]
184+
)
185+
186+
sim = pybamm.Simulation(
187+
model,
188+
experiment=experiment,
189+
solver=solver,
190+
)
191+
192+
sol = sim.solve()
193+
sols.append(sol)
194+
summary_vars.append(sol.summary_variables)
195+
196+
# check computed variables are propegated sucessfully
197+
np.testing.assert_array_equal(
198+
sols[0]["Pressure [Pa]"].data, sols[1]["Pressure [Pa]"].data
199+
)
200+
np.testing.assert_array_almost_equal(
201+
sols[0]["Voltage [V]"].data, sols[1]["Voltage [V]"].data
202+
)
203+
204+
# check summary variables are the same if output variables are specified
205+
for var in summary_vars[0].keys():
206+
assert summary_vars[0][var] == summary_vars[1][var]

tests/unit/test_solvers/test_idaklu_solver.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,9 @@ def construct_model():
939939
with pytest.raises(KeyError):
940940
sol[varname].data
941941

942+
# Check Solution is marked
943+
assert sol.variables_returned is True
944+
942945
# Mock a 1D current collector and initialise (none in the model)
943946
sol["x_s [m]"].domain = ["current collector"]
944947
sol["x_s [m]"].entries

tests/unit/test_solvers/test_processed_variable_computed.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#
22
# Tests for the Processed Variable Computed class
33
#
4-
# This class forms a container for variables (and sensitivities) calculted
4+
# This class forms a container for variables (and sensitivities) calculated
55
# by the idaklu solver, and does not possesses any capability to calculate
66
# values itself since it does not have access to the full state vector
77
#
@@ -76,11 +76,12 @@ def test_processed_variable_0D(self):
7676
t_sol = np.array([0])
7777
y_sol = np.array([1])[:, np.newaxis]
7878
var_casadi = to_casadi(var, y_sol)
79+
sol = pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {})
7980
processed_var = pybamm.ProcessedVariableComputed(
8081
[var],
8182
[var_casadi],
8283
[y_sol],
83-
pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
84+
sol,
8485
)
8586
# Assert that the processed variable is the same as the solution
8687
np.testing.assert_array_equal(processed_var.entries, y_sol[0])
@@ -94,6 +95,22 @@ def test_processed_variable_0D(self):
9495
processed_var.cumtrapz_ic = 1
9596
processed_var.entries
9697

98+
# check _update
99+
t_sol2 = np.array([1])
100+
y_sol2 = np.array([2])[:, np.newaxis]
101+
var_casadi = to_casadi(var, y_sol2)
102+
sol_2 = pybamm.Solution(t_sol2, y_sol2, pybamm.BaseModel(), {})
103+
processed_var2 = pybamm.ProcessedVariableComputed(
104+
[var],
105+
[var_casadi],
106+
[y_sol2],
107+
sol_2,
108+
)
109+
110+
comb_sol = sol + sol_2
111+
comb_var = processed_var._update(processed_var2, comb_sol)
112+
np.testing.assert_array_equal(comb_var.entries, np.append(y_sol, y_sol2))
113+
97114
# check empty sensitivity works
98115
def test_processed_variable_0D_no_sensitivity(self):
99116
# without space
@@ -217,6 +234,60 @@ def test_processed_variable_1D_unknown_domain(self):
217234
c_casadi = to_casadi(c, y_sol)
218235
pybamm.ProcessedVariableComputed([c], [c_casadi], [y_sol], solution)
219236

237+
def test_processed_variable_1D_update(self):
238+
# variable 1
239+
var = pybamm.Variable("var", domain=["negative electrode", "separator"])
240+
x = pybamm.SpatialVariable("x", domain=["negative electrode", "separator"])
241+
242+
disc = tests.get_discretisation_for_testing()
243+
disc.set_variable_slices([var])
244+
x_sol1 = disc.process_symbol(x).entries[:, 0]
245+
var_sol1 = disc.process_symbol(var)
246+
t_sol1 = np.linspace(0, 1)
247+
y_sol1 = np.ones_like(x_sol1)[:, np.newaxis] * np.linspace(0, 5)
248+
249+
var_casadi1 = to_casadi(var_sol1, y_sol1)
250+
sol1 = pybamm.Solution(t_sol1, y_sol1, pybamm.BaseModel(), {})
251+
processed_var1 = pybamm.ProcessedVariableComputed(
252+
[var_sol1],
253+
[var_casadi1],
254+
[y_sol1],
255+
sol1,
256+
)
257+
258+
# variable 2 -------------------
259+
var2 = pybamm.Variable("var2", domain=["negative electrode", "separator"])
260+
z = pybamm.SpatialVariable("z", domain=["negative electrode", "separator"])
261+
262+
disc = tests.get_discretisation_for_testing()
263+
disc.set_variable_slices([var2])
264+
z_sol2 = disc.process_symbol(z).entries[:, 0]
265+
var_sol2 = disc.process_symbol(var2)
266+
t_sol2 = np.linspace(2, 3)
267+
y_sol2 = np.ones_like(z_sol2)[:, np.newaxis] * np.linspace(5, 1)
268+
269+
var_casadi2 = to_casadi(var_sol2, y_sol2)
270+
sol2 = pybamm.Solution(t_sol2, y_sol2, pybamm.BaseModel(), {})
271+
var_2 = pybamm.ProcessedVariableComputed(
272+
[var_sol2],
273+
[var_casadi2],
274+
[y_sol2],
275+
sol2,
276+
)
277+
278+
comb_sol = sol1 + sol2
279+
comb_var = processed_var1._update(var_2, comb_sol)
280+
281+
# Ordering from idaklu with output_variables set is different to
282+
# the full solver
283+
y_sol1 = y_sol1.reshape((y_sol1.shape[1], y_sol1.shape[0])).transpose()
284+
y_sol2 = y_sol2.reshape((y_sol2.shape[1], y_sol2.shape[0])).transpose()
285+
286+
np.testing.assert_array_equal(
287+
comb_var.entries, np.concatenate((y_sol1, y_sol2), axis=1)
288+
)
289+
np.testing.assert_array_equal(comb_var.entries, comb_var.data)
290+
220291
def test_processed_variable_2D_x_r(self):
221292
var = pybamm.Variable(
222293
"var",

0 commit comments

Comments
 (0)