Skip to content

Commit 3e32d20

Browse files
committed
Add test and fix bug
1 parent c550680 commit 3e32d20

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

src/pybamm/solvers/base_solver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ def solve(
860860
for it in model.concatenated_initial_conditions.pre_order()
861861
]
862862
)
863-
if all_inputs_names.issubset(initial_conditions_node_names):
863+
if not initial_conditions_node_names.isdisjoint(all_inputs_names):
864864
raise pybamm.SolverError(
865865
"Input parameters cannot appear in expression "
866866
"for initial conditions."
@@ -910,9 +910,9 @@ def solve(
910910
# If the new initial conditions are different
911911
# and cannot be evaluated directly, set up again
912912
self.set_up(model, model_inputs_list[0], t_eval, ics_only=True)
913-
self._model_set_up[model]["initial conditions"] = (
914-
model.concatenated_initial_conditions
915-
)
913+
self._model_set_up[model][
914+
"initial conditions"
915+
] = model.concatenated_initial_conditions
916916
else:
917917
# Set the standard initial conditions
918918
self._set_initial_conditions(model, t_eval[0], model_inputs_list[0])

tests/unit/test_solvers/test_base_solver.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,33 @@ def test_on_extrapolation_and_on_failure_settings(self):
446446
ValueError, match="on_failure must be 'warn', 'raise', or 'ignore'"
447447
):
448448
base_solver.on_failure = "invalid"
449+
450+
def test_solver_multiple_inputs_initial_conditions_error(self):
451+
452+
y = pybamm.Variable("y")
453+
y0 = pybamm.InputParameter("y0")
454+
k = pybamm.InputParameter("k")
455+
456+
model = pybamm.BaseModel()
457+
model.rhs = {y: -k * y}
458+
model.initial_conditions = {y: y0}
459+
model.variables = {"y": y}
460+
461+
disc = pybamm.Discretisation()
462+
disc.process_model(model)
463+
464+
t_eval = np.linspace(0.0, 1.0, 6)
465+
466+
# Three different ICs so each run is clearly distinct
467+
inputs_list = [
468+
{"y0": 1.0, "k": 0.5},
469+
{"y0": 2.0, "k": 1.0},
470+
{"y0": 3.0, "k": 1.5},
471+
]
472+
473+
solver = pybamm.BaseSolver()
474+
with pytest.raises(
475+
pybamm.SolverError,
476+
match="Input parameters cannot appear in expression for initial conditions",
477+
):
478+
solver.solve(model, t_eval=t_eval, inputs=inputs_list)

0 commit comments

Comments
 (0)