@@ -927,29 +927,121 @@ def set_initial_conditions_from(
927927 """
928928 mesh = mesh or {}
929929 initial_conditions = {}
930+ is_dict_input = isinstance (solution , dict )
930931 if isinstance (solution , pybamm .Solution ):
931932 solution = solution .last_state
932933
934+ def _find_matching_variable (var , solution_model ):
935+ if not (
936+ solution_model .is_discretised and solution_model .y_slices is not None
937+ ):
938+ return None
939+ var_id = var .id
940+ for sol_var in solution_model .y_slices .keys ():
941+ if sol_var .id == var_id :
942+ return sol_var
943+ return None
944+
945+ def _extract_from_y_slices (var , solution_var , solution_model , solution ):
946+ solution_y_slice = solution_model .y_slices [solution_var ][0 ]
947+ y_last = solution .y [:, - 1 ] if solution .y .ndim > 1 else solution .y
948+
949+ # Validate slice bounds
950+ slice_stop = (
951+ solution_y_slice .stop
952+ if solution_y_slice .stop is not None
953+ else len (y_last )
954+ )
955+ slice_start = (
956+ solution_y_slice .start if solution_y_slice .start is not None else 0
957+ )
958+ if slice_start < 0 or slice_stop > len (y_last ):
959+ return None
960+
961+ # Extract scaled state vector values
962+ y_scaled = np .array (y_last [solution_y_slice ])
963+
964+ # Convert from scaled state vector to physical values
965+ # physical = reference + scale * y_scaled
966+ try :
967+ solution_scale = np .asarray (
968+ solution_var .scale .evaluate ()
969+ ) * np .ones_like (y_scaled )
970+ solution_reference = np .asarray (
971+ solution_var .reference .evaluate ()
972+ ) * np .ones_like (y_scaled )
973+ return solution_reference + solution_scale * y_scaled
974+ except (TypeError , ValueError , AttributeError ):
975+ # Fall back to dict lookup
976+ return None
977+
978+ def _extract_final_time_step (var_data ):
979+ var_data = np .array (var_data )
980+ if var_data .ndim == 0 :
981+ return var_data
982+ elif var_data .ndim == 1 :
983+ return np .array (var_data [- 1 :])
984+ elif var_data .ndim == 2 :
985+ return np .array (var_data [:, - 1 ])
986+ elif var_data .ndim == 3 :
987+ return var_data [:, :, - 1 ].flatten (order = "F" )
988+ elif var_data .ndim == 4 :
989+ return var_data [:, :, :, - 1 ].flatten (order = "F" )
990+ else :
991+ raise NotImplementedError ("Variable must be 0D, 1D, 2D, 3D, or 4D" )
992+
933993 def get_final_state_eval (final_state ):
934- if isinstance (solution , pybamm .Solution ):
994+ # If it's a ProcessedVariable, extract .data first
995+ if isinstance (solution , pybamm .Solution ) and hasattr (final_state , "data" ):
935996 final_state = final_state .data
936997
998+ final_state = np .array (final_state )
999+
1000+ # Extract final state from time series
9371001 if final_state .ndim == 0 :
9381002 return np .array ([final_state ])
9391003 elif final_state .ndim == 1 :
940- return final_state [- 1 :]
1004+ # 1D arrays are already final state (from y_slices or processed from dict)
1005+ return np .array (final_state )
9411006 elif final_state .ndim == 2 :
942- return final_state [:, - 1 ]
1007+ return np . array ( final_state [:, - 1 ])
9431008 elif final_state .ndim == 3 :
9441009 return final_state [:, :, - 1 ].flatten (order = "F" )
9451010 elif final_state .ndim == 4 :
9461011 return final_state [:, :, :, - 1 ].flatten (order = "F" )
9471012 else :
948- raise NotImplementedError ("Variable must be 0D, 1D, 2D, or 3D" )
1013+ raise NotImplementedError ("Variable must be 0D, 1D, 2D, 3D, or 4D" )
1014+
1015+ def get_variable_state (var ):
1016+ var_name = var .name
9491017
950- def get_variable_state (var_name ):
1018+ # Try y_slices for discretised models
1019+ if (
1020+ self .is_discretised
1021+ and var in self .y_slices
1022+ and isinstance (solution , pybamm .Solution )
1023+ and len (solution .all_models ) > 0
1024+ ):
1025+ try :
1026+ solution_model = solution .all_models [- 1 ]
1027+ solution_var = _find_matching_variable (var , solution_model )
1028+ if solution_var is not None :
1029+ final_state = _extract_from_y_slices (
1030+ var , solution_var , solution_model , solution
1031+ )
1032+ if final_state is not None :
1033+ return final_state
1034+ except (KeyError , AttributeError , IndexError , TypeError ):
1035+ pass
1036+
1037+ # Fall back to solution[var_name] lookup
9511038 try :
952- return solution [var_name ]
1039+ var_data = solution [var_name ]
1040+ # For dict inputs, extract final time step here
1041+ if is_dict_input :
1042+ return _extract_final_time_step (var_data )
1043+ else :
1044+ return var_data
9531045 except KeyError as e :
9541046 raise pybamm .ModelError (
9551047 "To update a model from a solution, each variable in "
@@ -963,13 +1055,13 @@ def get_variable_state(var_name):
9631055 var , pybamm .Concatenation
9641056 ):
9651057 try :
966- final_state = get_variable_state (var . name )
1058+ final_state = get_variable_state (var )
9671059 final_state_eval = get_final_state_eval (final_state )
9681060 except pybamm .ModelError as e :
9691061 if isinstance (var , pybamm .Concatenation ):
9701062 children = []
9711063 for child in var .orphans :
972- final_state = get_variable_state (child . name )
1064+ final_state = get_variable_state (child )
9731065 final_state_eval = get_final_state_eval (final_state )
9741066 children .append (final_state_eval )
9751067 final_state_eval = np .concatenate (children )
@@ -986,10 +1078,10 @@ def get_variable_state(var_name):
9861078 if self .is_discretised :
9871079 scale , reference = var .scale , var .reference
9881080 else :
989- scale , reference = 1 , 0
1081+ scale , reference = pybamm . Scalar ( 1 ), pybamm . Scalar ( 0 )
9901082 initial_conditions [var ] = (
9911083 pybamm .Vector (final_state_eval ) - reference
992- ) / scale
1084+ ) / scale . evaluate ()
9931085
9941086 # Also update the concatenated initial conditions if the model is already
9951087 # discretised
0 commit comments