Skip to content

Commit 18a273c

Browse files
committed
sphinx integration
1 parent 190d5cd commit 18a273c

File tree

13 files changed

+1331
-291
lines changed

13 files changed

+1331
-291
lines changed

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ ISE, or ice-sheet emulators, is a package for end-to-end creation and analysis o
55

66
The main features of ISE include loading and processing of ISMIP6 sea level contribution simulations,
77
data preparation and feature engineering for machine learning, and training and testing of trained neural network emulators.
8+
This repository also contains the code to use ISEFlow, as presented in the paper "ISEFlow: A Flow-Based Neural Network Emulator for Improved Sea Level Projections and Uncertainty Quantification".
89

910
.. toctree::
1011
:maxdepth: 1

ise/models/ISEFlow/ISEFlow.py

Lines changed: 207 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,33 @@
2222

2323
class ISEFlow(torch.nn.Module):
2424
"""
25-
The ISEFlow (Flow-based Ice Sheet Emulator) that combines a deep ensemble and a normalizing flow model.
25+
ISEFlow is a hybrid ice sheet emulator that combines a deep ensemble model and a normalizing flow model.
26+
27+
This class provides methods to train, predict, save, and load hybrid models for ice sheet emulation.
28+
It integrates a deep ensemble to capture epistemic uncertainties and a normalizing flow to model aleatoric uncertainties.
29+
30+
Attributes:
31+
device (str): The computing device ('cuda' if available, else 'cpu').
32+
deep_ensemble (DeepEnsemble): The deep ensemble model for epistemic uncertainty.
33+
normalizing_flow (NormalizingFlow): The normalizing flow model for aleatoric uncertainty.
34+
trained (bool): Flag indicating whether the model has been trained.
35+
scaler_path (str or None): Path to the scaler used for output transformation.
2636
"""
2737

38+
2839
def __init__(self, deep_ensemble, normalizing_flow):
40+
"""
41+
Initializes the ISEFlow model with a deep ensemble and a normalizing flow.
42+
43+
Args:
44+
deep_ensemble (DeepEnsemble): A deep ensemble model for epistemic uncertainty estimation.
45+
normalizing_flow (NormalizingFlow): A normalizing flow model for aleatoric uncertainty estimation.
46+
47+
Raises:
48+
ValueError: If `deep_ensemble` is not an instance of DeepEnsemble.
49+
ValueError: If `normalizing_flow` is not an instance of NormalizingFlow.
50+
"""
51+
2952
super(ISEFlow, self).__init__()
3053

3154
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -44,8 +67,30 @@ def __init__(self, deep_ensemble, normalizing_flow):
4467
def fit(self, X, y, nf_epochs, de_epochs, batch_size=64, X_val=None, y_val=None, save_checkpoints=True, checkpoint_path='checkpoint_ensemble', early_stopping=True,
4568
sequence_length=5, patience=10, verbose=True):
4669
"""
47-
Fits the hybrid emulator to the training data.
70+
Trains the hybrid emulator using the provided data.
71+
72+
This method trains the normalizing flow model first, then uses its latent representations
73+
to train the deep ensemble model.
74+
75+
Args:
76+
X (array-like): Input feature matrix.
77+
y (array-like): Target values.
78+
nf_epochs (int): Number of training epochs for the normalizing flow.
79+
de_epochs (int): Number of training epochs for the deep ensemble.
80+
batch_size (int, optional): Batch size for training. Defaults to 64.
81+
X_val (array-like, optional): Validation feature matrix. Defaults to None.
82+
y_val (array-like, optional): Validation target values. Defaults to None.
83+
save_checkpoints (bool, optional): Whether to save training checkpoints. Defaults to True.
84+
checkpoint_path (str, optional): Path prefix for saving model checkpoints. Defaults to 'checkpoint_ensemble'.
85+
early_stopping (bool, optional): Whether to use early stopping. Defaults to True.
86+
sequence_length (int, optional): Sequence length for recurrent architectures. Defaults to 5.
87+
patience (int, optional): Number of epochs with no improvement before stopping. Defaults to 10.
88+
verbose (bool, optional): Whether to print training progress. Defaults to True.
89+
90+
Raises:
91+
Warning: If the model has already been trained.
4892
"""
93+
4994

5095
if early_stopping is None:
5196
early_stopping = X_val is not None and y_val is not None
@@ -83,7 +128,23 @@ def fit(self, X, y, nf_epochs, de_epochs, batch_size=64, X_val=None, y_val=None,
83128
def forward(self, x, smooth_projection=False):
84129
"""
85130
Performs a forward pass through the hybrid emulator.
131+
132+
Args:
133+
x (array-like): Input data.
134+
smooth_projection (bool, optional): Whether to apply smoothing to projections. Defaults to False.
135+
136+
Returns:
137+
tuple: A tuple containing:
138+
- prediction (numpy.ndarray): Model predictions.
139+
- uncertainties (dict): Dictionary with keys:
140+
- 'total' (numpy.ndarray): Total uncertainty.
141+
- 'epistemic' (numpy.ndarray): Epistemic uncertainty.
142+
- 'aleatoric' (numpy.ndarray): Aleatoric uncertainty.
143+
144+
Raises:
145+
Warning: If the model has not been trained.
86146
"""
147+
87148
self.eval()
88149
x = to_tensor(x).to(self.device)
89150
if not self.trained:
@@ -102,6 +163,26 @@ def forward(self, x, smooth_projection=False):
102163
return prediction, uncertainties
103164

104165
def predict(self, x, output_scaler=True, smooth_projection=False):
166+
"""
167+
Makes predictions using the trained hybrid emulator.
168+
169+
Args:
170+
x (array-like): Input data.
171+
output_scaler (bool or str, optional): Path to the output scaler or whether to apply scaling. Defaults to True.
172+
smooth_projection (bool, optional): Whether to apply smoothing. Defaults to False.
173+
174+
Returns:
175+
tuple: A tuple containing:
176+
- unscaled_predictions (numpy.ndarray): Model predictions in the original scale.
177+
- uncertainties (dict): Dictionary with keys:
178+
- 'total' (numpy.ndarray): Total uncertainty.
179+
- 'epistemic' (numpy.ndarray): Epistemic uncertainty.
180+
- 'aleatoric' (numpy.ndarray): Aleatoric uncertainty.
181+
182+
Raises:
183+
Warning: If no scaler path is provided.
184+
"""
185+
105186
self.eval()
106187
if output_scaler is True:
107188
output_scaler = os.path.join(self.model_dir, "scaler_y.pkl")
@@ -134,8 +215,19 @@ def predict(self, x, output_scaler=True, smooth_projection=False):
134215

135216
def save(self, save_dir, input_features=None, output_scaler_path=None):
136217
"""
137-
Saves the trained model to the specified directory.
218+
Saves the trained model and related components to a specified directory.
219+
220+
Args:
221+
save_dir (str): Directory where the model should be saved.
222+
input_features (list, optional): List of input feature names. Defaults to None.
223+
output_scaler_path (str, optional): Path to the output scaler. Defaults to None.
224+
225+
Raises:
226+
ValueError: If the model has not been trained.
227+
ValueError: If `save_dir` is a file instead of a directory.
228+
ValueError: If `input_features` is not a list.
138229
"""
230+
139231
if not self.trained:
140232
raise ValueError("This model has not been trained yet. Train the model before saving.")
141233
if save_dir.endswith(".pth"):
@@ -160,8 +252,20 @@ def save(self, save_dir, input_features=None, output_scaler_path=None):
160252
@staticmethod
161253
def load(model_dir=None, deep_ensemble_path=None, normalizing_flow_path=None,):
162254
"""
163-
Loads a trained model from the specified paths.
255+
Loads a trained ISEFlow model from specified paths.
256+
257+
Args:
258+
model_dir (str, optional): Directory containing the saved model. Defaults to None.
259+
deep_ensemble_path (str, optional): Path to the saved deep ensemble model. Defaults to None.
260+
normalizing_flow_path (str, optional): Path to the saved normalizing flow model. Defaults to None.
261+
262+
Returns:
263+
ISEFlow: The loaded ISEFlow model.
264+
265+
Raises:
266+
NotImplementedError: If an unsupported version is specified.
164267
"""
268+
165269

166270
if model_dir:
167271
deep_ensemble_path = os.path.join(model_dir, "deep_ensemble.pth")
@@ -178,14 +282,46 @@ def load(model_dir=None, deep_ensemble_path=None, normalizing_flow_path=None,):
178282

179283

180284
class ISEFlow_AIS(ISEFlow):
285+
"""
286+
ISEFlow_AIS is a specialized version of ISEFlow for the Antarctic Ice Sheet (AIS).
287+
288+
This subclass initializes the deep ensemble and normalizing flow models specifically
289+
for AIS and provides a loading method with pre-trained model paths.
290+
"""
291+
181292
def __init__(self,):
293+
"""
294+
Initializes the ISEFlow_AIS model.
295+
296+
Sets the ice sheet type to 'AIS' and initializes pre-configured deep ensemble
297+
and normalizing flow models specific to AIS.
298+
"""
299+
182300
self.ice_sheet = "AIS"
183301
deep_ensemble = ISEFlow_AIS_DE()
184302
normalizing_flow = ISEFlow_AIS_NF()
185303
super(ISEFlow_AIS, self).__init__(deep_ensemble, normalizing_flow)
186304

187305
@staticmethod
188306
def load(version="v1.0.0", model_dir=None, deep_ensemble_path=None, normalizing_flow_path=None):
307+
"""
308+
Loads a trained ISEFlow_AIS model.
309+
310+
Args:
311+
version (str, optional): Model version. Defaults to "v1.0.0".
312+
model_dir (str, optional): Directory of the saved model. Defaults to None.
313+
deep_ensemble_path (str, optional): Path to deep ensemble model. Defaults to None.
314+
normalizing_flow_path (str, optional): Path to normalizing flow model. Defaults to None.
315+
316+
Returns:
317+
ISEFlow_AIS: The loaded model.
318+
319+
Raises:
320+
NotImplementedError: If an unsupported version is specified.
321+
"""
322+
323+
# TODO: Add support for deep ensemble and normalizing flow paths
324+
189325
if model_dir is None:
190326
if version == "v1.0.0":
191327
model_dir = ISEFlow_AIS_v1_0_0_path
@@ -230,6 +366,39 @@ def process(
230366
standard_melt_type: str=None,
231367

232368
):
369+
"""
370+
Processes input data for prediction by applying necessary transformations and encoding.
371+
372+
Args:
373+
year (np.array): Years of the input data.
374+
pr_anomaly (np.array): Precipitation anomaly data.
375+
evspsbl_anomaly (np.array): Evaporation anomaly data.
376+
mrro_anomaly (np.array): Runoff anomaly data.
377+
smb_anomaly (np.array): Surface mass balance anomaly.
378+
ts_anomaly (np.array): Surface temperature anomaly.
379+
ocean_thermal_forcing (np.array): Ocean thermal forcing.
380+
ocean_salinity (np.array): Ocean salinity.
381+
ocean_temperature (np.array): Ocean temperature.
382+
initial_year (int): Initial year for modeling.
383+
numerics (str): Numerical scheme used.
384+
stress_balance (str): Stress balance model.
385+
resolution (int): Resolution of the model.
386+
init_method (str): Initialization method.
387+
melt_in_floating_cells (str): Melt treatment method.
388+
icefront_migration (str): Ice front migration scheme.
389+
ocean_forcing_type (str): Type of ocean forcing applied.
390+
ocean_sensitivity (str): Ocean sensitivity setting.
391+
ice_shelf_fracture (bool): Whether ice shelf fracture is considered.
392+
open_melt_type (str, optional): Type of open melt model. Defaults to None.
393+
standard_melt_type (str, optional): Type of standard melt model. Defaults to None.
394+
395+
Returns:
396+
pd.DataFrame: Processed input data ready for prediction.
397+
398+
Raises:
399+
ValueError: If any input arguments are invalid.
400+
"""
401+
233402

234403
if year[0] == 2015:
235404
year = year - 2015
@@ -458,6 +627,18 @@ def predict(
458627
open_melt_type: str=None,
459628
standard_melt_type: str=None,
460629
):
630+
"""
631+
Predicts ice sheet evolution using the trained ISEFlow_AIS model.
632+
633+
Args:
634+
(Same as process method)
635+
636+
Returns:
637+
tuple: A tuple containing:
638+
- unscaled_predictions (numpy.ndarray): Model predictions in the original scale.
639+
- uncertainties (dict): Dictionary containing different uncertainty components.
640+
"""
641+
461642

462643
data = self.process(
463644
year, pr_anomaly, evspsbl_anomaly, mrro_anomaly, smb_anomaly, ts_anomaly, ocean_thermal_forcing, ocean_salinity, ocean_temperature, initial_year, numerics, stress_balance, resolution, init_method, melt_in_floating_cells, icefront_migration, ocean_forcing_type, ocean_sensitivity, ice_shelf_fracture, open_melt_type, standard_melt_type
@@ -467,17 +648,39 @@ def predict(
467648
return super().predict(X, output_scaler=f"{ISEFlow_AIS_v1_0_0_path}/scaler_y.pkl")
468649

469650
class ISEFlow_GrIS(ISEFlow):
651+
"""
652+
ISEFlow_GrIS is a specialized version of ISEFlow for the Greenland Ice Sheet (GrIS).
653+
654+
This subclass initializes the deep ensemble and normalizing flow models specifically
655+
for GrIS and provides a loading method with pre-trained model paths.
656+
"""
657+
470658
def __init__(self,):
659+
"""
660+
Initializes the ISEFlow_GrIS model.
661+
662+
Sets the ice sheet type to 'GrIS' and initializes pre-configured deep ensemble
663+
and normalizing flow models specific to GrIS.
664+
"""
665+
471666
self.ice_sheet = "GrIS"
472667
deep_ensemble = ISEFlow_GrIS_DE()
473668
normalizing_flow = ISEFlow_GrIS_NF()
474669
super(ISEFlow_GrIS, self).__init__(deep_ensemble, normalizing_flow)
475670

476671
@staticmethod
477672
def load(version="v1.0.0", model_dir=None, deep_ensemble_path=None, normalizing_flow_path=None,):
673+
"""
674+
Loads a trained ISEFlow_GrIS model.
675+
676+
(Same arguments and return type as `ISEFlow_AIS.load`.)
677+
"""
678+
478679
if model_dir is None:
479680
if version == "v1.0.0":
480681
model_dir = ISEFlow_GrIS_v1_0_0_path
481682
else:
482683
raise NotImplementedError("Only version v1.0.0 is supported")
483684
return super(ISEFlow_GrIS, ISEFlow_GrIS).load(model_dir, deep_ensemble_path, normalizing_flow_path)
685+
686+
# TODO: ISEFlow GrIS process, predict

ise/models/ISEFlow/de.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,23 @@
44

55

66
class ISEFlow_AIS_DE(DeepEnsemble):
7+
"""
8+
ISEFlow Deep ensemble model for Antarctic Ice Sheet (AIS) emulation.
9+
10+
This class implements an ensemble of Long Short-Term Memory (LSTM) networks
11+
to predict ice sheet dynamics using deep learning. It extends the `DeepEnsemble`
12+
class and combines multiple LSTM models to enhance predictive performance.
13+
14+
Attributes:
15+
input_size (int): The number of input features, Defaults to 99.
16+
output_size (int): The number of output features, Defaults to 1.
17+
iseflow_ais_ensemble (list): A list of LSTM models with different architectures and loss functions.
18+
19+
Inherits from:
20+
DeepEnsemble: A base class for deep ensemble models.
21+
22+
"""
23+
724
def __init__(self, ):
825

926
self.input_size = 99
@@ -25,6 +42,22 @@ def __init__(self, ):
2542

2643

2744
class ISEFlow_GrIS_DE(DeepEnsemble):
45+
"""
46+
ISEFlow Deep ensemble model for Greenland Ice Sheet (GrIS) emulation.
47+
48+
This class constructs an ensemble of LSTM models to predict ice sheet behavior
49+
for the Greenland Ice Sheet (GrIS). It extends the `DeepEnsemble` framework
50+
and integrates multiple LSTM-based predictors to improve accuracy.
51+
52+
Attributes:
53+
input_size (int): The number of input features (90).
54+
output_size (int): The number of output features (1).
55+
iseflow_gris_ensemble (list): A list of LSTM models with varying architectures and loss functions.
56+
57+
Inherits from:
58+
DeepEnsemble: A base class for deep ensemble models.
59+
"""
60+
2861
def __init__(self,):
2962
self.input_size = 90
3063
self.output_size = 1

0 commit comments

Comments
 (0)