1515
1616import numpy as np
1717import os
18+ import tifffile
1819import torch
1920import torch .nn as nn
2021import torch .optim as optim
@@ -30,6 +31,7 @@ def __init__(
3031 self ,
3132 output_dir ,
3233 batch_size = 8 ,
34+ device = "cuda:0" ,
3335 lr = 1e-3 ,
3436 max_epochs = 200 ,
3537 ):
@@ -42,10 +44,12 @@ def __init__(
4244 Directory that model checkpoints and tensorboard are written to.
4345 batch_size : int, optional
4446 Number of samples per batch during training. Default is 32.
45- lr : float
46- Learning rate.
47- max_epochs : int
48- Maximum number of training epochs.
47+ device : str, optional
48+ GPU device that model is trained on. Default is "cuda:0".
49+ lr : float, optional
50+ Learning rate. Default is 1e-3.
51+ max_epochs : int, optional
52+ Maximum number of training epochs. Default is 200.
4953 """
5054 # Initializations
5155 exp_name = "session-" + datetime .today ().strftime ("%Y%m%d_%H%M" )
@@ -133,6 +137,24 @@ def train_step(self, train_dataloader, epoch):
133137 return np .mean (losses )
134138
135139 def validate_step (self , val_dataloader , epoch ):
140+ """
141+ Validates the model over the provided DataLoader.
142+
143+ Parameters
144+ ----------
145+ val_dataloader : torch.utils.data.DataLoader
146+ DataLoader for the validation dataset.
147+ epoch : int
148+ Current training epoch.
149+
150+ Returns
151+ -------
152+ tuple
153+ A tuple containing the following:
154+ - float: Average loss over the validation dataset.
155+ - float: Average compression ratio over the validation dataset.
156+ - bool: Indication of whether the model is the best so far.
157+ """
136158 losses = list ()
137159 cratios = list ()
138160 with torch .no_grad ():
@@ -190,8 +212,23 @@ def compute_cratios(self, imgs, mn_mx):
190212 mn , mx = tuple (mn_mx [i , :])
191213 img = imgs [i , 0 , ...] * (mx - mn ) + mn
192214 cratios .append (img_util .compute_cratio (img , self .codec ))
215+ if i < 10 :
216+ tifffile .imwrite (f"{ i } .tiff" , img )
193217 return cratios
194218
219+ def load_pretrained_weights (self , model_path ):
220+ """
221+ Loads a pretrained model weights from a checkpoint file.
222+
223+ Parameters
224+ ----------
225+ model_path : str
226+ Path to the checkpoint file containing the saved weights.
227+ """
228+ self .model .load_state_dict (
229+ torch .load (model_path , map_location = device )
230+ )
231+
195232 def save_model (self , epoch ):
196233 """
197234 Saves the current model state to a file.
@@ -202,6 +239,6 @@ def save_model(self, epoch):
202239 Current training epoch.
203240 """
204241 date = datetime .today ().strftime ("%Y%m%d" )
205- filename = f"BM4DNet-{ date } -{ epoch } -{ round ( self .best_l1 , 4 ) } .pth"
242+ filename = f"BM4DNet-{ date } -{ epoch } -{ self .best_l1 :.4f } .pth"
206243 path = os .path .join (self .log_dir , filename )
207244 torch .save (self .model .state_dict (), path )
0 commit comments