Skip to content

Commit 86fec31

Browse files
Merge pull request #5274 from agriyakhetarpal/roundtrip-serialisation-fix
Don't be too strict with func_args longer than symbol.children
2 parents 1319b4a + eda4ebe commit 86fec31

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

src/pybamm/parameters/parameter_values.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -866,12 +866,21 @@ def _process_function_parameter(self, symbol):
866866
else:
867867
new_children.append(self.process_symbol(child))
868868

869-
# Get the expression and inputs for the function
869+
# Get the expression and inputs for the function.
870+
# func_args may include arguments that were not explicitly wired up
871+
# in this FunctionParameter (e.g., kwargs with default values). After
872+
# serialisation/deserialisation, we only recover the children that were
873+
# actually connected.
874+
#
875+
# Using strict=True here therefore raises a ValueError when there are
876+
# more args than children. We allow func_args to be longer than
877+
# symbol.children and only build the mapping for the args for which we
878+
# actually have children.
870879
expression = function_parameter.child
871880
inputs = {
872881
arg: child
873882
for arg, child in zip(
874-
function_parameter.func_args, symbol.children, strict=True
883+
function_parameter.func_args, symbol.children, strict=False
875884
)
876885
}
877886

tests/unit/test_parameters/test_parameter_values.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,6 +1298,31 @@ def test_to_json_with_filename(self):
12981298
finally:
12991299
os.remove(temp_path)
13001300

1301+
def test_roundtrip_with_keyword_args(self):
1302+
def func_no_kwargs(x):
1303+
return 2 * x
1304+
1305+
def func_with_kwargs(x, y=1):
1306+
return 2 * x
1307+
1308+
x = pybamm.Scalar(2)
1309+
func_param = pybamm.FunctionParameter("func", {"x": x})
1310+
1311+
parameter_values = pybamm.ParameterValues({"func": func_no_kwargs})
1312+
assert parameter_values.evaluate(func_param) == 4.0
1313+
1314+
serialized = parameter_values.to_json()
1315+
parameter_values_loaded = pybamm.ParameterValues.from_json(serialized)
1316+
assert parameter_values_loaded.evaluate(func_param) == 4.0
1317+
1318+
parameter_values = pybamm.ParameterValues({"func": func_with_kwargs})
1319+
assert parameter_values.evaluate(func_param) == 4.0
1320+
1321+
serialized = parameter_values.to_json()
1322+
parameter_values_loaded = pybamm.ParameterValues.from_json(serialized)
1323+
1324+
assert parameter_values_loaded.evaluate(func_param) == 4.0
1325+
13011326
def test_convert_symbols_in_dict_with_interpolator(self):
13021327
"""Test convert_symbols_in_dict with interpolator (covers lines 1154-1170)."""
13031328
import numpy as np

0 commit comments

Comments
 (0)