Skip to content

Commit 67d3e78

Browse files
rtimmsclaudepre-commit-ci[bot]
authored
Preserve custom variables and events in built-in model to_config (#5411)
* Preserve custom variables and events in built-in model to_config When users modify a built-in model (add/remove variables, change events), to_config() now detects these changes and includes them as optional override keys (extra_variables, removed_variables, events) in the compact config format. from_config() applies them after constructing the fresh model. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * style: pre-commit fixes * update changelog --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 07556c9 commit 67d3e78

File tree

4 files changed

+265
-2
lines changed

4 files changed

+265
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Features
44

55
- Improved the performance of processed variables by replacing `casadi.vertcat` input stacking with numpy vectors. ([#5413](https://github.com/pybamm-team/PyBaMM/pull/5413))
6+
- Preserve custom variables and events in built-in model to_config ([#5411](https://github.com/pybamm-team/PyBaMM/pull/5411))
67
- Allow out of bounds initial state of charge to enable initialising a simulation at a voltage outside the voltage limits. ([#5386](https://github.com/pybamm-team/PyBaMM/pull/5386))
78
- Added `cache_esoh` option to `Simulation` that caches the electrode SOH computation across repeated `solve` calls, avoiding redundant recalculation when eSOH-relevant parameters have not changed. The cached eSOH solver/simulation object is also reused on cache misses to skip expensive model rebuilding. ([#5408](https://github.com/pybamm-team/PyBaMM/pull/5408))
89
- Eliminated the mass matrix inverse and temporary dense matrix objects when building the simulation. ([#5391](https://github.com/pybamm-team/PyBaMM/pull/5391))

docs/source/api/models/submodels/electrolyte_conductivity/full_conductivity.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,3 @@ Full Model
33

44
.. autoclass:: pybamm.electrolyte_conductivity.Full
55
:members:
6-
:inherited-members:

src/pybamm/models/base_model.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import pybamm
1717
from pybamm.expression_tree.operations.serialise import Serialise
18+
from pybamm.models.event import EventType
1819
from pybamm.models.symbol_processor import SymbolProcessor
1920

2021

@@ -2041,6 +2042,82 @@ def _find_builtin_module(cls):
20412042
return mod_name
20422043
return None
20432044

2045+
def serialise_builtin_overrides(self, model_config: dict) -> None:
2046+
"""Detect and serialise user modifications to a built-in model.
2047+
2048+
Compares the current model's ``variables`` keys and ``events``
2049+
against a freshly-constructed reference model (same class and
2050+
options). Any differences are written into *model_config* using
2051+
the following optional keys:
2052+
2053+
``custom_variables``
2054+
``dict[str, json]`` – variables the user *added* (keys not
2055+
present in the reference model). Each value is the symbolic
2056+
expression serialised via
2057+
:func:`convert_symbol_to_json`.
2058+
``removed_variables``
2059+
``list[str]`` – variable names that exist in the reference
2060+
model but have been deleted by the user.
2061+
``events``
2062+
``list[dict]`` – full serialised events list, included
2063+
**only** when the events differ from the reference model
2064+
(by count or by name set).
2065+
2066+
If the model is unmodified, no extra keys are added and the
2067+
config stays identical to the previous compact format.
2068+
2069+
.. note::
2070+
2071+
Only added/removed variable *keys* are tracked. Overwriting
2072+
an existing variable's expression with a new one is **not**
2073+
detected in this version.
2074+
2075+
Parameters
2076+
----------
2077+
model_config : dict
2078+
The compact config dict (mutated in-place).
2079+
"""
2080+
from pybamm.expression_tree.operations.serialise import (
2081+
convert_symbol_to_json,
2082+
)
2083+
2084+
# Build a pristine reference with the same options.
2085+
model_cls = type(self)
2086+
options = model_config.get("options", {})
2087+
ref = model_cls(options=options) if options else model_cls()
2088+
2089+
# --- Variables diff (added / removed keys only) ---
2090+
current_keys = set(self.variables.keys())
2091+
ref_keys = set(ref.variables.keys())
2092+
2093+
added = current_keys - ref_keys
2094+
removed = ref_keys - current_keys
2095+
2096+
if added:
2097+
model_config["custom_variables"] = {
2098+
name: convert_symbol_to_json(self.variables[name])
2099+
for name in sorted(added)
2100+
}
2101+
if removed:
2102+
model_config["removed_variables"] = sorted(removed)
2103+
2104+
# --- Events diff (full replacement when different) ---
2105+
current_event_names = {e.name for e in self.events}
2106+
ref_event_names = {e.name for e in ref.events}
2107+
2108+
if (
2109+
len(self.events) != len(ref.events)
2110+
or current_event_names != ref_event_names
2111+
):
2112+
model_config["events"] = [
2113+
{
2114+
"name": event.name,
2115+
"expression": convert_symbol_to_json(event.expression),
2116+
"event_type": event.event_type.value,
2117+
}
2118+
for event in self.events
2119+
]
2120+
20442121
def to_config(
20452122
self,
20462123
filename: str | Path | None = None,
@@ -2082,6 +2159,13 @@ def to_config(
20822159
}
20832160
if hasattr(self, "options"):
20842161
model_config["options"] = dict(self.options)
2162+
2163+
# Detect user modifications to variables and events.
2164+
# A fresh reference model is instantiated with the same options
2165+
# so we can diff against it. Only added/removed variable *keys*
2166+
# are tracked; overwriting an existing variable's expression is
2167+
# NOT detected (out of scope for v1).
2168+
self.serialise_builtin_overrides(model_config)
20852169
else:
20862170
# Custom / user-defined model — full serialised format
20872171
model_config = {
@@ -2212,11 +2296,60 @@ def from_config(config: str | dict) -> BaseModel:
22122296
k: tuple(v) if isinstance(v, list) else v
22132297
for k, v in options.items()
22142298
}
2215-
return model_cls(options=options) if options else model_cls()
2299+
model = model_cls(options=options) if options else model_cls()
2300+
BaseModel.apply_builtin_overrides(model, data)
2301+
return model
22162302

22172303
# Fallback: raw to_json dict (no "type" key)
22182304
return Serialise.load_custom_model(data)
22192305

2306+
@staticmethod
2307+
def apply_builtin_overrides(model: BaseModel, data: dict) -> None:
2308+
"""Apply variable and event overrides from a built-in config.
2309+
2310+
This is the inverse of
2311+
:meth:`serialise_builtin_overrides`. It mutates *model*
2312+
in-place.
2313+
2314+
Parameters
2315+
----------
2316+
model : BaseModel
2317+
A freshly-constructed built-in model.
2318+
data : dict
2319+
The config dictionary, which may contain
2320+
``custom_variables``, ``removed_variables``, and/or
2321+
``events``.
2322+
"""
2323+
from pybamm.expression_tree.operations.serialise import (
2324+
convert_symbol_from_json,
2325+
)
2326+
2327+
# --- Custom variables ---
2328+
extra = data.get("custom_variables")
2329+
if extra:
2330+
for name, expr_json in extra.items():
2331+
model.variables[name] = convert_symbol_from_json(expr_json)
2332+
2333+
# --- Removed variables ---
2334+
removed = data.get("removed_variables")
2335+
if removed:
2336+
for name in removed:
2337+
model.variables.pop(name, None)
2338+
2339+
# --- Events override ---
2340+
events_data = data.get("events")
2341+
if events_data is not None:
2342+
model.events = [
2343+
pybamm.Event(
2344+
e["name"],
2345+
convert_symbol_from_json(e["expression"]),
2346+
EventType(e["event_type"])
2347+
if isinstance(e["event_type"], int)
2348+
else e["event_type"],
2349+
)
2350+
for e in events_data
2351+
]
2352+
22202353

22212354
def load_model(filename, battery_model: BaseModel | None = None):
22222355
"""

tests/unit/test_models/test_model_to_json.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,133 @@ def test_to_config_builtin_file_round_trip(self, tmp_path):
369369
assert file_path.exists()
370370
loaded = pybamm.BaseModel.from_config(str(file_path))
371371
assert type(loaded) is type(model)
372+
373+
# ---- Built-in model override tests ----
374+
375+
def test_to_config_builtin_unmodified_has_no_overrides(self):
376+
"""Unmodified built-in model config has no override keys."""
377+
model = pybamm.lithium_ion.SPM()
378+
config = model.to_config()
379+
assert config["type"] == "SPM"
380+
assert "custom_variables" not in config
381+
assert "removed_variables" not in config
382+
assert "events" not in config
383+
384+
def test_to_config_builtin_with_extra_variable_round_trip(self):
385+
"""Added variable survives to_config / from_config round-trip."""
386+
model = pybamm.lithium_ion.SPM()
387+
model.variables["My variable"] = 2 * model.variables["Voltage [V]"]
388+
config = model.to_config()
389+
390+
# Still compact format with override
391+
assert config["type"] == "SPM"
392+
assert "custom_variables" in config
393+
assert "My variable" in config["custom_variables"]
394+
395+
loaded = pybamm.BaseModel.from_config(config)
396+
assert type(loaded) is type(model)
397+
assert "My variable" in loaded.variables
398+
399+
def test_to_config_builtin_with_cleared_events_round_trip(self):
400+
"""Clearing events survives to_config / from_config round-trip."""
401+
model = pybamm.lithium_ion.SPM()
402+
model.events = []
403+
config = model.to_config()
404+
405+
assert config["type"] == "SPM"
406+
assert "events" in config
407+
assert config["events"] == []
408+
409+
loaded = pybamm.BaseModel.from_config(config)
410+
assert type(loaded) is type(model)
411+
assert loaded.events == []
412+
413+
def test_to_config_builtin_with_removed_variable_round_trip(self):
414+
"""Removed variable is absent after from_config round-trip."""
415+
model = pybamm.lithium_ion.SPM()
416+
assert "Voltage [V]" in model.variables
417+
del model.variables["Voltage [V]"]
418+
config = model.to_config()
419+
420+
assert config["type"] == "SPM"
421+
assert "removed_variables" in config
422+
assert "Voltage [V]" in config["removed_variables"]
423+
424+
loaded = pybamm.BaseModel.from_config(config)
425+
assert type(loaded) is type(model)
426+
assert "Voltage [V]" not in loaded.variables
427+
428+
def test_to_config_builtin_with_added_event_round_trip(self):
429+
"""Custom event added to built-in model survives round-trip."""
430+
model = pybamm.lithium_ion.SPM()
431+
original_count = len(model.events)
432+
model.events.append(
433+
pybamm.Event(
434+
"my_custom_event",
435+
pybamm.Scalar(1),
436+
pybamm.EventType.TERMINATION,
437+
)
438+
)
439+
config = model.to_config()
440+
441+
assert config["type"] == "SPM"
442+
assert "events" in config
443+
assert len(config["events"]) == original_count + 1
444+
445+
loaded = pybamm.BaseModel.from_config(config)
446+
assert type(loaded) is type(model)
447+
event_names = {e.name for e in loaded.events}
448+
assert "my_custom_event" in event_names
449+
450+
def test_to_config_builtin_combined_variable_and_event_changes(self):
451+
"""Matches example.py: add variable + clear events, round-trip."""
452+
model = pybamm.lithium_ion.SPM()
453+
model.events = []
454+
model.variables["My variable"] = 2 * model.variables["Voltage [V]"]
455+
config = model.to_config()
456+
457+
assert config["type"] == "SPM"
458+
assert "custom_variables" in config
459+
assert "events" in config
460+
assert config["events"] == []
461+
462+
loaded = pybamm.BaseModel.from_config(config)
463+
assert type(loaded) is type(model)
464+
assert "My variable" in loaded.variables
465+
assert loaded.events == []
466+
467+
def test_to_config_builtin_event_type_survives_json_round_trip(self):
468+
"""Event type enum survives JSON serialisation round-trip."""
469+
model = pybamm.lithium_ion.SPM()
470+
model.events.append(
471+
pybamm.Event(
472+
"custom_event",
473+
pybamm.Scalar(1),
474+
pybamm.EventType.TERMINATION,
475+
)
476+
)
477+
config = model.to_config()
478+
479+
# JSON round-trip (would fail if event_type is a raw enum)
480+
json_str = json.dumps(config)
481+
reloaded = json.loads(json_str)
482+
483+
loaded = pybamm.BaseModel.from_config(reloaded)
484+
custom = [e for e in loaded.events if e.name == "custom_event"]
485+
assert len(custom) == 1
486+
assert custom[0].event_type == pybamm.EventType.TERMINATION
487+
488+
def test_to_config_builtin_overrides_json_serializable(self):
489+
"""Config with overrides survives JSON round-trip."""
490+
model = pybamm.lithium_ion.SPM()
491+
model.events = []
492+
model.variables["My variable"] = 2 * model.variables["Voltage [V]"]
493+
config = model.to_config()
494+
495+
json_str = json.dumps(config)
496+
assert isinstance(json_str, str)
497+
reloaded = json.loads(json_str)
498+
499+
loaded = pybamm.BaseModel.from_config(reloaded)
500+
assert "My variable" in loaded.variables
501+
assert loaded.events == []

0 commit comments

Comments
 (0)