Skip to content

Commit 86c7c4d

Browse files
committed
Add test and fix bug
1 parent acff2a2 commit 86c7c4d

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

src/pybamm/solvers/base_solver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ def solve(
789789
for it in model.concatenated_initial_conditions.pre_order()
790790
]
791791
)
792-
if all_inputs_names.issubset(initial_conditions_node_names):
792+
if not initial_conditions_node_names.isdisjoint(all_inputs_names):
793793
raise pybamm.SolverError(
794794
"Input parameters cannot appear in expression "
795795
"for initial conditions."
@@ -839,9 +839,9 @@ def solve(
839839
# If the new initial conditions are different
840840
# and cannot be evaluated directly, set up again
841841
self.set_up(model, model_inputs_list[0], t_eval, ics_only=True)
842-
self._model_set_up[model]["initial conditions"] = (
843-
model.concatenated_initial_conditions
844-
)
842+
self._model_set_up[model][
843+
"initial conditions"
844+
] = model.concatenated_initial_conditions
845845
else:
846846
# Set the standard initial conditions
847847
self._set_initial_conditions(model, t_eval[0], model_inputs_list[0])

tests/unit/test_solvers/test_base_solver.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,33 @@ def test_on_extrapolation_and_on_failure_settings(self):
446446
ValueError, match="on_failure must be 'warn', 'raise', or 'ignore'"
447447
):
448448
base_solver.on_failure = "invalid"
449+
450+
def test_solver_multiple_inputs_initial_conditions_error(self):
451+
452+
y = pybamm.Variable("y")
453+
y0 = pybamm.InputParameter("y0")
454+
k = pybamm.InputParameter("k")
455+
456+
model = pybamm.BaseModel()
457+
model.rhs = {y: -k * y}
458+
model.initial_conditions = {y: y0}
459+
model.variables = {"y": y}
460+
461+
disc = pybamm.Discretisation()
462+
disc.process_model(model)
463+
464+
t_eval = np.linspace(0.0, 1.0, 6)
465+
466+
# Three different ICs so each run is clearly distinct
467+
inputs_list = [
468+
{"y0": 1.0, "k": 0.5},
469+
{"y0": 2.0, "k": 1.0},
470+
{"y0": 3.0, "k": 1.5},
471+
]
472+
473+
solver = pybamm.BaseSolver()
474+
with pytest.raises(
475+
pybamm.SolverError,
476+
match="Input parameters cannot appear in expression for initial conditions",
477+
):
478+
solver.solve(model, t_eval=t_eval, inputs=inputs_list)

0 commit comments

Comments
 (0)