Skip to content

Commit d7451be

Browse files
committed
Add test and fix bug
1 parent 78fe787 commit d7451be

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

src/pybamm/solvers/base_solver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ def solve(
797797
for it in model.concatenated_initial_conditions.pre_order()
798798
]
799799
)
800-
if all_inputs_names.issubset(initial_conditions_node_names):
800+
if not initial_conditions_node_names.isdisjoint(all_inputs_names):
801801
raise pybamm.SolverError(
802802
"Input parameters cannot appear in expression "
803803
"for initial conditions."
@@ -847,9 +847,9 @@ def solve(
847847
# If the new initial conditions are different
848848
# and cannot be evaluated directly, set up again
849849
self.set_up(model, model_inputs_list[0], t_eval, ics_only=True)
850-
self._model_set_up[model]["initial conditions"] = (
851-
model.concatenated_initial_conditions
852-
)
850+
self._model_set_up[model][
851+
"initial conditions"
852+
] = model.concatenated_initial_conditions
853853
else:
854854
# Set the standard initial conditions
855855
self._set_initial_conditions(model, t_eval[0], model_inputs_list[0])

tests/unit/test_solvers/test_base_solver.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,33 @@ def test_on_extrapolation_and_on_failure_settings(self):
446446
ValueError, match="on_failure must be 'warn', 'raise', or 'ignore'"
447447
):
448448
base_solver.on_failure = "invalid"
449+
450+
def test_solver_multiple_inputs_initial_conditions_error(self):
451+
452+
y = pybamm.Variable("y")
453+
y0 = pybamm.InputParameter("y0")
454+
k = pybamm.InputParameter("k")
455+
456+
model = pybamm.BaseModel()
457+
model.rhs = {y: -k * y}
458+
model.initial_conditions = {y: y0}
459+
model.variables = {"y": y}
460+
461+
disc = pybamm.Discretisation()
462+
disc.process_model(model)
463+
464+
t_eval = np.linspace(0.0, 1.0, 6)
465+
466+
# Three different ICs so each run is clearly distinct
467+
inputs_list = [
468+
{"y0": 1.0, "k": 0.5},
469+
{"y0": 2.0, "k": 1.0},
470+
{"y0": 3.0, "k": 1.5},
471+
]
472+
473+
solver = pybamm.BaseSolver()
474+
with pytest.raises(
475+
pybamm.SolverError,
476+
match="Input parameters cannot appear in expression for initial conditions",
477+
):
478+
solver.solve(model, t_eval=t_eval, inputs=inputs_list)

0 commit comments

Comments
 (0)