File tree Expand file tree Collapse file tree 2 files changed +36
-2
lines changed
tests/unit/test_parameters Expand file tree Collapse file tree 2 files changed +36
-2
lines changed Original file line number Diff line number Diff line change @@ -866,12 +866,21 @@ def _process_function_parameter(self, symbol):
866866 else :
867867 new_children .append (self .process_symbol (child ))
868868
869- # Get the expression and inputs for the function
869+ # Get the expression and inputs for the function.
870+ # func_args may include arguments that were not explicitly wired up
871+ # in this FunctionParameter (e.g., kwargs with default values). After
872+ # serialisation/deserialisation, we only recover the children that were
873+ # actually connected.
874+ #
875+ # Using strict=True here therefore raises a ValueError when there are
876+ # more args than children. We allow func_args to be longer than
877+ # symbol.children and only build the mapping for the args for which we
878+ # actually have children.
870879 expression = function_parameter .child
871880 inputs = {
872881 arg : child
873882 for arg , child in zip (
874- function_parameter .func_args , symbol .children , strict = True
883+ function_parameter .func_args , symbol .children , strict = False
875884 )
876885 }
877886
Original file line number Diff line number Diff line change @@ -1298,6 +1298,31 @@ def test_to_json_with_filename(self):
12981298 finally :
12991299 os .remove (temp_path )
13001300
1301+ def test_roundtrip_with_keyword_args (self ):
1302+ def func_no_kwargs (x ):
1303+ return 2 * x
1304+
1305+ def func_with_kwargs (x , y = 1 ):
1306+ return 2 * x
1307+
1308+ x = pybamm .Scalar (2 )
1309+ func_param = pybamm .FunctionParameter ("func" , {"x" : x })
1310+
1311+ parameter_values = pybamm .ParameterValues ({"func" : func_no_kwargs })
1312+ assert parameter_values .evaluate (func_param ) == 4.0
1313+
1314+ serialized = parameter_values .to_json ()
1315+ parameter_values_loaded = pybamm .ParameterValues .from_json (serialized )
1316+ assert parameter_values_loaded .evaluate (func_param ) == 4.0
1317+
1318+ parameter_values = pybamm .ParameterValues ({"func" : func_with_kwargs })
1319+ assert parameter_values .evaluate (func_param ) == 4.0
1320+
1321+ serialized = parameter_values .to_json ()
1322+ parameter_values_loaded = pybamm .ParameterValues .from_json (serialized )
1323+
1324+ assert parameter_values_loaded .evaluate (func_param ) == 4.0
1325+
13011326 def test_convert_symbols_in_dict_with_interpolator (self ):
13021327 """Test convert_symbols_in_dict with interpolator (covers lines 1154-1170)."""
13031328 import numpy as np
You can’t perform that action at this time.
0 commit comments