Skip to content

Commit f73e368

Browse files
Set initial conditions from y slice (#5257)
* allow get state from y_slice * add tests * fix set initial conditions from dict * changelog * remove local test script * coverage * more cov * style: pre-commit fixes * coverage again * more coverage and remove extra ifs * remove another unnecessary function --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 051a6c0 commit f73e368

File tree

4 files changed

+422
-12
lines changed

4 files changed

+422
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)
22

33
## Features
4-
4+
- Allow setting initial conditions from `y_slices` of a `Solution` object. ([#5257](https://github.com/pybamm-team/PyBaMM/pull/5257))
55
- Added docstring to `FuzzyDict.copy` explaining its return value and behavior. ([#5242](https://github.com/pybamm-team/PyBaMM/pull/5242))
66

77
## Bug fixes

src/pybamm/models/base_model.py

Lines changed: 102 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -927,29 +927,121 @@ def set_initial_conditions_from(
927927
"""
928928
mesh = mesh or {}
929929
initial_conditions = {}
930+
is_dict_input = isinstance(solution, dict)
930931
if isinstance(solution, pybamm.Solution):
931932
solution = solution.last_state
932933

934+
def _find_matching_variable(var, solution_model):
935+
if not (
936+
solution_model.is_discretised and solution_model.y_slices is not None
937+
):
938+
return None
939+
var_id = var.id
940+
for sol_var in solution_model.y_slices.keys():
941+
if sol_var.id == var_id:
942+
return sol_var
943+
return None
944+
945+
def _extract_from_y_slices(var, solution_var, solution_model, solution):
946+
solution_y_slice = solution_model.y_slices[solution_var][0]
947+
y_last = solution.y[:, -1] if solution.y.ndim > 1 else solution.y
948+
949+
# Validate slice bounds
950+
slice_stop = (
951+
solution_y_slice.stop
952+
if solution_y_slice.stop is not None
953+
else len(y_last)
954+
)
955+
slice_start = (
956+
solution_y_slice.start if solution_y_slice.start is not None else 0
957+
)
958+
if slice_start < 0 or slice_stop > len(y_last):
959+
return None
960+
961+
# Extract scaled state vector values
962+
y_scaled = np.array(y_last[solution_y_slice])
963+
964+
# Convert from scaled state vector to physical values
965+
# physical = reference + scale * y_scaled
966+
try:
967+
solution_scale = np.asarray(
968+
solution_var.scale.evaluate()
969+
) * np.ones_like(y_scaled)
970+
solution_reference = np.asarray(
971+
solution_var.reference.evaluate()
972+
) * np.ones_like(y_scaled)
973+
return solution_reference + solution_scale * y_scaled
974+
except (TypeError, ValueError, AttributeError):
975+
# Fall back to dict lookup
976+
return None
977+
978+
def _extract_final_time_step(var_data):
979+
var_data = np.array(var_data)
980+
if var_data.ndim == 0:
981+
return var_data
982+
elif var_data.ndim == 1:
983+
return np.array(var_data[-1:])
984+
elif var_data.ndim == 2:
985+
return np.array(var_data[:, -1])
986+
elif var_data.ndim == 3:
987+
return var_data[:, :, -1].flatten(order="F")
988+
elif var_data.ndim == 4:
989+
return var_data[:, :, :, -1].flatten(order="F")
990+
else:
991+
raise NotImplementedError("Variable must be 0D, 1D, 2D, 3D, or 4D")
992+
933993
def get_final_state_eval(final_state):
934-
if isinstance(solution, pybamm.Solution):
994+
# If it's a ProcessedVariable, extract .data first
995+
if isinstance(solution, pybamm.Solution) and hasattr(final_state, "data"):
935996
final_state = final_state.data
936997

998+
final_state = np.array(final_state)
999+
1000+
# Extract final state from time series
9371001
if final_state.ndim == 0:
9381002
return np.array([final_state])
9391003
elif final_state.ndim == 1:
940-
return final_state[-1:]
1004+
# 1D arrays are already final state (from y_slices or processed from dict)
1005+
return np.array(final_state)
9411006
elif final_state.ndim == 2:
942-
return final_state[:, -1]
1007+
return np.array(final_state[:, -1])
9431008
elif final_state.ndim == 3:
9441009
return final_state[:, :, -1].flatten(order="F")
9451010
elif final_state.ndim == 4:
9461011
return final_state[:, :, :, -1].flatten(order="F")
9471012
else:
948-
raise NotImplementedError("Variable must be 0D, 1D, 2D, or 3D")
1013+
raise NotImplementedError("Variable must be 0D, 1D, 2D, 3D, or 4D")
1014+
1015+
def get_variable_state(var):
1016+
var_name = var.name
9491017

950-
def get_variable_state(var_name):
1018+
# Try y_slices for discretised models
1019+
if (
1020+
self.is_discretised
1021+
and var in self.y_slices
1022+
and isinstance(solution, pybamm.Solution)
1023+
and len(solution.all_models) > 0
1024+
):
1025+
try:
1026+
solution_model = solution.all_models[-1]
1027+
solution_var = _find_matching_variable(var, solution_model)
1028+
if solution_var is not None:
1029+
final_state = _extract_from_y_slices(
1030+
var, solution_var, solution_model, solution
1031+
)
1032+
if final_state is not None:
1033+
return final_state
1034+
except (KeyError, AttributeError, IndexError, TypeError):
1035+
pass
1036+
1037+
# Fall back to solution[var_name] lookup
9511038
try:
952-
return solution[var_name]
1039+
var_data = solution[var_name]
1040+
# For dict inputs, extract final time step here
1041+
if is_dict_input:
1042+
return _extract_final_time_step(var_data)
1043+
else:
1044+
return var_data
9531045
except KeyError as e:
9541046
raise pybamm.ModelError(
9551047
"To update a model from a solution, each variable in "
@@ -963,13 +1055,13 @@ def get_variable_state(var_name):
9631055
var, pybamm.Concatenation
9641056
):
9651057
try:
966-
final_state = get_variable_state(var.name)
1058+
final_state = get_variable_state(var)
9671059
final_state_eval = get_final_state_eval(final_state)
9681060
except pybamm.ModelError as e:
9691061
if isinstance(var, pybamm.Concatenation):
9701062
children = []
9711063
for child in var.orphans:
972-
final_state = get_variable_state(child.name)
1064+
final_state = get_variable_state(child)
9731065
final_state_eval = get_final_state_eval(final_state)
9741066
children.append(final_state_eval)
9751067
final_state_eval = np.concatenate(children)
@@ -986,10 +1078,10 @@ def get_variable_state(var_name):
9861078
if self.is_discretised:
9871079
scale, reference = var.scale, var.reference
9881080
else:
989-
scale, reference = 1, 0
1081+
scale, reference = pybamm.Scalar(1), pybamm.Scalar(0)
9901082
initial_conditions[var] = (
9911083
pybamm.Vector(final_state_eval) - reference
992-
) / scale
1084+
) / scale.evaluate()
9931085

9941086
# Also update the concatenated initial conditions if the model is already
9951087
# discretised

0 commit comments

Comments
 (0)