Skip to content

Commit f5a6f1e

Browse files
Rishab87martinjrobinspre-commit-ci[bot]kratman
authored
feat: added an option for multiple initial conditions in IDAKLU solver (#4981)
* initial conditions options and few tests * added to changelog * fixed tests, coverage, changelog and other minor changes * changes in changelog * made changes according to comments * minor changes * fixed tests and tolerances * minor change * style: pre-commit fixes * Update CHANGELOG.md * fixed docs * minor change in test * style: pre-commit fixes * Update src/pybamm/solvers/idaklu_solver.py Co-authored-by: Eric G. Kratz <[email protected]> --------- Co-authored-by: Martin Robinson <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric G. Kratz <[email protected]>
1 parent 43f8f43 commit f5a6f1e

File tree

7 files changed

+390
-9
lines changed

7 files changed

+390
-9
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
- Creates `BaseProcessedVariable` to enable object combination when adding solutions together ([#5076](https://github.com/pybamm-team/PyBaMM/pull/5076))
2727
- Added a `Constant` symbol for named constants. This is a subclass of `Scalar` and is used to represent named constants such as the gas constant. This avoids constants being simplified out when constructing expressions. ([#5070](https://github.com/pybamm-team/PyBaMM/pull/5070))
2828
- Generalise `pybamm.DiscreteTimeSum` to allow it to be embedded in other expressions ([#5044](https://github.com/pybamm-team/PyBaMM/pull/5044))
29+
- Added an option for multiple initial conditions in IDAKLU solver ([#4981](https://github.com/pybamm-team/PyBaMM/pull/4981))
2930
- Adds `all` key-value pair to `output_variables` sensitivity dictionaries, accessible through `solution[var].sensitivities['all']`. Aligns shape with conventional solution sensitivities object. ([#5067](https://github.com/pybamm-team/PyBaMM/pull/5067))
3031
- Added a new `BaseHysteresisOpenCircuitPotential` class that sets variables for the lithiation and delithiation OCP and the hysteresis voltage (`H = U_lith - U_delith`). Allow the initial hysteresis state to be a function of position through the electrode. Allow the hysteresis decay rates of the Axen and Wycisk models to be a function of stoichiometry and temperature. Added a heat source term in each active material phase `Q_hys = i_vol * (U - U_eq)` where `i_vol` is the volumetric interfacial current density, `U` is the OCP (i.e. includes hysteresis), and `U_eq` is the "equilibrium OCP". Renamed the open-circuit potential models to be more descriptive. The options "Axen" and "Wycisk" are now "one-state hysteresis" and "one-state differential capacity hysteresis". The old option names still work but will raise a warning. ([#4893](https://github.com/pybamm-team/PyBaMM/pull/4893))
3132
- Add support for `output_variables` to `pybamm.DiscreteTimeSum` and `pybamm.ExplicitTimeIntegral` expressions. ([#5071](https://github.com/pybamm-team/PyBaMM/pull/5071))

src/pybamm/simulation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def solve(
368368
showprogress=False,
369369
inputs=None,
370370
t_interp=None,
371+
initial_conditions=None,
371372
**kwargs,
372373
):
373374
"""
@@ -521,9 +522,13 @@ def solve(
521522
pybamm.SolverWarning,
522523
stacklevel=2,
523524
)
524-
525525
self._solution = solver.solve(
526-
self._built_model, t_eval, inputs=inputs, t_interp=t_interp, **kwargs
526+
self._built_model,
527+
t_eval,
528+
inputs=inputs,
529+
t_interp=t_interp,
530+
**kwargs,
531+
initial_conditions=initial_conditions,
527532
)
528533

529534
elif self.operating_mode == "with experiment":

src/pybamm/solvers/base_solver.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ def solve(
676676
nproc=None,
677677
calculate_sensitivities=False,
678678
t_interp=None,
679+
initial_conditions=None,
679680
):
680681
"""
681682
Execute the solver setup and calculate the solution of the model at
@@ -706,7 +707,14 @@ def solve(
706707
t_interp : None, list or ndarray, optional
707708
The times (in seconds) at which to interpolate the solution. Defaults to None.
708709
Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`).
710+
initial_conditions : dict, numpy.ndarray, or list, optional
711+
Override the model’s default `y0`. Can be:
709712
713+
- a dict mapping variable names → values
714+
- a 1D array of length `n_states`
715+
- a list of such overrides (one per parallel solve)
716+
717+
Only valid for IDAKLU solver.
710718
Returns
711719
-------
712720
:class:`pybamm.Solution` or list of :class:`pybamm.Solution` objects.
@@ -878,6 +886,7 @@ def solve(
878886
t_eval[start_index:end_index],
879887
model_inputs_list,
880888
t_interp,
889+
initial_conditions,
881890
)
882891
else:
883892
ninputs = len(model_inputs_list)

src/pybamm/solvers/idaklu_solver.py

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,44 @@ def __setstate__(self, d):
532532
def supports_parallel_solve(self):
533533
return True
534534

535-
def _integrate(self, model, t_eval, inputs_list=None, t_interp=None):
535+
def _apply_solver_initial_conditions(self, model, initial_conditions):
536+
"""
537+
Apply custom initial conditions to a model by overriding model.y0.
538+
539+
Parameters
540+
----------
541+
model : pybamm.BaseModel
542+
A model with a precomputed y0 vector.
543+
initial_conditions : dict or numpy.ndarray
544+
Either a mapping from variable names to values (scalar or array),
545+
or a flat numpy array matching the length of model.y0.
546+
"""
547+
if isinstance(initial_conditions, dict):
548+
y0_np = (
549+
model.y0.full() if isinstance(model.y0, casadi.DM) else model.y0.copy()
550+
)
551+
552+
for var_name, value in initial_conditions.items():
553+
found = False
554+
for symbol, slice_info in model.y_slices.items():
555+
if symbol.name == var_name:
556+
var_slice = slice_info[0]
557+
y0_np[var_slice] = value
558+
found = True
559+
break
560+
if not found:
561+
raise ValueError(f"Variable '{var_name}' not found in model")
562+
563+
model.y0 = casadi.DM(y0_np)
564+
565+
elif isinstance(initial_conditions, np.ndarray):
566+
model.y0 = casadi.DM(initial_conditions)
567+
else:
568+
raise TypeError("Initial conditions must be dict or numpy array")
569+
570+
def _integrate(
571+
self, model, t_eval, inputs_list=None, t_interp=None, initial_conditions=None
572+
):
536573
"""
537574
Solve a DAE model defined by residuals with initial conditions y0.
538575
@@ -547,6 +584,13 @@ def _integrate(self, model, t_eval, inputs_list=None, t_interp=None):
547584
t_interp : None, list or ndarray, optional
548585
The times (in seconds) at which to interpolate the solution. Defaults to `None`,
549586
which returns the adaptive time-stepping times.
587+
initial_conditions : dict, numpy.ndarray, or list, optional
588+
Override the model’s default `y0`. Can be:
589+
590+
- a dict mapping variable names → values
591+
- a 1D array of length `n_states`
592+
- a list of such overrides (one per parallel solve)
593+
550594
"""
551595
if model.convert_to_format != "casadi": # pragma: no cover
552596
# Shouldn't ever reach this point
@@ -565,11 +609,36 @@ def _integrate(self, model, t_eval, inputs_list=None, t_interp=None):
565609
else:
566610
inputs = np.array([[]] * len(inputs_list))
567611

568-
# stack y0full and ydot0full so they are a 2D array of shape (number_of_inputs, number_of_states + number_of_parameters * number_of_states)
569-
# note that y0full and ydot0full are currently 1D arrays (i.e. independent of inputs), but in the future we will support
570-
# different initial conditions for different inputs (see https://github.com/pybamm-team/PyBaMM/pull/4260). For now we just repeat the same initial conditions for each input
571-
y0full = np.vstack([model.y0full] * len(inputs_list))
572-
ydot0full = np.vstack([model.ydot0full] * len(inputs_list))
612+
if initial_conditions is not None:
613+
if isinstance(initial_conditions, list):
614+
if len(initial_conditions) != len(inputs_list):
615+
raise ValueError(
616+
"Number of initial conditions must match number of input sets"
617+
)
618+
619+
y0_list = []
620+
621+
model_copy = model.new_copy()
622+
for ic in initial_conditions:
623+
self._apply_solver_initial_conditions(model_copy, ic)
624+
y0_list.append(model_copy.y0.full().flatten())
625+
626+
y0full = np.vstack(y0_list)
627+
ydot0full = np.zeros_like(y0full)
628+
629+
else:
630+
self._apply_solver_initial_conditions(model, initial_conditions)
631+
632+
y0_np = model.y0.full()
633+
634+
y0full = np.vstack([y0_np for _ in range(len(inputs_list))])
635+
ydot0full = np.zeros_like(y0full)
636+
else:
637+
# stack y0full and ydot0full so they are a 2D array of shape (number_of_inputs, number_of_states + number_of_parameters * number_of_states)
638+
# note that y0full and ydot0full are currently 1D arrays (i.e. independent of inputs), but in the future we will support
639+
# different initial conditions for different inputs. For now we just repeat the same initial conditions for each input
640+
y0full = np.vstack([model.y0full] * len(inputs_list))
641+
ydot0full = np.vstack([model.ydot0full] * len(inputs_list))
573642

574643
atol = getattr(model, "atol", self.atol)
575644
atol = self._check_atol_type(atol, y0full.size)

src/pybamm/solvers/jax_solver.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ def supports_parallel_solve(self):
200200
def requires_explicit_sensitivities(self):
201201
return False
202202

203-
def _integrate(self, model, t_eval, inputs=None, t_interp=None):
203+
def _integrate(
204+
self, model, t_eval, inputs=None, t_interp=None, intial_conditions=None
205+
):
204206
"""
205207
Solve a model defined by dydt with initial conditions y0.
206208
@@ -220,6 +222,10 @@ def _integrate(self, model, t_eval, inputs=None, t_interp=None):
220222
various diagnostic messages.
221223
222224
"""
225+
if intial_conditions is not None: # pragma: no cover
226+
raise NotImplementedError(
227+
"Setting initial conditions is not yet implemented for the JAX IDAKLU solver"
228+
)
223229
if isinstance(inputs, dict):
224230
inputs = [inputs]
225231
timer = pybamm.Timer()

tests/integration/test_solvers/test_idaklu.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23

34
import pybamm
45

@@ -210,6 +211,71 @@ def test_with_experiments(self):
210211
sols[1].cycles[-1]["Current [A]"].data,
211212
)
212213

214+
@pytest.mark.parametrize(
215+
"model_cls, make_ics",
216+
[
217+
(pybamm.lithium_ion.SPM, lambda y0: [y0, 2 * y0]),
218+
(
219+
pybamm.lithium_ion.DFN,
220+
lambda y0: [y0, y0 * (1 + 0.01 * np.ones_like(y0))],
221+
),
222+
],
223+
)
224+
def test_multiple_initial_conditions_against_independent_solves(
225+
self, model_cls, make_ics
226+
):
227+
model = model_cls()
228+
geom = model.default_geometry
229+
pv = model.default_parameter_values
230+
pv.process_model(model)
231+
pv.process_geometry(geom)
232+
mesh = pybamm.Mesh(geom, model.default_submesh_types, model.default_var_pts)
233+
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
234+
disc.process_model(model)
235+
236+
t_eval = np.array([0, 1])
237+
solver = pybamm.IDAKLUSolver()
238+
239+
base_sol = solver.solve(model, t_eval)
240+
y0_base = base_sol.y[:, 0]
241+
242+
ics = make_ics(y0_base)
243+
inputs = [{}] * len(ics)
244+
245+
multi_sols = solver.solve(
246+
model,
247+
t_eval,
248+
inputs=inputs,
249+
initial_conditions=ics,
250+
)
251+
assert isinstance(multi_sols, list) and len(multi_sols) == 2
252+
253+
indep_sols = []
254+
for ic in ics:
255+
sol_indep = solver.solve(
256+
model, t_eval, inputs=[{}], initial_conditions=[ic]
257+
)
258+
if isinstance(sol_indep, list):
259+
sol_indep = sol_indep[0]
260+
indep_sols.append(sol_indep)
261+
262+
if model_cls is pybamm.lithium_ion.SPM:
263+
rtol, atol = 1e-8, 1e-10
264+
else:
265+
rtol, atol = 1e-6, 1e-8
266+
267+
for idx in (0, 1):
268+
sol_vec = multi_sols[idx]
269+
sol_ind = indep_sols[idx]
270+
271+
np.testing.assert_allclose(sol_vec.t, sol_ind.t, rtol=1e-12, atol=0)
272+
np.testing.assert_allclose(sol_vec.y, sol_ind.y, rtol=rtol, atol=atol)
273+
274+
if model_cls is pybamm.lithium_ion.SPM:
275+
np.testing.assert_allclose(
276+
sol_vec.y[:, 0], ics[idx], rtol=1e-8, atol=1e-10
277+
)
278+
213279
def test_outvars_with_experiments_multi_simulation(self):
214280
model = pybamm.lithium_ion.SPM()
215281

0 commit comments

Comments
 (0)