Skip to content

Commit cd04ddc

Browse files
committed
checkpointer logging, script fixes post-refactor
1 parent 69c7b33 commit cd04ddc

File tree

9 files changed

+136
-211
lines changed

9 files changed

+136
-211
lines changed

ise/models/density_estimators/normalizing_flow.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -87,32 +87,43 @@ def fit(self, X, y, epochs=100, batch_size=64, save_checkpoints=True, checkpoint
8787
checkpointer = CheckpointSaver(self, self.optimizer, checkpoint_path, verbose)
8888
checkpointer.best_loss = best_loss
8989

90-
for epoch in range(start_epoch, epochs + 1):
91-
epoch_loss = []
92-
for i, (x, y) in enumerate(data_loader):
93-
x = x.to(self.device).view(x.shape[0], -1)
94-
y = y.to(self.device)
95-
self.optimizer.zero_grad()
96-
loss = torch.mean(-self.flow.log_prob(inputs=y, context=x))
97-
loss.backward()
98-
self.optimizer.step()
99-
epoch_loss.append(loss.item())
100-
average_epoch_loss = sum(epoch_loss) / len(epoch_loss)
101-
102-
if save_checkpoints:
103-
checkpointer(average_epoch_loss)
104-
if hasattr(checkpointer, "early_stop") and checkpointer.early_stop:
105-
if verbose:
106-
print("Early stopping")
107-
break
108-
90+
if start_epoch < epochs:
91+
for epoch in range(start_epoch, epochs + 1):
92+
epoch_loss = []
93+
for i, (x, y) in enumerate(data_loader):
94+
x = x.to(self.device).view(x.shape[0], -1)
95+
y = y.to(self.device)
96+
self.optimizer.zero_grad()
97+
loss = torch.mean(-self.flow.log_prob(inputs=y, context=x))
98+
loss.backward()
99+
self.optimizer.step()
100+
epoch_loss.append(loss.item())
101+
average_epoch_loss = sum(epoch_loss) / len(epoch_loss)
102+
103+
if save_checkpoints:
104+
checkpointer(average_epoch_loss, epoch)
105+
if hasattr(checkpointer, "early_stop") and checkpointer.early_stop:
106+
if verbose:
107+
print("Early stopping")
108+
break
109+
110+
if verbose:
111+
print(f"[epoch/total]: [{epoch}/{epochs}], loss: {average_epoch_loss}{f' -- {checkpointer.log}' if save_checkpoints else ''}")
112+
else:
109113
if verbose:
110-
print(f"[epoch/total]: [{epoch}/{epochs}], loss: {average_epoch_loss}{f' -- {checkpointer.log}' if early_stopping else ''}")
111-
114+
print(f"Training already completed ({epochs}/{epochs}).")
115+
112116
self.trained = True
113117

114-
if early_stopping:
115-
self.load_state_dict(torch.load(checkpoint_path))
118+
if save_checkpoints:
119+
checkpoint = torch.load(checkpoint_path)
120+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint.keys():
121+
self.load_state_dict(checkpoint['model_state_dict'])
122+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
123+
self.best_loss = checkpoint['best_loss']
124+
self.epochs_trained = checkpoint['epoch']
125+
else:
126+
self.load_state_dict(checkpoint)
116127
os.remove(checkpoint_path)
117128

118129
def sample(self, features, num_samples, return_type="numpy"):
@@ -150,7 +161,9 @@ def save(self, path):
150161
metadata = {
151162
"input_size": self.num_input_features,
152163
"output_size": self.num_predicted_sle,
153-
"device": self.device
164+
"device": self.device,
165+
"best_loss": self.best_loss,
166+
"epochs_trained": self.epochs_trained,
154167
}
155168
metadata_path = path + "_metadata.json"
156169

@@ -177,11 +190,14 @@ def load(path):
177190

178191
checkpoint = torch.load(path, map_location="cpu" if not torch.cuda.is_available() else None)
179192

180-
#
181-
# model.load_state_dict(checkpoint['model_state_dict'])
182-
model.load_state_dict(checkpoint)
183-
# model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
184-
# model.trained = checkpoint['trained']
193+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint.keys():
194+
model.load_state_dict(checkpoint['model_state_dict'])
195+
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
196+
model.trained = checkpoint['trained']
197+
else:
198+
model.load_state_dict(checkpoint)
199+
model.trained = True
200+
185201
model.trained = True
186202
model.to(model.device)
187203
model.eval()

ise/models/predictors/deep_ensemble.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def save(self, model_path):
109109
"output_size": member.output_size,
110110
"trained": member.trained,
111111
"path": os.path.join("ensemble_members", f"member_{i+1}.pth"),
112+
"best_loss": float(member.best_loss),
113+
"epochs_trained": int(member.epochs_trained),
112114
}
113115
for i, member in enumerate(self.ensemble_members)
114116
],
@@ -129,6 +131,10 @@ def save(self, model_path):
129131
member_path = os.path.join(ensemble_dir, f"member_{i+1}.pth")
130132
torch.save(member.state_dict(), member_path)
131133
print(f"Ensemble Member {i+1} saved to {member_path}")
134+
135+
print('Removing checkpoints after saving to model directory...')
136+
[os.remove(member.checkpoint_path) for member in self.ensemble_members if hasattr(member, "checkpoint_path")]
137+
132138

133139
@classmethod
134140
def load(cls, model_path):

ise/models/predictors/lstm.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def fit(
8282
# Check if a checkpoint exists and load it
8383
start_epoch = 1
8484
best_loss = float("inf")
85+
self.checkpoint_path = checkpoint_path
8586
if os.path.exists(checkpoint_path):
8687
checkpoint = torch.load(checkpoint_path)
8788
self.load_state_dict(checkpoint['model_state_dict'])
@@ -134,47 +135,58 @@ def fit(
134135
checkpointer.best_loss = best_loss
135136

136137
# Training loop
137-
for epoch in range(start_epoch, epochs + 1):
138-
self.train()
139-
batch_losses = []
140-
for i, (x, y) in enumerate(data_loader):
141-
x = x.to(self.device)
142-
y = y.to(self.device)
143-
self.optimizer.zero_grad()
144-
y_pred = self.forward(x)
145-
loss = self.criterion(y_pred, y) # Renamed to 'loss' for clarity
146-
loss.backward()
147-
self.optimizer.step()
148-
batch_losses.append(loss.item())
149-
150-
# Print average batch loss and validation loss (if provided)
151-
if validate:
152-
val_preds = self.predict(
153-
X_val, sequence_length=sequence_length, batch_size=batch_size
154-
).to(self.device)
155-
val_loss = F.mse_loss(val_preds.squeeze(), y_val.squeeze())
156-
157-
if save_checkpoints:
158-
checkpointer(val_loss)
159-
160-
if hasattr(checkpointer, "early_stop") and checkpointer.early_stop:
161-
if verbose:
162-
print("Early stopping")
163-
break
164-
165-
if verbose:
166-
print(f"[epoch/total]: [{epoch}/{epochs}], train loss: {sum(batch_losses) / len(batch_losses)}, val mse: {val_loss:.6f} -- {getattr(checkpointer, 'log', '')}")
167-
else:
168-
average_batch_loss = sum(batch_losses) / len(batch_losses)
169-
if verbose:
170-
print(f"[epoch/total]: [{epoch}/{epochs}], train loss: {average_batch_loss}")
138+
if start_epoch < epochs:
139+
for epoch in range(start_epoch, epochs + 1):
140+
self.train()
141+
batch_losses = []
142+
for i, (x, y) in enumerate(data_loader):
143+
x = x.to(self.device)
144+
y = y.to(self.device)
145+
self.optimizer.zero_grad()
146+
y_pred = self.forward(x)
147+
loss = self.criterion(y_pred, y) # Renamed to 'loss' for clarity
148+
loss.backward()
149+
self.optimizer.step()
150+
batch_losses.append(loss.item())
151+
152+
# Print average batch loss and validation loss (if provided)
153+
if validate:
154+
val_preds = self.predict(
155+
X_val, sequence_length=sequence_length, batch_size=batch_size
156+
).to(self.device)
157+
val_loss = F.mse_loss(val_preds.squeeze(), y_val.squeeze())
158+
159+
if save_checkpoints:
160+
checkpointer(val_loss, epoch)
161+
162+
if hasattr(checkpointer, "early_stop") and checkpointer.early_stop:
163+
if verbose:
164+
print("Early stopping")
165+
break
166+
167+
if verbose:
168+
print(f"[epoch/total]: [{epoch}/{epochs}], train loss: {sum(batch_losses) / len(batch_losses)}, val mse: {val_loss:.6f} -- {getattr(checkpointer, 'log', '')}")
169+
else:
170+
average_batch_loss = sum(batch_losses) / len(batch_losses)
171+
if verbose:
172+
print(f"[epoch/total]: [{epoch}/{epochs}], train loss: {average_batch_loss}")
173+
else:
174+
if verbose:
175+
print(f"Training already completed ({epochs}/{epochs}).")
171176

172177
self.trained = True
173178

174179
# loads best model
175180
if save_checkpoints:
176-
self.load_state_dict(torch.load(checkpoint_path))
177-
os.remove(checkpoint_path)
181+
checkpoint = torch.load(checkpoint_path)
182+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint.keys():
183+
self.load_state_dict(checkpoint['model_state_dict'])
184+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
185+
self.best_loss = checkpoint['best_loss']
186+
self.epochs_trained = checkpoint['epoch']
187+
else:
188+
self.load_state_dict(checkpoint)
189+
# os.remove(checkpoint_path)
178190

179191
def predict(self, X, sequence_length=5, batch_size=64, dataclass=EmulatorDataset):
180192
self.eval()

ise/utils/training.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,19 @@ def __init__(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, che
88
self.optimizer = optimizer
99
self.best_loss = float('inf')
1010
self.verbose = verbose
11+
self.log = None
1112

12-
def __call__(self, loss, epoch, save_best_only=True, path=None):
13+
def __call__(self, loss, epoch, save_best_only=True,):
1314
is_better = self._determine_if_better(loss) if save_best_only else True
1415

1516
if is_better or not save_best_only: # Save if loss improves or save_best_only is False
16-
self.save_checkpoint(epoch, loss, path)
17+
self.save_checkpoint(epoch, loss, self.checkpoint_path)
1718
if self.verbose:
18-
print(f"Loss decreased ({self.best_loss:.6f} --> {loss:.6f}). Saving checkpoint.")
19+
self.log = f"Loss decreased ({self.best_loss:.6f} --> {loss:.6f}). Saving checkpoint to {self.checkpoint_path}."
1920
self._update_best_loss(loss)
2021
return True
22+
else:
23+
self.log = ""
2124
return False
2225

2326
def _determine_if_better(self, loss: float):
@@ -36,8 +39,8 @@ def save_checkpoint(self, epoch, loss, path: str = None):
3639
'best_loss': self.best_loss,
3740
}
3841
torch.save(checkpoint, checkpoint_path)
39-
if self.verbose:
40-
print(f"Checkpoint saved to {checkpoint_path}")
42+
# if self.verbose:
43+
# print(f"Checkpoint saved to {checkpoint_path}")
4144

4245
def load_checkpoint(self, path: str = None):
4346
checkpoint_path = path or self.checkpoint_path
@@ -57,8 +60,8 @@ def __init__(self, model, optimizer, checkpoint_path='checkpoint.pt', patience=1
5760
self.counter = 0
5861
self.early_stop = False
5962

60-
def __call__(self, loss, epoch, save_best_only=True, path=None):
61-
saved = super().__call__(loss, epoch, save_best_only, path)
63+
def __call__(self, loss, epoch, save_best_only=True,):
64+
saved = super().__call__(loss, epoch, save_best_only,)
6265
if saved:
6366
self.counter = 0 # Reset counter if the model improved
6467
else:

manuscripts/ISEFlow/scripts/get_best_nn.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,20 @@ def get_best_nn(data_directory, export_directory, iterations=10, with_chars=True
3131
X_val_df, _ = f.get_X_y(pd.read_csv(f"{data_directory}/val.csv"), 'sectors', return_format='pandas', with_chars=with_chars)
3232
X_val, y_val = f.get_X_y(pd.read_csv(f"{data_directory}/val.csv"), 'sectors', return_format='numpy', with_chars=with_chars)
3333
cur_time = time.time()
34-
de = DeepEnsemble(num_predictors=num_predictors, forcing_size=X_train.shape[1], )
35-
nf = NormalizingFlow(forcing_size=X_train.shape[1])
34+
de = DeepEnsemble(num_ensemble_members=num_predictors, input_size=X_train.shape[1], )
35+
nf = NormalizingFlow(input_size=X_train.shape[1])
3636
emulator = ISEFlow(de, nf)
3737

3838
nf_epochs = 100
3939
de_epochs = 100
4040
train_time_start = time.time()
4141
print('\n\nTraining model with ', num_predictors, 'predictors,', nf_epochs, 'NF epochs, and', de_epochs, 'DE epochs')
42-
emulator.fit(X_train, y_train, X_val=X_val, y_val=y_val, early_stopping=True, patience=20, delta=1e-5, nf_epochs=nf_epochs, de_epochs=de_epochs, early_stopping_path=f"checkpoint_{ice_sheet}")
42+
emulator.fit(
43+
X_train, y_train, X_val=X_val, y_val=y_val,
44+
save_checkpoints=True, checkpoint_path=f"checkpoint_{ice_sheet}",
45+
early_stopping=True, patience=20,
46+
nf_epochs=nf_epochs, de_epochs=de_epochs,
47+
)
4348
train_time_end = time.time()
4449
total_train_time = (train_time_end - train_time_start) / 60.0
4550

@@ -70,7 +75,7 @@ def get_best_nn(data_directory, export_directory, iterations=10, with_chars=True
7075
if __name__ == '__main__':
7176
ICE_SHEET = 'GrIS'
7277
WITH_CHARS = True
73-
ITERATIONS = 10
78+
ITERATIONS = 1
7479
DATA_DIRECTORY = f'/oscar/home/pvankatw/data/pvankatw/pvankatw-bfoxkemp/ISEFlow/data/ml/{ICE_SHEET}/'
7580
EXPORT_DIRECTORY = f'/oscar/home/pvankatw/data/pvankatw/pvankatw-bfoxkemp/ISEFlow/models/all_variables/{"with_characteristics" if WITH_CHARS else "without_characteristics"}/{ICE_SHEET}/'
7681
get_best_nn(DATA_DIRECTORY, EXPORT_DIRECTORY, iterations=ITERATIONS, with_chars=WITH_CHARS)

manuscripts/ISEFlow/scripts/get_best_onlytemp_nn.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def get_optimal_temponly_model(ice_sheet, out_dir, iterations=10, with_chars=Fal
8282
cur_time = time.time()
8383

8484
# Initialize the model
85-
de = DeepEnsemble(num_predictors=num_predictors, forcing_size=X_train.shape[1])
86-
nf = NormalizingFlow(forcing_size=X_train.shape[1])
85+
de = DeepEnsemble(num_ensemble_members=num_predictors, input_size=X_train.shape[1])
86+
nf = NormalizingFlow(input_size=X_train.shape[1])
8787
emulator = ISEFlow(de, nf)
8888

8989
# Randomly choose epochs for normalizing flow and deep ensemble training
@@ -94,10 +94,10 @@ def get_optimal_temponly_model(ice_sheet, out_dir, iterations=10, with_chars=Fal
9494
train_time_start = time.time()
9595
print(f"\n\nTraining model with {num_predictors} predictors, {nf_epochs} NF epochs, and {de_epochs} DE epochs")
9696
emulator.fit(
97-
X_train, y_train, X_val, y_val,
98-
early_stopping=True, patience=10, delta=1e-5,
97+
X_train, y_train, X_val=X_val, y_val=y_val,
98+
save_checkpoints=True, checkpoint_path=f"{ice_sheet}_onlysmb_checkpoint.pt",
99+
early_stopping=True, patience=10,
99100
nf_epochs=nf_epochs, de_epochs=de_epochs,
100-
early_stopping_path=f"{ice_sheet}_onlysmb_checkpoint.pt"
101101
)
102102
train_time_end = time.time()
103103
total_train_time = (train_time_end - train_time_start) / 60.0
@@ -147,9 +147,16 @@ def get_optimal_temponly_model(ice_sheet, out_dir, iterations=10, with_chars=Fal
147147
with_chars = False
148148

149149
# Call the main function to start the model training process
150+
# get_optimal_temponly_model(
151+
# ice_sheet,
152+
# f'/oscar/home/pvankatw/data/pvankatw/pvankatw-bfoxkemp/ISEFlow/models/isolated_variables/SMB_only/{ice_sheet}/',
153+
# iterations=iterations,
154+
# with_chars=with_chars
155+
# )
156+
150157
get_optimal_temponly_model(
151158
ice_sheet,
152-
f'/oscar/home/pvankatw/data/pvankatw/pvankatw-bfoxkemp/ISEFlow/models/isolated_variables/SMB_only/{ice_sheet}/',
159+
f'/users/pvankatw/research/ise/delete/',
153160
iterations=iterations,
154161
with_chars=with_chars
155162
)

0 commit comments

Comments
 (0)