2222
2323class ISEFlow (torch .nn .Module ):
2424 """
25- The ISEFlow (Flow-based Ice Sheet Emulator) that combines a deep ensemble and a normalizing flow model.
25+ ISEFlow is a hybrid ice sheet emulator that combines a deep ensemble model and a normalizing flow model.
26+
27+ This class provides methods to train, predict, save, and load hybrid models for ice sheet emulation.
28+ It integrates a deep ensemble to capture epistemic uncertainties and a normalizing flow to model aleatoric uncertainties.
29+
30+ Attributes:
31+ device (str): The computing device ('cuda' if available, else 'cpu').
32+ deep_ensemble (DeepEnsemble): The deep ensemble model for epistemic uncertainty.
33+ normalizing_flow (NormalizingFlow): The normalizing flow model for aleatoric uncertainty.
34+ trained (bool): Flag indicating whether the model has been trained.
35+ scaler_path (str or None): Path to the scaler used for output transformation.
2636 """
2737
38+
2839 def __init__ (self , deep_ensemble , normalizing_flow ):
40+ """
41+ Initializes the ISEFlow model with a deep ensemble and a normalizing flow.
42+
43+ Args:
44+ deep_ensemble (DeepEnsemble): A deep ensemble model for epistemic uncertainty estimation.
45+ normalizing_flow (NormalizingFlow): A normalizing flow model for aleatoric uncertainty estimation.
46+
47+ Raises:
48+ ValueError: If `deep_ensemble` is not an instance of DeepEnsemble.
49+ ValueError: If `normalizing_flow` is not an instance of NormalizingFlow.
50+ """
51+
2952 super (ISEFlow , self ).__init__ ()
3053
3154 self .device = "cuda" if torch .cuda .is_available () else "cpu"
@@ -44,8 +67,30 @@ def __init__(self, deep_ensemble, normalizing_flow):
4467 def fit (self , X , y , nf_epochs , de_epochs , batch_size = 64 , X_val = None , y_val = None , save_checkpoints = True , checkpoint_path = 'checkpoint_ensemble' , early_stopping = True ,
4568 sequence_length = 5 , patience = 10 , verbose = True ):
4669 """
47- Fits the hybrid emulator to the training data.
70+ Trains the hybrid emulator using the provided data.
71+
72+ This method trains the normalizing flow model first, then uses its latent representations
73+ to train the deep ensemble model.
74+
75+ Args:
76+ X (array-like): Input feature matrix.
77+ y (array-like): Target values.
78+ nf_epochs (int): Number of training epochs for the normalizing flow.
79+ de_epochs (int): Number of training epochs for the deep ensemble.
80+ batch_size (int, optional): Batch size for training. Defaults to 64.
81+ X_val (array-like, optional): Validation feature matrix. Defaults to None.
82+ y_val (array-like, optional): Validation target values. Defaults to None.
83+ save_checkpoints (bool, optional): Whether to save training checkpoints. Defaults to True.
84+ checkpoint_path (str, optional): Path prefix for saving model checkpoints. Defaults to 'checkpoint_ensemble'.
85+ early_stopping (bool, optional): Whether to use early stopping. Defaults to True.
86+ sequence_length (int, optional): Sequence length for recurrent architectures. Defaults to 5.
87+ patience (int, optional): Number of epochs with no improvement before stopping. Defaults to 10.
88+ verbose (bool, optional): Whether to print training progress. Defaults to True.
89+
90+ Raises:
91+ Warning: If the model has already been trained.
4892 """
93+
4994
5095 if early_stopping is None :
5196 early_stopping = X_val is not None and y_val is not None
@@ -83,7 +128,23 @@ def fit(self, X, y, nf_epochs, de_epochs, batch_size=64, X_val=None, y_val=None,
83128 def forward (self , x , smooth_projection = False ):
84129 """
85130 Performs a forward pass through the hybrid emulator.
131+
132+ Args:
133+ x (array-like): Input data.
134+ smooth_projection (bool, optional): Whether to apply smoothing to projections. Defaults to False.
135+
136+ Returns:
137+ tuple: A tuple containing:
138+ - prediction (numpy.ndarray): Model predictions.
139+ - uncertainties (dict): Dictionary with keys:
140+ - 'total' (numpy.ndarray): Total uncertainty.
141+ - 'epistemic' (numpy.ndarray): Epistemic uncertainty.
142+ - 'aleatoric' (numpy.ndarray): Aleatoric uncertainty.
143+
144+ Raises:
145+ Warning: If the model has not been trained.
86146 """
147+
87148 self .eval ()
88149 x = to_tensor (x ).to (self .device )
89150 if not self .trained :
@@ -102,6 +163,26 @@ def forward(self, x, smooth_projection=False):
102163 return prediction , uncertainties
103164
104165 def predict (self , x , output_scaler = True , smooth_projection = False ):
166+ """
167+ Makes predictions using the trained hybrid emulator.
168+
169+ Args:
170+ x (array-like): Input data.
171+ output_scaler (bool or str, optional): Path to the output scaler or whether to apply scaling. Defaults to True.
172+ smooth_projection (bool, optional): Whether to apply smoothing. Defaults to False.
173+
174+ Returns:
175+ tuple: A tuple containing:
176+ - unscaled_predictions (numpy.ndarray): Model predictions in the original scale.
177+ - uncertainties (dict): Dictionary with keys:
178+ - 'total' (numpy.ndarray): Total uncertainty.
179+ - 'epistemic' (numpy.ndarray): Epistemic uncertainty.
180+ - 'aleatoric' (numpy.ndarray): Aleatoric uncertainty.
181+
182+ Raises:
183+ Warning: If no scaler path is provided.
184+ """
185+
105186 self .eval ()
106187 if output_scaler is True :
107188 output_scaler = os .path .join (self .model_dir , "scaler_y.pkl" )
@@ -134,8 +215,19 @@ def predict(self, x, output_scaler=True, smooth_projection=False):
134215
135216 def save (self , save_dir , input_features = None , output_scaler_path = None ):
136217 """
137- Saves the trained model to the specified directory.
218+ Saves the trained model and related components to a specified directory.
219+
220+ Args:
221+ save_dir (str): Directory where the model should be saved.
222+ input_features (list, optional): List of input feature names. Defaults to None.
223+ output_scaler_path (str, optional): Path to the output scaler. Defaults to None.
224+
225+ Raises:
226+ ValueError: If the model has not been trained.
227+ ValueError: If `save_dir` is a file instead of a directory.
228+ ValueError: If `input_features` is not a list.
138229 """
230+
139231 if not self .trained :
140232 raise ValueError ("This model has not been trained yet. Train the model before saving." )
141233 if save_dir .endswith (".pth" ):
@@ -160,8 +252,20 @@ def save(self, save_dir, input_features=None, output_scaler_path=None):
160252 @staticmethod
161253 def load (model_dir = None , deep_ensemble_path = None , normalizing_flow_path = None ,):
162254 """
163- Loads a trained model from the specified paths.
255+ Loads a trained ISEFlow model from specified paths.
256+
257+ Args:
258+ model_dir (str, optional): Directory containing the saved model. Defaults to None.
259+ deep_ensemble_path (str, optional): Path to the saved deep ensemble model. Defaults to None.
260+ normalizing_flow_path (str, optional): Path to the saved normalizing flow model. Defaults to None.
261+
262+ Returns:
263+ ISEFlow: The loaded ISEFlow model.
264+
265+ Raises:
266+ NotImplementedError: If an unsupported version is specified.
164267 """
268+
165269
166270 if model_dir :
167271 deep_ensemble_path = os .path .join (model_dir , "deep_ensemble.pth" )
@@ -178,14 +282,46 @@ def load(model_dir=None, deep_ensemble_path=None, normalizing_flow_path=None,):
178282
179283
180284class ISEFlow_AIS (ISEFlow ):
285+ """
286+ ISEFlow_AIS is a specialized version of ISEFlow for the Antarctic Ice Sheet (AIS).
287+
288+ This subclass initializes the deep ensemble and normalizing flow models specifically
289+ for AIS and provides a loading method with pre-trained model paths.
290+ """
291+
181292 def __init__ (self ,):
293+ """
294+ Initializes the ISEFlow_AIS model.
295+
296+ Sets the ice sheet type to 'AIS' and initializes pre-configured deep ensemble
297+ and normalizing flow models specific to AIS.
298+ """
299+
182300 self .ice_sheet = "AIS"
183301 deep_ensemble = ISEFlow_AIS_DE ()
184302 normalizing_flow = ISEFlow_AIS_NF ()
185303 super (ISEFlow_AIS , self ).__init__ (deep_ensemble , normalizing_flow )
186304
187305 @staticmethod
188306 def load (version = "v1.0.0" , model_dir = None , deep_ensemble_path = None , normalizing_flow_path = None ):
307+ """
308+ Loads a trained ISEFlow_AIS model.
309+
310+ Args:
311+ version (str, optional): Model version. Defaults to "v1.0.0".
312+ model_dir (str, optional): Directory of the saved model. Defaults to None.
313+ deep_ensemble_path (str, optional): Path to deep ensemble model. Defaults to None.
314+ normalizing_flow_path (str, optional): Path to normalizing flow model. Defaults to None.
315+
316+ Returns:
317+ ISEFlow_AIS: The loaded model.
318+
319+ Raises:
320+ NotImplementedError: If an unsupported version is specified.
321+ """
322+
323+ # TODO: Add support for deep ensemble and normalizing flow paths
324+
189325 if model_dir is None :
190326 if version == "v1.0.0" :
191327 model_dir = ISEFlow_AIS_v1_0_0_path
@@ -230,6 +366,39 @@ def process(
230366 standard_melt_type : str = None ,
231367
232368 ):
369+ """
370+ Processes input data for prediction by applying necessary transformations and encoding.
371+
372+ Args:
373+ year (np.array): Years of the input data.
374+ pr_anomaly (np.array): Precipitation anomaly data.
375+ evspsbl_anomaly (np.array): Evaporation anomaly data.
376+ mrro_anomaly (np.array): Runoff anomaly data.
377+ smb_anomaly (np.array): Surface mass balance anomaly.
378+ ts_anomaly (np.array): Surface temperature anomaly.
379+ ocean_thermal_forcing (np.array): Ocean thermal forcing.
380+ ocean_salinity (np.array): Ocean salinity.
381+ ocean_temperature (np.array): Ocean temperature.
382+ initial_year (int): Initial year for modeling.
383+ numerics (str): Numerical scheme used.
384+ stress_balance (str): Stress balance model.
385+ resolution (int): Resolution of the model.
386+ init_method (str): Initialization method.
387+ melt_in_floating_cells (str): Melt treatment method.
388+ icefront_migration (str): Ice front migration scheme.
389+ ocean_forcing_type (str): Type of ocean forcing applied.
390+ ocean_sensitivity (str): Ocean sensitivity setting.
391+ ice_shelf_fracture (bool): Whether ice shelf fracture is considered.
392+ open_melt_type (str, optional): Type of open melt model. Defaults to None.
393+ standard_melt_type (str, optional): Type of standard melt model. Defaults to None.
394+
395+ Returns:
396+ pd.DataFrame: Processed input data ready for prediction.
397+
398+ Raises:
399+ ValueError: If any input arguments are invalid.
400+ """
401+
233402
234403 if year [0 ] == 2015 :
235404 year = year - 2015
@@ -458,6 +627,18 @@ def predict(
458627 open_melt_type : str = None ,
459628 standard_melt_type : str = None ,
460629 ):
630+ """
631+ Predicts ice sheet evolution using the trained ISEFlow_AIS model.
632+
633+ Args:
634+ (Same as process method)
635+
636+ Returns:
637+ tuple: A tuple containing:
638+ - unscaled_predictions (numpy.ndarray): Model predictions in the original scale.
639+ - uncertainties (dict): Dictionary containing different uncertainty components.
640+ """
641+
461642
462643 data = self .process (
463644 year , pr_anomaly , evspsbl_anomaly , mrro_anomaly , smb_anomaly , ts_anomaly , ocean_thermal_forcing , ocean_salinity , ocean_temperature , initial_year , numerics , stress_balance , resolution , init_method , melt_in_floating_cells , icefront_migration , ocean_forcing_type , ocean_sensitivity , ice_shelf_fracture , open_melt_type , standard_melt_type
@@ -467,17 +648,39 @@ def predict(
467648 return super ().predict (X , output_scaler = f"{ ISEFlow_AIS_v1_0_0_path } /scaler_y.pkl" )
468649
469650class ISEFlow_GrIS (ISEFlow ):
651+ """
652+ ISEFlow_GrIS is a specialized version of ISEFlow for the Greenland Ice Sheet (GrIS).
653+
654+ This subclass initializes the deep ensemble and normalizing flow models specifically
655+ for GrIS and provides a loading method with pre-trained model paths.
656+ """
657+
470658 def __init__ (self ,):
659+ """
660+ Initializes the ISEFlow_GrIS model.
661+
662+ Sets the ice sheet type to 'GrIS' and initializes pre-configured deep ensemble
663+ and normalizing flow models specific to GrIS.
664+ """
665+
471666 self .ice_sheet = "GrIS"
472667 deep_ensemble = ISEFlow_GrIS_DE ()
473668 normalizing_flow = ISEFlow_GrIS_NF ()
474669 super (ISEFlow_GrIS , self ).__init__ (deep_ensemble , normalizing_flow )
475670
476671 @staticmethod
477672 def load (version = "v1.0.0" , model_dir = None , deep_ensemble_path = None , normalizing_flow_path = None ,):
673+ """
674+ Loads a trained ISEFlow_GrIS model.
675+
676+ (Same arguments and return type as `ISEFlow_AIS.load`.)
677+ """
678+
478679 if model_dir is None :
479680 if version == "v1.0.0" :
480681 model_dir = ISEFlow_GrIS_v1_0_0_path
481682 else :
482683 raise NotImplementedError ("Only version v1.0.0 is supported" )
483684 return super (ISEFlow_GrIS , ISEFlow_GrIS ).load (model_dir , deep_ensemble_path , normalizing_flow_path )
685+
686+ # TODO: ISEFlow GrIS process, predict
0 commit comments