diff --git a/.github/workflows/benchmark_on_push.yml b/.github/workflows/benchmark_on_push.yml index 18a9e7931d..f70556c782 100644 --- a/.github/workflows/benchmark_on_push.yml +++ b/.github/workflows/benchmark_on_push.yml @@ -40,7 +40,7 @@ jobs: enable-cache: true - name: Install python dependencies - run: uv pip install asv[virtualenv] + run: uv pip install --system asv - name: Fetch base branch run: | diff --git a/.github/workflows/periodic_benchmarks.yml b/.github/workflows/periodic_benchmarks.yml index 86cd1991ab..0a7a7d9c47 100644 --- a/.github/workflows/periodic_benchmarks.yml +++ b/.github/workflows/periodic_benchmarks.yml @@ -48,7 +48,7 @@ jobs: enable-cache: true - name: Install python dependencies - run: uv pip install asv[virtualenv] + run: uv pip install --system asv - name: Run benchmarks run: | diff --git a/CHANGELOG.md b/CHANGELOG.md index 29af0fbe72..24d2770211 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,14 @@ ## Features +- Adds the ability to observe custom 0D variables from a `Solution` object. ([#5308](https://github.com/pybamm-team/PyBaMM/pull/5308)) +- Adds `silence_sundials_errors` IDAKLU solver option with `default=False` to match historical output. ([#5290](https://github.com/pybamm-team/PyBaMM/pull/5290)) + ## Bug fixes +- Fixed a bug where `IDAKLUSolver` errors were not raised correctly. ([#5291](https://github.com/pybamm-team/PyBaMM/pull/5291)) +- Fix a bug with serialising `InputParameter`s. ([#5289](https://github.com/pybamm-team/PyBaMM/pull/5289)) + # [v25.10.1](https://github.com/pybamm-team/PyBaMM/tree/v25.10.1) - 2025-11-14 ## Features @@ -19,6 +25,7 @@ ## Features +- Added uniform grid sizing across subdomains in the x-dimension, ensuring consistent grid spacing when geometries have varying lengths. ([#5253](https://github.com/pybamm-team/PyBaMM/pull/5253)) - Added the `electrode_phases` kwarg to `plot_voltage_components()` which allows choosing between plotting primary or secondary phase overpotentials. ([#5229](https://github.com/pybamm-team/PyBaMM/pull/5229)) - Added the `num_steps_no_progress` and `t_no_progress` options in the `IDAKLUSolver` to early terminate the simulation if little progress is detected. ([#5201](https://github.com/pybamm-team/PyBaMM/pull/5201)) - EvaluateAt symbol: add support for children evaluated at edges ([#5190](https://github.com/pybamm-team/PyBaMM/pull/5190)) diff --git a/docs/source/api/expression_tree/operations/index.rst b/docs/source/api/expression_tree/operations/index.rst index 67beaca136..8e629d258d 100644 --- a/docs/source/api/expression_tree/operations/index.rst +++ b/docs/source/api/expression_tree/operations/index.rst @@ -10,3 +10,4 @@ Classes and functions that operate on the expression tree convert_to_casadi serialise unpack_symbol + replace_symbols diff --git a/docs/source/api/expression_tree/operations/replace_symbols.rst b/docs/source/api/expression_tree/operations/replace_symbols.rst new file mode 100644 index 0000000000..a2b65b1df6 --- /dev/null +++ b/docs/source/api/expression_tree/operations/replace_symbols.rst @@ -0,0 +1,11 @@ +Symbol Replacer +================ + +.. autoclass:: pybamm.SymbolReplacer + :members: + +Variable Replacement Map +========================= + +.. autoclass:: pybamm.VariableReplacementMap + :members: diff --git a/docs/source/examples/notebooks/creating_models/4-comparing-full-and-reduced-order-models.ipynb b/docs/source/examples/notebooks/creating_models/4-comparing-full-and-reduced-order-models.ipynb index 071fd54f95..19780371c0 100644 --- a/docs/source/examples/notebooks/creating_models/4-comparing-full-and-reduced-order-models.ipynb +++ b/docs/source/examples/notebooks/creating_models/4-comparing-full-and-reduced-order-models.ipynb @@ -24,7 +24,7 @@ "$$\n", "\\left.c\\right\\vert_{t=0} = c_0,\n", "$$\n", - "where $c$$ is the concentration, $r$ the radial coordinate, $t$ time, $R$ the particle radius, $D$ the diffusion coefficient, $j$ the interfacial current density, $F$ Faraday's constant, and $c_0$ the initial concentration. \n", + "where $c$ is the concentration, $r$ the radial coordinate, $t$ time, $R$ the particle radius, $D$ the diffusion coefficient, $j$ the interfacial current density, $F$ Faraday's constant, and $c_0$ the initial concentration. \n", "\n", "As in the previous example we use the following parameters:\n", "\n", diff --git a/pyproject.toml b/pyproject.toml index 5f214ef642..364b31a52e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ "Topic :: Scientific/Engineering", ] dependencies = [ - "pybammsolvers>=0.3.0,<0.4.0", + "pybammsolvers>=0.3.3,<0.4.0", "black", "numpy", "scipy>=1.11.4", diff --git a/src/pybamm/__init__.py b/src/pybamm/__init__.py index b80ae2db5e..07ea4e4154 100644 --- a/src/pybamm/__init__.py +++ b/src/pybamm/__init__.py @@ -59,6 +59,10 @@ from .expression_tree.operations.convert_to_casadi import CasadiConverter from .expression_tree.operations.unpack_symbols import SymbolUnpacker from .expression_tree.operations.serialise import Serialise,ExpressionFunctionParameter +from .expression_tree.operations.replace_symbols import ( + SymbolReplacer, + VariableReplacementMap, +) # Model classes from .models.base_model import BaseModel diff --git a/src/pybamm/discretisations/discretisation.py b/src/pybamm/discretisations/discretisation.py index f1c0a7459b..40cd45c518 100644 --- a/src/pybamm/discretisations/discretisation.py +++ b/src/pybamm/discretisations/discretisation.py @@ -800,34 +800,32 @@ def process_symbol(self, symbol): Discretised symbol """ - try: - return self._discretised_symbols[symbol] - except KeyError: - discretised_symbol = self._process_symbol(symbol) - self._discretised_symbols[symbol] = discretised_symbol - discretised_symbol.test_shape() - - # Assign mesh as an attribute to the processed variable - if symbol.domain != []: - discretised_symbol.mesh = self.mesh[symbol.domain] - else: - discretised_symbol.mesh = None + _discretised_symbol = self._discretised_symbols.get(symbol) + if _discretised_symbol is not None: + return _discretised_symbol + discretised_symbol = self._process_symbol(symbol) + self._discretised_symbols[symbol] = discretised_symbol + discretised_symbol.test_shape() + + # Assign mesh as an attribute to the processed variable + if symbol.domain != []: + discretised_symbol.mesh = self.mesh[symbol.domain] + else: + discretised_symbol.mesh = None - # Assign secondary mesh - if symbol.domains["secondary"] != []: - discretised_symbol.secondary_mesh = self.mesh[ - symbol.domains["secondary"] - ] - else: - discretised_symbol.secondary_mesh = None + # Assign secondary mesh + if symbol.domains["secondary"] != []: + discretised_symbol.secondary_mesh = self.mesh[symbol.domains["secondary"]] + else: + discretised_symbol.secondary_mesh = None - # Assign tertiary mesh - if symbol.domains["tertiary"] != []: - discretised_symbol.tertiary_mesh = self.mesh[symbol.domains["tertiary"]] - else: - discretised_symbol.tertiary_mesh = None + # Assign tertiary mesh + if symbol.domains["tertiary"] != []: + discretised_symbol.tertiary_mesh = self.mesh[symbol.domains["tertiary"]] + else: + discretised_symbol.tertiary_mesh = None - return discretised_symbol + return discretised_symbol def _process_symbol(self, symbol): """See :meth:`Discretisation.process_symbol()`.""" @@ -841,15 +839,13 @@ def _process_symbol(self, symbol): self.bcs[key_id] = self.check_tab_conditions( symbol, self.bcs[key_id] ) + else: + spatial_method = None if isinstance(symbol, pybamm.BinaryOperator): # Pre-process children left, right = symbol.children # Catch case where diffusion is a scalar and turn it into an identity matrix vector field. - if len(symbol.domain) != 0: - spatial_method = self.spatial_methods[symbol.domain[0]] - else: - spatial_method = None if isinstance(spatial_method, pybamm.FiniteVolume2D): if isinstance(left, pybamm.Scalar) and ( isinstance(right, pybamm.VectorField) diff --git a/src/pybamm/expression_tree/operations/convert_to_casadi.py b/src/pybamm/expression_tree/operations/convert_to_casadi.py index 2192bf7b9a..52d87c7004 100644 --- a/src/pybamm/expression_tree/operations/convert_to_casadi.py +++ b/src/pybamm/expression_tree/operations/convert_to_casadi.py @@ -46,15 +46,15 @@ def convert( :class:`casadi.MX` The converted symbol """ - try: - return self._casadi_symbols[symbol] - except KeyError: - # Change inputs to empty dictionary if it's None - inputs = inputs or {} - casadi_symbol = self._convert(symbol, t, y, y_dot, inputs) - self._casadi_symbols[symbol] = casadi_symbol + _casadi_symbol = self._casadi_symbols.get(symbol) + if _casadi_symbol is not None: + return _casadi_symbol + # Change inputs to empty dictionary if it's None + inputs = inputs or {} + casadi_symbol = self._convert(symbol, t, y, y_dot, inputs) + self._casadi_symbols[symbol] = casadi_symbol - return casadi_symbol + return casadi_symbol def _convert(self, symbol, t, y, y_dot, inputs): """See :meth:`CasadiConverter.convert()`.""" diff --git a/src/pybamm/expression_tree/operations/replace_symbols.py b/src/pybamm/expression_tree/operations/replace_symbols.py new file mode 100644 index 0000000000..bbda550f5b --- /dev/null +++ b/src/pybamm/expression_tree/operations/replace_symbols.py @@ -0,0 +1,219 @@ +import pybamm + + +class SymbolReplacer: + """ + Helper class to replace all instances of one or more symbols in an expression tree + with another symbol, as defined by the dictionary `symbol_replacement_map` + + Parameters + ---------- + symbol_replacement_map : dict {:class:`pybamm.Symbol` -> :class:`pybamm.Symbol`} + Map of which symbols should be replaced by which. + processed_symbols: dict {:class:`pybamm.Symbol` -> :class:`pybamm.Symbol`}, optional + cached replaced symbols + process_initial_conditions: bool, optional + Whether to process initial conditions, default is True + """ + + def __init__( + self, + symbol_replacement_map, + processed_symbols=None, + process_initial_conditions=True, + ): + self._symbol_replacement_map = symbol_replacement_map + self._processed_symbols = processed_symbols or {} + self.process_initial_conditions = process_initial_conditions + + def process_model(self, unprocessed_model, inplace=True): + """Replace all instances of a symbol in a model. + + Parameters + ---------- + unprocessed_model : :class:`pybamm.BaseModel` + Model to assign parameter values for + inplace: bool, optional + If True, replace the parameters in the model in place. Otherwise, return a + new model with parameter values set. Default is True. + """ + pybamm.logger.info(f"Start replacing symbols in {unprocessed_model.name}") + + # set up inplace vs not inplace + if inplace: + # any changes to unprocessed_model attributes will change model attributes + # since they point to the same object + model = unprocessed_model + else: + # create a copy of the model + model = unprocessed_model.new_copy() + + new_rhs = {} + for variable, equation in unprocessed_model.rhs.items(): + pybamm.logger.verbose(f"Replacing symbols in {variable!r} (rhs)") + new_rhs[self.process_symbol(variable)] = self.process_symbol(equation) + model.rhs = new_rhs + + new_algebraic = {} + for variable, equation in unprocessed_model.algebraic.items(): + pybamm.logger.verbose(f"Replacing symbols in {variable!r} (algebraic)") + new_algebraic[self.process_symbol(variable)] = self.process_symbol(equation) + model.algebraic = new_algebraic + + new_initial_conditions = {} + for variable, equation in unprocessed_model.initial_conditions.items(): + pybamm.logger.verbose( + f"Replacing symbols in {variable!r} (initial conditions)" + ) + if self.process_initial_conditions: + new_initial_conditions[self.process_symbol(variable)] = ( + self.process_symbol(equation) + ) + else: + new_initial_conditions[self.process_symbol(variable)] = equation + model.initial_conditions = new_initial_conditions + + model.boundary_conditions = self.process_boundary_conditions(unprocessed_model) + + new_variables = {} + for variable, equation in unprocessed_model.variables.items(): + pybamm.logger.verbose(f"Replacing symbols in {variable!r} (variables)") + new_variables[variable] = self.process_symbol(equation) + model.variables = new_variables + + new_events = [] + for event in unprocessed_model.events: + pybamm.logger.verbose(f"Replacing symbols in event'{event.name}''") + new_events.append( + pybamm.Event( + event.name, self.process_symbol(event.expression), event.event_type + ) + ) + model.events = new_events + + pybamm.logger.info(f"Finish replacing symbols in {model.name}") + + return model + + def process_boundary_conditions(self, model): + """ + Process boundary conditions for a model + Boundary conditions are dictionaries {"left": left bc, "right": right bc} + in general, but may be imposed on the tabs (or *not* on the tab) for a + small number of variables, e.g. {"negative tab": neg. tab bc, + "positive tab": pos. tab bc "no tab": no tab bc}. + """ + new_boundary_conditions = {} + sides = ["left", "right", "negative tab", "positive tab", "no tab"] + for variable, bcs in model.boundary_conditions.items(): + processed_variable = self.process_symbol(variable) + new_boundary_conditions[processed_variable] = {} + for side in sides: + try: + bc, typ = bcs[side] + pybamm.logger.verbose( + f"Replacing symbols in {variable!r} ({side} bc)" + ) + processed_bc = (self.process_symbol(bc), typ) + new_boundary_conditions[processed_variable][side] = processed_bc + except KeyError as err: + # don't raise error if the key error comes from the side not being + # found + if err.args[0] in side: + pass + # do raise error otherwise (e.g. can't process symbol) + else: # pragma: no cover + raise KeyError(err) from err + + return new_boundary_conditions + + def process_symbol(self, symbol): + """ + This function recurses down the tree, replacing any symbols in + self._symbol_replacement_map with their corresponding value + + Parameters + ---------- + symbol : :class:`pybamm.Symbol` + The symbol to replace + + Returns + ------- + :class:`pybamm.Symbol` + Symbol with all replacements performed + """ + + _processed_symbol = self._processed_symbols.get(symbol) + if _processed_symbol is not None: + return _processed_symbol + + replaced_symbol = self._process_symbol(symbol) + self._processed_symbols[symbol] = replaced_symbol + return replaced_symbol + + def _process_symbol(self, symbol): + """See :meth:`Simplification.process_symbol()`.""" + _processed_symbol = self._processed_symbols.get(symbol) + if _processed_symbol is not None: + return _processed_symbol + + if symbol in self._symbol_replacement_map: + return self._symbol_replacement_map[symbol] + + if isinstance(symbol, pybamm.BinaryOperator): + left, right = symbol.children + # process children + new_left = self.process_symbol(left) + new_right = self.process_symbol(right) + # Return a new copy with the replaced symbols + return symbol._binary_new_copy(new_left, new_right) + + elif isinstance(symbol, pybamm.UnaryOperator): + new_child = self.process_symbol(symbol.child) + # Return a new copy with the replaced symbols + return symbol._unary_new_copy(new_child) + + elif isinstance(symbol, pybamm.Function): + new_children = [self.process_symbol(child) for child in symbol.children] + # Return a new copy with the replaced symbols + return symbol._function_new_copy(new_children) + + elif isinstance(symbol, pybamm.Concatenation): + new_children = [self.process_symbol(child) for child in symbol.children] + # Return a new copy with the replaced symbols + return symbol._concatenation_new_copy(new_children) + + else: + # Only other option is that the symbol is a leaf (doesn't have children) + # In this case, since we have already ruled out that the symbol is one of + # the symbols that needs to be replaced, we can just return the symbol + return symbol + + +class VariableReplacementMap: + """ + A simple dict-like object that efficiently resolves pybamm symbols by name. + """ + + __slots__ = ["_symbol_replacement_map"] + + def __init__(self, symbol_replacement_map: dict[str, pybamm.Symbol]): + self._symbol_replacement_map = symbol_replacement_map + + def __getitem__(self, symbol): + return self._symbol_replacement_map[symbol.name] + + def __contains__(self, symbol): + return self.get(symbol) is not None + + def get(self, symbol, default=None): + if not isinstance(symbol, pybamm.Variable): + return default + + name = symbol.name + value = self._symbol_replacement_map.get(name) + + # Check exact variable match + if value is not None and pybamm.Variable(name) == symbol: + return value + return default diff --git a/src/pybamm/expression_tree/operations/serialise.py b/src/pybamm/expression_tree/operations/serialise.py index 6b733e0fad..98e44e2410 100644 --- a/src/pybamm/expression_tree/operations/serialise.py +++ b/src/pybamm/expression_tree/operations/serialise.py @@ -435,6 +435,13 @@ def serialise_custom_model(model: pybamm.BaseModel) -> dict: str(variable_name): convert_symbol_to_json(expression) for variable_name, expression in getattr(model, "variables", {}).items() }, + "parameter_values": ( + Serialise._serialise_parameter_values( + getattr(model, "_parameter_values", None) + ) + if getattr(model, "_parameter_values", None) is not None + else None + ), } SCHEMA_VERSION = "1.0" @@ -1318,6 +1325,21 @@ def load_custom_model(filename: str | dict) -> pybamm.BaseModel: f"Failed to convert variable '{variable_name}': {e!s}" ) from e + # Load parameter_values if present + # Note: fixed_input_parameters is now computed from parameter_values, so no need to load separately + if ( + "parameter_values" in model_data + and model_data["parameter_values"] is not None + ): + try: + model._parameter_values = Serialise._deserialise_parameter_values( + model_data["parameter_values"] + ) + except Exception as e: + raise ValueError(f"Failed to convert parameter_values: {e!s}") from e + else: + model._parameter_values = None + return model @staticmethod @@ -1549,6 +1571,73 @@ def _convert_options(self, d): else: return d + @staticmethod + def _serialise_parameter_values( + parameter_values: pybamm.ParameterValues | None, + ) -> dict | None: + """ + Serializes a ParameterValues object to a JSON-serializable dictionary. + + Parameters + ---------- + parameter_values : :class:`pybamm.ParameterValues` or None + The parameter values to serialize. + + Returns + ------- + dict or None + A JSON-serializable dictionary representation of the parameter values, or None if input is None. + """ + if parameter_values is None: + return None + + parameter_values_dict = {} + for k, v in parameter_values.items(): + if callable(v): + parameter_values_dict[k] = convert_symbol_to_json( + convert_function_to_symbolic_expression(v, k) + ) + else: + parameter_values_dict[k] = convert_symbol_to_json(v) + + return parameter_values_dict + + @staticmethod + def _deserialise_parameter_values( + parameter_values_dict: dict, + ) -> pybamm.ParameterValues: + """ + Deserializes a dictionary back into a ParameterValues object. + + Parameters + ---------- + parameter_values_dict : dict + Dictionary containing the serialized parameter values. + + Returns + ------- + :class:`pybamm.ParameterValues` + The reconstructed ParameterValues object. + """ + deserialized = {} + for key, val in parameter_values_dict.items(): + if isinstance(val, dict) and "type" in val: + deserialized[key] = convert_symbol_from_json(val) + elif isinstance(val, list): + deserialized[key] = val + elif isinstance(val, (numbers.Number | bool)): + deserialized[key] = val + elif isinstance(val, str): + deserialized[key] = val + elif isinstance(val, dict): + deserialized[key] = val + else: + raise ValueError( + f"Unsupported parameter format for key '{key}': {val!r}" + ) + + return pybamm.ParameterValues(deserialized) + def convert_function_to_symbolic_expression(func, name=None): """ @@ -1625,6 +1714,8 @@ def convert_symbol_from_json(json_data): elif json_data["type"] == "Parameter": # Convert stored parameters back to PyBaMM Parameter objects return pybamm.Parameter(json_data["name"]) + elif json_data["type"] == "InputParameter": + return pybamm.InputParameter(json_data["name"]) elif json_data["type"] == "Scalar": # Convert stored numerical values back to PyBaMM Scalar objects return pybamm.Scalar(json_data["value"]) diff --git a/src/pybamm/expression_tree/symbol.py b/src/pybamm/expression_tree/symbol.py index d4f666eded..4be40cf222 100644 --- a/src/pybamm/expression_tree/symbol.py +++ b/src/pybamm/expression_tree/symbol.py @@ -247,13 +247,10 @@ def __init__( # Test shape on everything but nodes that contain the base Symbol class or # the base BinaryOperator class - if pybamm.settings.debug_mode is True: - if not any( - issubclass(pybamm.Symbol, type(x)) - or issubclass(pybamm.BinaryOperator, type(x)) - for x in self.pre_order() - ): - self.test_shape() + if pybamm.settings.debug_mode is True and not any( + isinstance(x, (Symbol | pybamm.BinaryOperator)) for x in self.pre_order() + ): + self.test_shape() @classmethod def _from_json(cls, snippet: dict): diff --git a/src/pybamm/meshes/meshes.py b/src/pybamm/meshes/meshes.py index f717b0a347..ddc295d57b 100644 --- a/src/pybamm/meshes/meshes.py +++ b/src/pybamm/meshes/meshes.py @@ -9,6 +9,39 @@ import pybamm +def compute_var_pts_from_thicknesses(electrode_thicknesses, grid_size): + """ + Compute a ``var_pts`` dictionary using electrode thicknesses and a target cell size (dx). + + Added as per maintainer feedback in issue # to make mesh generation + explicit — ``grid_size`` now represents the mesh cell size in metres. + + Parameters + ---------- + electrode_thicknesses : dict + Domain thicknesses in metres. + grid_size : float + Desired uniform mesh cell size (m). + + Returns + ------- + dict + Mapping of each domain to its computed grid points. + """ + if not isinstance(electrode_thicknesses, dict): + raise TypeError("electrode_thicknesses must be a dictionary") + + if not isinstance(grid_size, (int | float)) or grid_size <= 0: + raise ValueError("grid_size must be a positive number") + + var_pts = {} + for domain, thickness in electrode_thicknesses.items(): + npts = max(round(thickness / grid_size), 2) + var_pts[domain] = {f"x_{domain[0]}": npts} + + return var_pts + + class Mesh(dict): """ Mesh contains a list of submeshes on each subdomain. diff --git a/src/pybamm/models/base_model.py b/src/pybamm/models/base_model.py index 9fd7483385..04f4c05f4c 100644 --- a/src/pybamm/models/base_model.py +++ b/src/pybamm/models/base_model.py @@ -73,6 +73,7 @@ def __init__(self, name="Unnamed model"): self._is_standard_form_dae = None self._variables_casadi = {} self._geometry = pybamm.Geometry({}) + self._parameter_values = None # Default behaviour is to use the jacobian self.use_jacobian = True @@ -430,6 +431,26 @@ def default_parameter_values(self): """Returns the default parameter values for the model (an empty set of parameters by default).""" return pybamm.ParameterValues({}) + @property + def parameter_values(self): + """Returns the parameter values for the model.""" + return self._parameter_values + + @parameter_values.setter + def parameter_values(self, parameter_values): + self._parameter_values = parameter_values + + @property + def fixed_input_parameters(self): + """Returns a dictionary of all fixed input parameters from parameter_values.""" + if self._parameter_values is None: + return {} + return { + k: v + for k, v in self._parameter_values.items() + if isinstance(v, pybamm.InputParameter) + } + @property def parameters(self): """Returns a list of all parameter symbols used in the model.""" @@ -736,6 +757,7 @@ def _find_symbols(self, typ): for side in x.keys() ] + list(self.variables.values()) + + list(self.fixed_input_parameters.values()) + [event.expression for event in self.events] ) return list(all_input_parameters) @@ -753,6 +775,7 @@ def _find_symbols_by_submodel(self, typ, submodel): for side in x.keys() ] + list(self._variables_by_submodel[submodel].values()) + + list(self.submodels[submodel].fixed_input_parameters.values()) + [event.expression for event in self.submodels[submodel].events] ) return list(all_input_parameters) @@ -770,6 +793,11 @@ def new_copy(self): new_model._variables = self.variables.copy() new_model._events = self.events.copy() new_model._variables_casadi = self._variables_casadi.copy() + new_model._parameter_values = ( + self._parameter_values.copy() + if self._parameter_values is not None + else None + ) return new_model def update(self, *submodels): @@ -906,7 +934,12 @@ def _build_model(self): self.build_model_equations() def set_initial_conditions_from( - self, solution, inplace=True, return_type="model", mesh=None + self, + solution, + inputs=None, + inplace=True, + return_type="model", + mesh=None, ): """ Update initial conditions with the final states from a Solution object or from @@ -918,6 +951,8 @@ def set_initial_conditions_from( ---------- solution : :class:`pybamm.Solution`, or dict The solution to use to initialize the model + inputs : dict + The dictionary of model input parameters. inplace : bool, optional Whether to modify the model inplace or create a new model (default True) return_type : str, optional @@ -1081,7 +1116,7 @@ def get_variable_state(var): scale, reference = pybamm.Scalar(1), pybamm.Scalar(0) initial_conditions[var] = ( pybamm.Vector(final_state_eval) - reference - ) / scale.evaluate() + ) / scale.evaluate(inputs=inputs) # Also update the concatenated initial conditions if the model is already # discretised diff --git a/src/pybamm/models/full_battery_models/lithium_ion/electrode_soh.py b/src/pybamm/models/full_battery_models/lithium_ion/electrode_soh.py index a9322991b2..b9abb4261f 100644 --- a/src/pybamm/models/full_battery_models/lithium_ion/electrode_soh.py +++ b/src/pybamm/models/full_battery_models/lithium_ion/electrode_soh.py @@ -561,14 +561,14 @@ def _set_up_solve(self, inputs, direction): def _solve_full(self, inputs, ics, direction): sim = self._get_electrode_soh_sims_full(direction) sim.build() - sim.built_model.set_initial_conditions_from(ics) + sim.built_model.set_initial_conditions_from(ics, inputs=inputs) sol = sim.solve([0], inputs=inputs) return sol def _solve_split(self, inputs, ics, direction): x100_sim, x0_sim = self._get_electrode_soh_sims_split(direction) x100_sim.build() - x100_sim.built_model.set_initial_conditions_from(ics) + x100_sim.built_model.set_initial_conditions_from(ics, inputs=inputs) x100_sol = x100_sim.solve([0], inputs=inputs) if self.options["open-circuit potential"] == "MSMR": inputs["Un(x_100)"] = x100_sol["Un(x_100)"].data[0] @@ -577,7 +577,7 @@ def _solve_split(self, inputs, ics, direction): inputs["x_100"] = x100_sol["x_100"].data[0] inputs["y_100"] = x100_sol["y_100"].data[0] x0_sim.build() - x0_sim.built_model.set_initial_conditions_from(ics) + x0_sim.built_model.set_initial_conditions_from(ics, inputs=inputs) x0_sol = x0_sim.solve([0], inputs=inputs) return x0_sol diff --git a/src/pybamm/models/submodels/interface/kinetics/butler_volmer.py b/src/pybamm/models/submodels/interface/kinetics/butler_volmer.py index c3e9ac0fc0..84cee586a3 100644 --- a/src/pybamm/models/submodels/interface/kinetics/butler_volmer.py +++ b/src/pybamm/models/submodels/interface/kinetics/butler_volmer.py @@ -12,7 +12,7 @@ class SymmetricButlerVolmer(BaseKinetics): Submodel which implements the symmetric forward Butler-Volmer equation: .. math:: - j = 2 * j_0(c) * \\sinh(ne * F * \\eta_r(c) / RT) + j = 2 * j_0(c) * \\sinh(ne * F * \\eta_r(c) / 2RT) Parameters ---------- diff --git a/src/pybamm/models/submodels/interface/open_circuit_potential/base_hysteresis_ocp.py b/src/pybamm/models/submodels/interface/open_circuit_potential/base_hysteresis_ocp.py index 934ed5e6cf..cdfc47ac5e 100644 --- a/src/pybamm/models/submodels/interface/open_circuit_potential/base_hysteresis_ocp.py +++ b/src/pybamm/models/submodels/interface/open_circuit_potential/base_hysteresis_ocp.py @@ -84,7 +84,7 @@ def _get_coupled_variables(self, variables): U_eq = self.phase_param.U(sto_surf, T) U_eq_x_av = self.phase_param.U(sto_surf, T) U_lith = self.phase_param.U(sto_surf, T, "lithiation") - U_lith_bulk = self.phase_param.U(sto_bulk, T_bulk) + U_lith_bulk = self.phase_param.U(sto_bulk, T_bulk, "lithiation") U_delith = self.phase_param.U(sto_surf, T, "delithiation") U_delith_bulk = self.phase_param.U(sto_bulk, T_bulk, "delithiation") diff --git a/src/pybamm/parameters/parameter_values.py b/src/pybamm/parameters/parameter_values.py index 2bec326a93..6c6822c26d 100644 --- a/src/pybamm/parameters/parameter_values.py +++ b/src/pybamm/parameters/parameter_values.py @@ -496,6 +496,8 @@ def process_model(self, unprocessed_model, inplace=True): ): raise pybamm.ModelError("Cannot process parameters for empty model") + model.parameter_values = self + new_rhs = {} for variable, equation in unprocessed_model.rhs.items(): pybamm.logger.verbose(f"Processing parameters for {variable!r} (rhs)") @@ -866,12 +868,21 @@ def _process_function_parameter(self, symbol): else: new_children.append(self.process_symbol(child)) - # Get the expression and inputs for the function + # Get the expression and inputs for the function. + # func_args may include arguments that were not explicitly wired up + # in this FunctionParameter (e.g., kwargs with default values). After + # serialisation/deserialisation, we only recover the children that were + # actually connected. + # + # Using strict=True here therefore raises a ValueError when there are + # more args than children. We allow func_args to be longer than + # symbol.children and only build the mapping for the args for which we + # actually have children. expression = function_parameter.child inputs = { arg: child for arg, child in zip( - function_parameter.func_args, symbol.children, strict=True + function_parameter.func_args, symbol.children, strict=False ) } diff --git a/src/pybamm/simulation.py b/src/pybamm/simulation.py index 8f3fcaaad1..08b17905fb 100644 --- a/src/pybamm/simulation.py +++ b/src/pybamm/simulation.py @@ -128,6 +128,7 @@ def __init__( self._model_with_set_params = None self._built_model = None self._built_initial_soc = None + self._built_nominal_capacity = None self.steps_to_built_models = None self.steps_to_built_solvers = None self._mesh = None @@ -163,6 +164,42 @@ def set_up_and_parameterise_experiment(self, solve_kwargs=None): warnings.warn(msg, DeprecationWarning, stacklevel=2) self._set_up_and_parameterise_experiment(solve_kwargs=solve_kwargs) + def _update_experiment_models_for_capacity(self, solve_kwargs=None): + """ + Check if the nominal capacity has changed and update the experiment models + if needed. This re-processes the models without rebuilding the mesh and + discretisation. + """ + current_capacity = self._parameter_values.get( + "Nominal cell capacity [A.h]", None + ) + + if self._built_nominal_capacity == current_capacity: + return + + # Capacity has changed, need to re-process the models + pybamm.logger.info( + f"Nominal capacity changed from {self._built_nominal_capacity} to " + f"{current_capacity}. Re-processing experiment models." + ) + + # Re-parameterise the experiment with the new capacity + self._set_up_and_parameterise_experiment(solve_kwargs) + + # Re-discretise the models + self.steps_to_built_models = {} + self.steps_to_built_solvers = {} + for ( + step, + model_with_set_params, + ) in self.experiment_unique_steps_to_model.items(): + built_model = self._disc.process_model(model_with_set_params, inplace=True) + solver = self._solver.copy() + self.steps_to_built_solvers[step] = solver + self.steps_to_built_models[step] = built_model + + self._built_nominal_capacity = current_capacity + def _set_up_and_parameterise_experiment(self, solve_kwargs=None): """ Create and parameterise the models for each step in the experiment. @@ -266,6 +303,7 @@ def set_initial_state(self, initial_soc, direction=None, inputs=None): # reset self._model_with_set_params = None self._built_model = None + self._built_nominal_capacity = None self.steps_to_built_models = None self.steps_to_built_solvers = None @@ -279,6 +317,7 @@ def set_initial_state(self, initial_soc, direction=None, inputs=None): options=options, inputs=inputs, ) + self._model.parameter_values = self._parameter_values # Save solved initial SOC in case we need to re-build the model self._built_initial_soc = initial_soc @@ -312,7 +351,7 @@ def build(self, initial_soc=None, direction=None, inputs=None): if self._built_model: return - elif self._model.is_discretised: + if self._model.is_discretised: self._model_with_set_params = self._model self._built_model = self._model else: @@ -338,6 +377,8 @@ def build_for_experiment( self.set_initial_state(initial_soc, direction=direction, inputs=inputs) if self.steps_to_built_models: + # Check if we need to update the models due to capacity change + self._update_experiment_models_for_capacity(solve_kwargs) return else: self._set_up_and_parameterise_experiment(solve_kwargs) @@ -366,6 +407,10 @@ def build_for_experiment( self.steps_to_built_solvers[step] = solver self.steps_to_built_models[step] = built_model + self._built_nominal_capacity = self._parameter_values.get( + "Nominal cell capacity [A.h]", None + ) + def solve( self, t_eval=None, @@ -778,7 +823,7 @@ def solve( feasible = False # If none of the cycles worked, raise an error if cycle_num == 1 and step_num == 1: - raise error + raise error from error # Otherwise, just stop this cycle break diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 92c7381f43..7d4709e06b 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -1340,7 +1340,7 @@ def step( model_inputs = self._set_up_model_inputs(model, inputs) # process calculate_sensitivities argument - calculate_sensitivities_list, sensitivities_have_changed = ( + _, sensitivities_have_changed = ( BaseSolver._solve_process_calculate_sensitivities_arg( model_inputs, model, calculate_sensitivities ) @@ -1381,7 +1381,9 @@ def step( else: _, concatenated_initial_conditions = model.set_initial_conditions_from( - old_solution, return_type="ics" + old_solution, + inputs=model_inputs, + return_type="ics", ) model.y0 = concatenated_initial_conditions.evaluate(0, inputs=model_inputs) if using_sensitivities: diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index 4a0d9edb07..147bc0314e 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -76,6 +76,8 @@ class IDAKLUSolver(pybamm.BaseSolver): "increment_factor": 1.0, # Enable or disable linear solution scaling "linear_solution_scaling": True, + # Silence Sundials errors during solve + "silence_sundials_errors": False, ## Main solver # Maximum order of the linear multistep method "max_order_bdf": 5, @@ -176,6 +178,7 @@ def __init__( "epsilon_linear_tolerance": 0.05, "increment_factor": 1.0, "linear_solution_scaling": True, + "silence_sundials_errors": False, "max_order_bdf": 5, "max_num_steps": 100000, "dt_init": 0.0, @@ -681,13 +684,17 @@ def _integrate( atol = self._check_atol_type(atol, y0full.size) timer = pybamm.Timer() - solns = self._setup["solver"].solve( - t_eval, - t_interp, - y0full, - ydot0full, - inputs, - ) + try: + solns = self._setup["solver"].solve( + t_eval, + t_interp, + y0full, + ydot0full, + inputs, + ) + except ValueError as e: + # Return from None to replace the C++ runtime error + raise pybamm.SolverError(str(e)) from None integration_time = timer.time() return [ @@ -734,14 +741,15 @@ def _post_process_solution(self, sol, model, integration_time, inputs_dict): termination = "final time" elif sol.flag < 0: termination = "failure" + msg = idaklu.sundials_error_message(sol.flag) match self._on_failure: case "warn": warnings.warn( - f"FAILURE {self._solver_flag(sol.flag)}, returning a partial solution.", + msg + ", returning a partial solution.", stacklevel=2, ) case "raise": - raise pybamm.SolverError(f"FAILURE {self._solver_flag(sol.flag)}") + raise pybamm.SolverError(msg) if sol.yp.size > 0: yp = sol.yp.reshape((number_of_timesteps, number_of_states)).T @@ -1002,40 +1010,3 @@ def jaxify( t_interp=t_interp, ) return obj - - @staticmethod - def _solver_flag(flag): - flags = { - 99: "IDA_WARNING: IDASolve succeeded but an unusual situation occurred.", - 2: "IDA_ROOT_RETURN: IDASolve succeeded and found one or more roots.", - 1: "IDA_TSTOP_RETURN: IDASolve succeeded by reaching the specified stopping point.", - 0: "IDA_SUCCESS: Successful function return.", - -1: "IDA_TOO_MUCH_WORK: The solver took mxstep internal steps but could not reach tout.", - -2: "IDA_TOO_MUCH_ACC: The solver could not satisfy the accuracy demanded by the user for some internal step.", - -3: "IDA_ERR_FAIL: Error test failures occurred too many times during one internal time step or minimum step size was reached.", - -4: "IDA_CONV_FAIL: Convergence test failures occurred too many times during one internal time step or minimum step size was reached.", - -5: "IDA_LINIT_FAIL: The linear solver's initialization function failed.", - -6: "IDA_LSETUP_FAIL: The linear solver's setup function failed in an unrecoverable manner.", - -7: "IDA_LSOLVE_FAIL: The linear solver's solve function failed in an unrecoverable manner.", - -8: "IDA_RES_FAIL: The user-provided residual function failed in an unrecoverable manner.", - -9: "IDA_REP_RES_FAIL: The user-provided residual function repeatedly returned a recoverable error flag, but the solver was unable to recover.", - -10: "IDA_RTFUNC_FAIL: The rootfinding function failed in an unrecoverable manner.", - -11: "IDA_CONSTR_FAIL: The inequality constraints were violated and the solver was unable to recover.", - -12: "IDA_FIRST_RES_FAIL: The user-provided residual function failed recoverably on the first call.", - -13: "IDA_LINESEARCH_FAIL: The line search failed.", - -14: "IDA_NO_RECOVERY: The residual function, linear solver setup function, or linear solver solve function had a recoverable failure, but IDACalcIC could not recover.", - -15: "IDA_NLS_INIT_FAIL: The nonlinear solver's init routine failed.", - -16: "IDA_NLS_SETUP_FAIL: The nonlinear solver's setup routine failed.", - -20: "IDA_MEM_NULL: The ida mem argument was NULL.", - -21: "IDA_MEM_FAIL: A memory allocation failed.", - -22: "IDA_ILL_INPUT: One of the function inputs is illegal.", - -23: "IDA_NO_MALLOC: The ida memory was not allocated by a call to IDAInit.", - -24: "IDA_BAD_EWT: Zero value of some error weight component.", - -25: "IDA_BAD_K: The k-th derivative is not available.", - -26: "IDA_BAD_T: The time t is outside the last step taken.", - -27: "IDA_BAD_DKY: The vector argument where derivative should be stored is NULL.", - } - - flag_unknown = "Unknown IDA flag." - - return flags.get(flag, flag_unknown) diff --git a/src/pybamm/solvers/solution.py b/src/pybamm/solvers/solution.py index cc5aac4b9a..f729be6ddc 100644 --- a/src/pybamm/solvers/solution.py +++ b/src/pybamm/solvers/solution.py @@ -1,3 +1,5 @@ +from __future__ import annotations + # # Solution class # @@ -12,6 +14,10 @@ from scipy.io import savemat import pybamm +from pybamm.expression_tree.operations.replace_symbols import ( + SymbolReplacer, + VariableReplacementMap, +) class NumpyEncoder(json.JSONEncoder): @@ -137,8 +143,8 @@ def __init__( self.solve_time = None self.integration_time = None - # initialize empty variables and data - self._variables = pybamm.FuzzyDict() + # initialize empty variable cache and data + self._variables = {} self._data = pybamm.FuzzyDict() # Add self as sub-solution for compatibility with ProcessedVariable @@ -226,29 +232,31 @@ def check_ys_are_not_too_large(self): # restraint, so if y gets large in the middle then comes back down that is ok y, model = self.all_ys[-1], self.all_models[-1] y = y[:, -1] - if np.any(y > pybamm.settings.max_y_value): - for var in [*model.rhs.keys(), *model.algebraic.keys()]: - var = model.variables[var.name] - # find the statevector corresponding to this variable - statevector = None - for node in var.pre_order(): - if isinstance(node, pybamm.StateVector): - statevector = node - - # there will always be a statevector, but just in case - if statevector is None: # pragma: no cover - raise RuntimeError( - f"Cannot find statevector corresponding to variable {var.name}" - ) - y_var = y[statevector.y_slices[0]] - if np.any(y_var > pybamm.settings.max_y_value): - pybamm.logger.error( - f"Solution for '{var}' exceeds the maximum allowed value " - f"of `{pybamm.settings.max_y_value}. This could be due to " - "incorrect scaling, model formulation, or " - "parameter values. The maximum allowed value is set by " - "'pybammm.settings.max_y_value'." - ) + if np.max(y, initial=0) <= pybamm.settings.max_y_value: + return + + for var in [*model.rhs.keys(), *model.algebraic.keys()]: + var = model.variables[var.name] + # find the statevector corresponding to this variable + statevector = None + for node in var.pre_order(): + if isinstance(node, pybamm.StateVector): + statevector = node + + # there will always be a statevector, but just in case + if statevector is None: # pragma: no cover + raise RuntimeError( + f"Cannot find statevector corresponding to variable {var.name}" + ) + y_var = y[statevector.y_slices[0]] + if np.any(y_var > pybamm.settings.max_y_value): + pybamm.logger.error( + f"Solution for '{var}' exceeds the maximum allowed value " + f"of `{pybamm.settings.max_y_value}. This could be due to " + "incorrect scaling, model formulation, or " + "parameter values. The maximum allowed value is set by " + "'pybammm.settings.max_y_value'." + ) @property def all_ts(self): @@ -411,7 +419,7 @@ def update_summary_variables(self, all_summary_variables): self, cycle_summary_variables=all_summary_variables ) - def update(self, variables): + def update(self, variables: str | list[str]): """Add ProcessedVariables to the dictionary of variables in the solution""" # Single variable if isinstance(variables, str): @@ -421,59 +429,187 @@ def update(self, variables): for variable in variables: self._update_variable(variable) - def _update_variable(self, variable): + def _update_model_variable( + self, + model: pybamm.BaseModel, + var_pybamm: pybamm.Symbol, + time_integral: pybamm.ProcessedVariableTimeIntegral | None, + inputs: dict, + ys_shape: tuple, + cache_key, + ): + _var_casadi = model._variables_casadi.get(cache_key) + if _var_casadi is not None: + return _var_casadi, var_pybamm, time_integral + + var_casadi, var_pybamm, time_integral = self._convert_to_casadi( + var_pybamm, inputs, ys_shape + ) + + # Only cache if it's not a time integral + if time_integral is None: + model._variables_casadi[cache_key] = var_casadi + return var_casadi, var_pybamm, time_integral + + def _update_variable(self, name: str): time_integral = None - pybamm.logger.debug(f"Post-processing {variable}") - vars_pybamm = [ - model.variables_and_events[variable] for model in self.all_models - ] + pybamm.logger.debug(f"Post-processing {name}") # Iterate through all models, some may be in the list several times and # therefore only get set up once - vars_casadi = [] - for i, (model, ys, inputs, var_pybamm) in enumerate( - zip(self.all_models, self.all_ys, self.all_inputs, vars_pybamm, strict=True) + vars_pybamm = [model.variables_and_events[name] for model in self.all_models] + vars_casadi = [None] * len(self.all_models) + for i, (model, ys, inputs) in enumerate( + zip(self.all_models, self.all_ys, self.all_inputs, strict=True) ): - if self.variables_returned and var_pybamm.has_symbol_of_classes( + _var_pybamm = vars_pybamm[i] + if self.variables_returned and _var_pybamm.has_symbol_of_classes( pybamm.expression_tree.state_vector.StateVector ): raise KeyError( - f"Cannot process variable '{variable}' as it was not part of the " + f"Cannot process variable '{name}' as it was not part of the " "solve. Please re-run the solve with `output_variables` set to " "include this variable." ) - elif variable in model._variables_casadi: - var_casadi = model._variables_casadi[variable] - else: - time_integral = pybamm.ProcessedVariableTimeIntegral.from_pybamm_var( - var_pybamm, self.all_ys[i].shape[0] + var_casadi, var_pybamm, time_integral = self._update_model_variable( + model, + _var_pybamm, + inputs=inputs, + ys_shape=ys.shape, + time_integral=time_integral, + cache_key=name, + ) + vars_pybamm[i] = var_pybamm + vars_casadi[i] = var_casadi + var = pybamm.process_variable( + name, vars_pybamm, vars_casadi, self, time_integral=time_integral + ) + + self._variables[name] = var + + def _update_observe_variable( + self, + symbol: pybamm.Symbol, + name: str, + replace_variables: bool | None = None, + ): + if self.variables_returned: + raise ValueError( + "Cannot use `observe` if the solver includes `output_variables`. " + "Please re-run the simulation without `output_variables`." + ) + if replace_variables is None: + replace_variables = True + symbol_id = symbol.id + + # Use model hashmap to avoid redundant symbol processing + vars_pybamm_map: dict[pybamm.BaseModel, pybamm.Symbol] = {} + + def _process_symbol( + symbol: pybamm.Symbol, model: pybamm.BaseModel + ) -> pybamm.Symbol: + _var_pybamm = vars_pybamm_map.get(model) + if _var_pybamm is not None: + return _var_pybamm + + if replace_variables: + symbol_replacer = SymbolReplacer( + VariableReplacementMap(model.variables_and_events) ) - if time_integral is not None: - vars_pybamm[i] = time_integral.sum_node.child - var_casadi = self.process_casadi_var( - time_integral.sum_node.child, - inputs, - ys.shape, - ) - if time_integral.post_sum_node is not None: - time_integral.post_sum = self.process_casadi_var( - time_integral.post_sum_node, - inputs, - ys.shape, - ) - else: - var_casadi = self.process_casadi_var( - var_pybamm, - inputs, - ys.shape, - ) - model._variables_casadi[variable] = var_casadi - vars_casadi.append(var_casadi) + symbol = symbol_replacer.process_symbol(symbol) + + var_pybamm = model._parameter_values.process_symbol(symbol) + dims = np.prod(var_pybamm.shape) + if dims > 1: + raise ValueError("`observe` currently only supports 0D variables.") + + # Cache the processed symbol for this model + vars_pybamm_map[model] = var_pybamm + return var_pybamm + + time_integral = None + # Iterate through all models, some may be in the list several times and + # therefore only get set up once + vars_casadi = [None] * len(self.all_models) + vars_pybamm = [None] * len(self.all_models) + for i, (model, ys, inputs) in enumerate( + zip(self.all_models, self.all_ys, self.all_inputs, strict=True) + ): + _var_pybamm = _process_symbol(symbol, model) + var_casadi, var_pybamm, time_integral = self._update_model_variable( + model, + _var_pybamm, + inputs=inputs, + ys_shape=ys.shape, + time_integral=time_integral, + cache_key=symbol_id, + ) + vars_pybamm[i] = var_pybamm + vars_casadi[i] = var_casadi var = pybamm.process_variable( - variable, vars_pybamm, vars_casadi, self, time_integral=time_integral + name, vars_pybamm, vars_casadi, self, time_integral=time_integral ) + self._variables[symbol_id] = var + return var - self._variables[variable] = var + def observe( + self, + symbol: pybamm.Symbol, + name: str | None = None, + replace_variables: bool | None = None, + ): + """ + Observe a `pybamm.Symbol` object from the solution. + Note: this currently only supports 0D variables. + + Parameters + ---------- + symbol : pybamm.Symbol + The symbol to observe. + name : str, optional + The name of the variable. If None, the name is the symbol's id. + replace_variables : bool, optional + Whether to replace ``pybamm.Variable`` objects in the symbol with the + discretized variables. Defaults to True. + + Returns + ------- + :class:`pybamm.ProcessedVariable` + The observed variable. + """ + if not isinstance(symbol, pybamm.Symbol): + try: + # Try to convert the input to a pybamm.Symbol + symbol = symbol * pybamm.Scalar(1) + except Exception: + raise ValueError("Input is not a valid PyBaMM symbol") from None + + value = self._variables.get(symbol.id) + if value is None: + name = name if name is not None else str(symbol.id) + value = self._update_observe_variable( + symbol, name, replace_variables=replace_variables + ) + return value + + def _convert_to_casadi(self, var_pybamm, inputs, ys_shape): + time_integral = pybamm.ProcessedVariableTimeIntegral.from_pybamm_var( + var_pybamm, ys_shape[0] + ) + if time_integral is not None: + var_pybamm = time_integral.sum_node.child + if time_integral.post_sum_node is not None: + time_integral.post_sum = self.process_casadi_var( + time_integral.post_sum_node, + inputs, + ys_shape, + ) + var_casadi = self.process_casadi_var( + var_pybamm, + inputs, + ys_shape, + ) + return var_casadi, var_pybamm, time_integral def process_casadi_var(self, var_pybamm, inputs, ys_shape): t_MX = casadi.MX.sym("t") @@ -533,14 +669,13 @@ def __getitem__(self, key): A variable that can be evaluated at any time or spatial point. The underlying data for this variable is available in its attribute ".data" """ + value = self._variables.get(key) + if value is not None: + return value - # return it if it exists - if key in self._variables: - return self._variables[key] - else: - # otherwise create it, save it and then return it - self.update(key) - return self._variables[key] + # otherwise create it, save it and then return it + self.update(key) + return self._variables[key] def plot(self, output_variables=None, **kwargs): """ diff --git a/tests/unit/test_experiments/test_simulation_with_experiment.py b/tests/unit/test_experiments/test_simulation_with_experiment.py index 2b62b83db3..3c5099dde7 100644 --- a/tests/unit/test_experiments/test_simulation_with_experiment.py +++ b/tests/unit/test_experiments/test_simulation_with_experiment.py @@ -1,3 +1,4 @@ +import logging import os from datetime import datetime @@ -1022,3 +1023,125 @@ def neg_stoich_cutoff(variables): neg_stoich = sol["Negative electrode stoichiometry"].data assert neg_stoich[-1] == pytest.approx(0.5, abs=0.0001) + + def test_simulation_changing_capacity_crate_steps(self): + """Test that C-rate steps are correctly updated when capacity changes""" + model = pybamm.lithium_ion.SPM() + experiment = pybamm.Experiment( + [ + ( + "Discharge at C/5 for 20 minutes", + "Discharge at C/2 for 20 minutes", + "Discharge at 1C for 20 minutes", + ) + ] + ) + param = pybamm.ParameterValues("Chen2020") + sim = pybamm.Simulation(model, experiment=experiment, parameter_values=param) + + # First solve + sol1 = sim.solve(calc_esoh=False) + original_capacity = param["Nominal cell capacity [A.h]"] + + # Check that C-rates correspond to expected currents + I_C5_1 = np.abs(sol1.cycles[0].steps[0]["Current [A]"].data).mean() + I_C2_1 = np.abs(sol1.cycles[0].steps[1]["Current [A]"].data).mean() + I_1C_1 = np.abs(sol1.cycles[0].steps[2]["Current [A]"].data).mean() + + np.testing.assert_allclose(I_C5_1, original_capacity / 5, rtol=1e-2) + np.testing.assert_allclose(I_C2_1, original_capacity / 2, rtol=1e-2) + np.testing.assert_allclose(I_1C_1, original_capacity, rtol=1e-2) + + # Update capacity + new_capacity = 0.9 * original_capacity + sim._parameter_values.update({"Nominal cell capacity [A.h]": new_capacity}) + + # Second solve with updated capacity + sol2 = sim.solve(calc_esoh=False) + + # Check that C-rates now correspond to updated currents + I_C5_2 = np.abs(sol2.cycles[0].steps[0]["Current [A]"].data).mean() + I_C2_2 = np.abs(sol2.cycles[0].steps[1]["Current [A]"].data).mean() + I_1C_2 = np.abs(sol2.cycles[0].steps[2]["Current [A]"].data).mean() + + np.testing.assert_allclose(I_C5_2, new_capacity / 5, rtol=1e-2) + np.testing.assert_allclose(I_C2_2, new_capacity / 2, rtol=1e-2) + np.testing.assert_allclose(I_1C_2, new_capacity, rtol=1e-2) + + # Verify all currents scaled proportionally + np.testing.assert_allclose(I_C5_2 / I_C5_1, 0.9, rtol=1e-2) + np.testing.assert_allclose(I_C2_2 / I_C2_1, 0.9, rtol=1e-2) + np.testing.assert_allclose(I_1C_2 / I_1C_1, 0.9, rtol=1e-2) + + def test_simulation_multiple_cycles_with_capacity_change(self): + """Test capacity changes across multiple experiment cycles""" + model = pybamm.lithium_ion.SPM() + experiment = pybamm.Experiment( + [("Discharge at 1C for 5 minutes", "Charge at 1C for 5 minutes")] * 2 + ) + param = pybamm.ParameterValues("Chen2020") + sim = pybamm.Simulation(model, experiment=experiment, parameter_values=param) + + # First solve + sol1 = sim.solve(calc_esoh=False) + original_capacity = param["Nominal cell capacity [A.h]"] + + # Get discharge currents for both cycles + I_discharge_cycle1 = np.abs(sol1.cycles[0].steps[0]["Current [A]"].data).mean() + I_discharge_cycle2 = np.abs(sol1.cycles[1].steps[0]["Current [A]"].data).mean() + + # Both cycles should use the same capacity initially + np.testing.assert_allclose(I_discharge_cycle1, original_capacity, rtol=1e-2) + np.testing.assert_allclose(I_discharge_cycle2, original_capacity, rtol=1e-2) + + # Update capacity between cycles + new_capacity = 0.85 * original_capacity + sim._parameter_values.update({"Nominal cell capacity [A.h]": new_capacity}) + + # Solve again + sol2 = sim.solve(calc_esoh=False) + + # All cycles in the new solution should use updated capacity + I_discharge_cycle1_new = np.abs( + sol2.cycles[0].steps[0]["Current [A]"].data + ).mean() + I_discharge_cycle2_new = np.abs( + sol2.cycles[1].steps[0]["Current [A]"].data + ).mean() + + np.testing.assert_allclose(I_discharge_cycle1_new, new_capacity, rtol=1e-2) + np.testing.assert_allclose(I_discharge_cycle2_new, new_capacity, rtol=1e-2) + + def test_simulation_logging_with_capacity_change(self, caplog): + """Test that capacity changes are logged appropriately""" + model = pybamm.lithium_ion.SPM() + experiment = pybamm.Experiment([("Discharge at 1C for 10 minutes",)]) + param = pybamm.ParameterValues("Chen2020") + sim = pybamm.Simulation(model, experiment=experiment, parameter_values=param) + + # First solve + sim.solve(calc_esoh=False) + original_capacity = param["Nominal cell capacity [A.h]"] + + # Update capacity + new_capacity = 0.75 * original_capacity + sim._parameter_values.update({"Nominal cell capacity [A.h]": new_capacity}) + + # Set logging level to capture INFO messages + original_log_level = pybamm.logger.level + pybamm.set_logging_level("INFO") + + try: + # Second solve should log capacity change + with caplog.at_level(logging.INFO, logger="pybamm.logger"): + sim.solve(calc_esoh=False) + + # Check that a log message about capacity change was recorded + log_messages = [record.message for record in caplog.records] + capacity_change_logged = any( + "Nominal capacity changed" in msg for msg in log_messages + ) + assert capacity_change_logged + finally: + # Restore original logging level + pybamm.logger.setLevel(original_log_level) diff --git a/tests/unit/test_expression_tree/test_replace_symbols.py b/tests/unit/test_expression_tree/test_replace_symbols.py new file mode 100644 index 0000000000..edf64e8721 --- /dev/null +++ b/tests/unit/test_expression_tree/test_replace_symbols.py @@ -0,0 +1,129 @@ +import pybamm +from pybamm.expression_tree.operations.replace_symbols import ( + SymbolReplacer, + VariableReplacementMap, +) + + +def test_symbol_replacements(): + a = pybamm.Parameter("a") + b = pybamm.Parameter("b") + c = pybamm.Parameter("c") + d = pybamm.Parameter("d") + replacer = SymbolReplacer({a: b, c: d}) + + for symbol_in, symbol_out in [ + (a, b), # just the symbol + (a + a, b + b), # binary operator + (2 * pybamm.sin(a), 2 * pybamm.sin(b)), # function + (3 * b, 3 * b), # no replacement + (a + c, b + d), # two replacements + ]: + replaced_symbol = replacer.process_symbol(symbol_in) + assert replaced_symbol == symbol_out + + var1 = pybamm.Variable("var 1", domain="dom 1") + var2 = pybamm.Variable("var 2", domain="dom 2") + var3 = pybamm.Variable("var 3", domain="dom 1") + conc = pybamm.concatenation(var1, var2) + + replacer = SymbolReplacer({var1: var3}) + replaced_symbol = replacer.process_symbol(conc) + assert replaced_symbol == pybamm.concatenation(var3, var2) + + +def test_process_model(): + model = pybamm.BaseModel() + a = pybamm.Parameter("a") + b = pybamm.Parameter("b") + c = pybamm.Parameter("c") + d = pybamm.Parameter("d") + var1 = pybamm.Variable("var1", domain="test") + var2 = pybamm.Variable("var2", domain="test") + model.rhs = {var1: a * pybamm.grad(var1)} + model.algebraic = {var2: c * var2} + model.initial_conditions = {var1: b, var2: d} + model.boundary_conditions = { + var1: {"left": (c, "Dirichlet"), "right": (d, "Neumann")} + } + model.variables = { + "var1": var1, + "var2": var2, + "grad_var1": pybamm.grad(var1), + "d_var1": d * var1, + } + + replacer = SymbolReplacer( + { + pybamm.Parameter("a"): pybamm.Scalar(4), + pybamm.Parameter("b"): pybamm.Scalar(2), + pybamm.Parameter("c"): pybamm.Scalar(3), + pybamm.Parameter("d"): pybamm.Scalar(42), + } + ) + replacer.process_model(model) + # rhs + var1 = model.variables["var1"] + assert isinstance(model.rhs[var1], pybamm.Multiplication) + assert isinstance(model.rhs[var1].children[0], pybamm.Scalar) + assert isinstance(model.rhs[var1].children[1], pybamm.Gradient) + assert model.rhs[var1].children[0].value == 4 + # algebraic + var2 = model.variables["var2"] + assert isinstance(model.algebraic[var2], pybamm.Multiplication) + assert isinstance(model.algebraic[var2].children[0], pybamm.Scalar) + assert isinstance(model.algebraic[var2].children[1], pybamm.Variable) + assert model.algebraic[var2].children[0].value == 3 + # initial conditions + assert isinstance(model.initial_conditions[var1], pybamm.Scalar) + assert model.initial_conditions[var1].value == 2 + # boundary conditions + bc_key = next(iter(model.boundary_conditions.keys())) + assert isinstance(bc_key, pybamm.Variable) + bc_value = next(iter(model.boundary_conditions.values())) + assert isinstance(bc_value["left"][0], pybamm.Scalar) + assert bc_value["left"][0].value == 3 + assert isinstance(bc_value["right"][0], pybamm.Scalar) + assert bc_value["right"][0].value == 42 + # variables + assert model.variables["var1"] == var1 + assert isinstance(model.variables["grad_var1"], pybamm.Gradient) + assert isinstance(model.variables["grad_var1"].children[0], pybamm.Variable) + assert model.variables["d_var1"] == (pybamm.Scalar(42) * var1) + assert isinstance(model.variables["d_var1"].children[0], pybamm.Scalar) + assert isinstance(model.variables["d_var1"].children[1], pybamm.Variable) + + +def test_variable_replacement_map(): + var1 = pybamm.Variable("Voltage [V]") + var2 = pybamm.Variable("Current [A]") + replacement1 = pybamm.Scalar(3.7) + replacement2 = pybamm.Parameter("I") + + replacement_map = VariableReplacementMap( + { + "Voltage [V]": replacement1, + "Current [A]": replacement2, + } + ) + + # Test __getitem__ + assert replacement_map[var1] is replacement1 + assert replacement_map[var2] is replacement2 + + # Test __contains__ + assert var1 in replacement_map + assert var2 in replacement_map + assert pybamm.Variable("Other [V]") not in replacement_map + + # Test get method + assert replacement_map.get(var1) == replacement1 + assert replacement_map.get(var2) == replacement2 + assert replacement_map.get(pybamm.Variable("Other [V]")) is None + assert replacement_map.get( + pybamm.Variable("Other [V]"), default=pybamm.Scalar(0) + ) == pybamm.Scalar(0) + + # Test that non-Variable symbols return default + assert replacement_map.get(pybamm.Parameter("a")) is None + assert replacement_map.get(pybamm.Scalar(1)) is None diff --git a/tests/unit/test_meshes/test_meshes.py b/tests/unit/test_meshes/test_meshes.py index 8c26a0900f..15ef8317ba 100644 --- a/tests/unit/test_meshes/test_meshes.py +++ b/tests/unit/test_meshes/test_meshes.py @@ -584,6 +584,37 @@ def test_to_json(self): assert mesh_json == expected_json + def test_compute_var_pts_from_thicknesses_cell_size(self): + from pybamm.meshes.meshes import compute_var_pts_from_thicknesses + + electrode_thicknesses = { + "negative electrode": 100e-6, + "separator": 25e-6, + "positive electrode": 100e-6, + } + + cell_size = 5e-6 # 5 micrometres per cell + var_pts = compute_var_pts_from_thicknesses(electrode_thicknesses, cell_size) + + assert isinstance(var_pts, dict) + assert all(isinstance(v, dict) for v in var_pts.values()) + assert var_pts["negative electrode"]["x_n"] == 20 + assert var_pts["separator"]["x_s"] == 5 + assert var_pts["positive electrode"]["x_p"] == 20 + + def test_compute_var_pts_from_thicknesses_invalid_thickness_type(self): + from pybamm.meshes.meshes import compute_var_pts_from_thicknesses + + with pytest.raises(TypeError): + compute_var_pts_from_thicknesses(["not", "a", "dict"], 1e-6) + + def test_compute_var_pts_from_thicknesses_invalid_grid_size(self): + from pybamm.meshes.meshes import compute_var_pts_from_thicknesses + + electrode_thicknesses = {"negative electrode": 100e-6} + with pytest.raises(ValueError): + compute_var_pts_from_thicknesses(electrode_thicknesses, -1e-6) + class TestMeshGenerator: def test_init_name(self): diff --git a/tests/unit/test_parameters/test_parameter_values.py b/tests/unit/test_parameters/test_parameter_values.py index 769705cf51..83947b8298 100644 --- a/tests/unit/test_parameters/test_parameter_values.py +++ b/tests/unit/test_parameters/test_parameter_values.py @@ -1018,6 +1018,25 @@ def test_process_model(self): assert isinstance(model.variables["d_var1"].children[0], pybamm.Scalar) assert isinstance(model.variables["d_var1"].children[1], pybamm.Variable) + # Check fixed_input_parameters - should be empty when no InputParameters + assert hasattr(model, "fixed_input_parameters") + assert model.fixed_input_parameters == {} + + # Test with InputParameters + model2 = pybamm.BaseModel() + input_param1 = pybamm.InputParameter("input1") + input_param2 = pybamm.InputParameter("input2") + model2.rhs = {var1: a * var1} + parameter_values2 = pybamm.ParameterValues( + {"a": 1, "input1": input_param1, "input2": input_param2} + ) + parameter_values2.process_model(model2) + assert hasattr(model2, "fixed_input_parameters") + assert "input1" in model2.fixed_input_parameters + assert "input2" in model2.fixed_input_parameters + assert model2.fixed_input_parameters["input1"] == input_param1 + assert model2.fixed_input_parameters["input2"] == input_param2 + # bad boundary conditions model = pybamm.BaseModel() model.algebraic = {var1: var1} @@ -1298,6 +1317,31 @@ def test_to_json_with_filename(self): finally: os.remove(temp_path) + def test_roundtrip_with_keyword_args(self): + def func_no_kwargs(x): + return 2 * x + + def func_with_kwargs(x, y=1): + return 2 * x + + x = pybamm.Scalar(2) + func_param = pybamm.FunctionParameter("func", {"x": x}) + + parameter_values = pybamm.ParameterValues({"func": func_no_kwargs}) + assert parameter_values.evaluate(func_param) == 4.0 + + serialized = parameter_values.to_json() + parameter_values_loaded = pybamm.ParameterValues.from_json(serialized) + assert parameter_values_loaded.evaluate(func_param) == 4.0 + + parameter_values = pybamm.ParameterValues({"func": func_with_kwargs}) + assert parameter_values.evaluate(func_param) == 4.0 + + serialized = parameter_values.to_json() + parameter_values_loaded = pybamm.ParameterValues.from_json(serialized) + + assert parameter_values_loaded.evaluate(func_param) == 4.0 + def test_convert_symbols_in_dict_with_interpolator(self): """Test convert_symbols_in_dict with interpolator (covers lines 1154-1170).""" import numpy as np diff --git a/tests/unit/test_serialisation/test_serialisation.py b/tests/unit/test_serialisation/test_serialisation.py index d6ddac10e4..d9023a8e54 100644 --- a/tests/unit/test_serialisation/test_serialisation.py +++ b/tests/unit/test_serialisation/test_serialisation.py @@ -617,6 +617,14 @@ def test_serialise_time(self): t2 = convert_symbol_from_json(j) assert isinstance(t2, pybamm.Time) + def test_serialise_input_parameter(self): + """Test InputParameter serialization and deserialization.""" + ip = pybamm.InputParameter("test_param") + j = convert_symbol_to_json(ip) + ip_restored = convert_symbol_from_json(j) + assert isinstance(ip_restored, pybamm.InputParameter) + assert ip_restored.name == "test_param" + def test_convert_symbol_to_json_with_number_and_list(self): for val in (0, 3.14, -7, True): out = convert_symbol_to_json(val) @@ -1337,6 +1345,7 @@ def test_variable_conversion_failure(self, tmp_path): "boundary_conditions": [], "events": [], "variables": {"Bad Variable": {"bad": "structure"}}, + "fixed_input_parameters": {}, }, } @@ -2292,6 +2301,118 @@ def test_load_custom_model_from_dict(self): assert loaded_model.name == "test_dict_model" assert isinstance(loaded_model.rhs, dict) + def test_parameter_values_serialisation(self, tmp_path): + """Test serialization and deserialization of parameter_values.""" + # Create a model with parameter_values + model = pybamm.BaseModel(name="test_param_values_model") + a = pybamm.Variable("a", domain="electrode") + model.rhs = {a: pybamm.Scalar(1) * a} + model.initial_conditions = {a: pybamm.Scalar(1)} + model.algebraic = {} + model.boundary_conditions = {a: {"left": (pybamm.Scalar(0), "Dirichlet")}} + model.events = [] + model.variables = {"a": a} + + # Set parameter_values using ParameterValues.process_model + param_values = pybamm.ParameterValues( + { + "param1": 10.0, + "param2": 20.0, + "input_param": pybamm.InputParameter("input1"), + } + ) + param_values.process_model(model) + + # Verify parameter_values are set + assert hasattr(model, "_parameter_values") + assert model._parameter_values is not None + assert "param1" in model._parameter_values + assert model._parameter_values["param1"] == 10.0 + + # Serialize the model + file_path = tmp_path / "test_param_values.json" + Serialise.save_custom_model(model, filename=str(file_path)) + assert file_path.exists() + + # Load the model back + loaded_model = Serialise.load_custom_model(str(file_path)) + + # Verify parameter_values are correctly loaded + assert hasattr(loaded_model, "_parameter_values") + assert loaded_model._parameter_values is not None + assert "param1" in loaded_model._parameter_values + assert loaded_model._parameter_values["param1"] == 10.0 + assert "param2" in loaded_model._parameter_values + assert loaded_model._parameter_values["param2"] == 20.0 + assert "input_param" in loaded_model._parameter_values + assert isinstance( + loaded_model._parameter_values["input_param"], pybamm.InputParameter + ) + + def test_parameter_values_none(self, tmp_path): + """Test serialization when parameter_values is None.""" + # Create a model without parameter_values + model = pybamm.BaseModel(name="test_no_param_values") + a = pybamm.Variable("a", domain="electrode") + model.rhs = {a: pybamm.Scalar(1) * a} + model.initial_conditions = {a: pybamm.Scalar(1)} + model.algebraic = {} + model.boundary_conditions = {a: {"left": (pybamm.Scalar(0), "Dirichlet")}} + model.events = [] + model.variables = {"a": a} + model._parameter_values = None + + # Serialize and load + file_path = tmp_path / "test_no_param_values.json" + Serialise.save_custom_model(model, filename=str(file_path)) + loaded_model = Serialise.load_custom_model(str(file_path)) + + # Verify parameter_values is None + assert hasattr(loaded_model, "_parameter_values") + assert loaded_model._parameter_values is None + + def test_parameter_values_serialise_from_dict(self): + """Test serialization of parameter_values when serializing to dict.""" + # Create a model with parameter_values + model = pybamm.BaseModel(name="test_param_values_dict") + a = pybamm.Variable("a", domain="electrode") + model.rhs = {a: pybamm.Scalar(1) * a} + model.initial_conditions = {a: pybamm.Scalar(1)} + model.algebraic = {} + model.boundary_conditions = {a: {"left": (pybamm.Scalar(0), "Dirichlet")}} + model.events = [] + model.variables = {"a": a} + + param_values = pybamm.ParameterValues( + { + "test_param": 42.0, + "test_input": pybamm.InputParameter("input1"), + } + ) + param_values.process_model(model) + + # Serialize to dict + model_json = Serialise.serialise_custom_model(model) + + # Verify the JSON structure + assert "parameter_values" in model_json["model"] + assert model_json["model"]["parameter_values"] is not None + assert "test_param" in model_json["model"]["parameter_values"] + assert model_json["model"]["parameter_values"]["test_param"] == 42.0 + + # Load from dict + loaded_model = Serialise.load_custom_model(model_json) + + # Verify it loaded correctly + assert hasattr(loaded_model, "_parameter_values") + assert loaded_model._parameter_values is not None + assert "test_param" in loaded_model._parameter_values + assert loaded_model._parameter_values["test_param"] == 42.0 + assert "test_input" in loaded_model._parameter_values + assert isinstance( + loaded_model._parameter_values["test_input"], pybamm.InputParameter + ) + def test_expression_function_parameter_evaluate(self): """Test _unary_evaluate method of ExpressionFunctionParameter.""" x = pybamm.Variable("x") diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index a86e18b21d..5b55e5e3d8 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -583,7 +583,7 @@ def test_failures(self): solver = pybamm.IDAKLUSolver() t_eval = [0, 3] - with pytest.raises(ValueError): + with pytest.raises(pybamm.SolverError): solver.solve(model, t_eval) def test_dae_solver_algebraic_model(self): @@ -779,7 +779,7 @@ def test_solver_options(self): options = {option: options_fail[option]} solver = pybamm.IDAKLUSolver(options=options) - with pytest.raises(ValueError): + with pytest.raises(pybamm.SolverError): solver.solve(model, t_eval) def test_with_output_variables(self): @@ -1487,7 +1487,7 @@ def test_on_failure_option(self): model, t_eval=t_eval, t_interp=t_interp, inputs=input_parameters ) assert len(w) > 0 - assert "FAILURE" in str(w[0].message) + assert "_FAIL" in str(w[0].message) def test_no_progress_early_termination(self): # SPM at rest diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 3977c1724a..46c58b2018 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -731,3 +731,94 @@ def test_explicit_time_integral(self, solver_class, use_post_sum, use_output_var atol=1e-2, ) assert isinstance(sol["integral"].sensitivities["a"], np.ndarray) + + def test_observe(self): + """Test the observe method with pybamm symbols, comparing with model variables.""" + # Set up a simple model + model = pybamm.lithium_ion.SPM() + parameter_values = pybamm.ParameterValues("Chen2020") + + # Solve the model + sim = pybamm.Simulation(model, parameter_values=parameter_values) + sol = sim.solve([0, 3600]) + + # Test observing "Voltage [V]" symbol - should match exactly with model variable + voltage_symbol = pybamm.Variable("Voltage [V]") + observed_voltage = sol.observe(voltage_symbol) + + # Compare with the actual variable from solution + actual_voltage = sol["Voltage [V]"] + + # They should match exactly + np.testing.assert_array_equal(observed_voltage.data, actual_voltage.data) + np.testing.assert_array_equal(observed_voltage.entries, actual_voltage.entries) + + # Test with "Current [A]" - another model variable + current_symbol = pybamm.Variable("Current [A]") + observed_current = sol.observe(current_symbol) + actual_current = sol["Current [A]"] + np.testing.assert_array_equal(observed_current.data, actual_current.data) + np.testing.assert_array_equal(observed_current.entries, actual_current.entries) + + # Test that observe returns a ProcessedVariable + assert isinstance(observed_voltage, pybamm.ProcessedVariable) + assert isinstance(observed_current, pybamm.ProcessedVariable) + + # Test that we can call observe multiple times and get the same result + observed_voltage2 = sol.observe(voltage_symbol) + np.testing.assert_array_equal(observed_voltage2.data, observed_voltage.data) + + # Test that the cache works - verify it's the same object (not just equal) + observed_voltage3 = sol.observe(voltage_symbol) + assert observed_voltage3 is observed_voltage # Should be the same cached object + + # Test that observing with a different name still uses cache if symbol.id is the same + observed_voltage4 = sol.observe(voltage_symbol, name="DifferentName") + assert ( + observed_voltage4 is observed_voltage + ) # Should use cache based on symbol.id + + def test_observe_failure(self): + """Test that observe raises an error if the solver includes `output_variables`.""" + model = pybamm.lithium_ion.SPM() + parameter_values = pybamm.ParameterValues("Chen2020") + sim = pybamm.Simulation(model, parameter_values=parameter_values) + sol = sim.solve([0, 3600]) + + c = pybamm.Variable("Positive particle concentration [mol.m-3]") + with pytest.raises( + ValueError, match="`observe` currently only supports 0D variables" + ): + sol.observe(c) + + # test that `output_variables` are unsupported + solver = pybamm.IDAKLUSolver(output_variables=["Voltage [V]"]) + sim = pybamm.Simulation(model, parameter_values=parameter_values, solver=solver) + sol = sim.solve([0, 3600]) + + with pytest.raises( + ValueError, + match="Cannot use `observe` if the solver includes `output_variables`. Please re-run the simulation without `output_variables`.", + ): + sol.observe(pybamm.Variable("Voltage [V]")) + + with pytest.raises(ValueError, match="Input is not a valid PyBaMM symbol"): + sol.observe(None) + + def test_observe_with_numeric_inputs(self): + """Test that observe works with numeric inputs like 0, which get converted to symbols.""" + # Set up a simple model + model = pybamm.lithium_ion.SPM() + sim = pybamm.Simulation(model) + + sol = sim.solve([0, 1]) + + # Test observing a scalar value (0) - should convert to pybamm.Scalar(0) + observed_zero = sol.observe(0) + assert isinstance(observed_zero, pybamm.ProcessedVariable) + # Should be a constant array of zeros + np.testing.assert_array_equal(observed_zero.data, np.zeros(len(sol.t))) + + # Test that numeric inputs are cached correctly + observed_zero2 = sol.observe(0) + assert observed_zero2 is observed_zero # Should be cached