Skip to content

Commit 5d91a0a

Browse files
authored
Merge pull request #14 from AllenNeuralDynamics/feat-load-pretrained
feat: load pretrained weights before training
2 parents c2a21bf + 2adcd2d commit 5d91a0a

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

src/aind_exaspim_image_compression/machine_learning/data_handling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def load_swcs(self, swc_pointer):
146146
# --- Sample Image Patches ---
147147
def __getitem__(self, dummy_input):
148148
"""
149-
Return a pair of noisy and BM4D-denoised image patches, normalized
149+
Return a pair of noisy and BM4D-denoised image patches, normalized
150150
according to percentile-based scaling.
151151
152152
Parameters

src/aind_exaspim_image_compression/machine_learning/train.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import numpy as np
1717
import os
18+
import tifffile
1819
import torch
1920
import torch.nn as nn
2021
import torch.optim as optim
@@ -30,6 +31,7 @@ def __init__(
3031
self,
3132
output_dir,
3233
batch_size=8,
34+
device="cuda:0",
3335
lr=1e-3,
3436
max_epochs=200,
3537
):
@@ -42,10 +44,12 @@ def __init__(
4244
Directory that model checkpoints and tensorboard are written to.
4345
batch_size : int, optional
4446
Number of samples per batch during training. Default is 32.
45-
lr : float
46-
Learning rate.
47-
max_epochs : int
48-
Maximum number of training epochs.
47+
device : str, optional
48+
GPU device that model is trained on. Default is "cuda:0".
49+
lr : float, optional
50+
Learning rate. Default is 1e-3.
51+
max_epochs : int, optional
52+
Maximum number of training epochs. Default is 200.
4953
"""
5054
# Initializations
5155
exp_name = "session-" + datetime.today().strftime("%Y%m%d_%H%M")
@@ -133,6 +137,24 @@ def train_step(self, train_dataloader, epoch):
133137
return np.mean(losses)
134138

135139
def validate_step(self, val_dataloader, epoch):
140+
"""
141+
Validates the model over the provided DataLoader.
142+
143+
Parameters
144+
----------
145+
val_dataloader : torch.utils.data.DataLoader
146+
DataLoader for the validation dataset.
147+
epoch : int
148+
Current training epoch.
149+
150+
Returns
151+
-------
152+
tuple
153+
A tuple containing the following:
154+
- float: Average loss over the validation dataset.
155+
- float: Average compression ratio over the validation dataset.
156+
- bool: Indication of whether the model is the best so far.
157+
"""
136158
losses = list()
137159
cratios = list()
138160
with torch.no_grad():
@@ -190,8 +212,23 @@ def compute_cratios(self, imgs, mn_mx):
190212
mn, mx = tuple(mn_mx[i, :])
191213
img = imgs[i, 0, ...] * (mx - mn) + mn
192214
cratios.append(img_util.compute_cratio(img, self.codec))
215+
if i < 10:
216+
tifffile.imwrite(f"{i}.tiff", img)
193217
return cratios
194218

219+
def load_pretrained_weights(self, model_path):
220+
"""
221+
Loads a pretrained model weights from a checkpoint file.
222+
223+
Parameters
224+
----------
225+
model_path : str
226+
Path to the checkpoint file containing the saved weights.
227+
"""
228+
self.model.load_state_dict(
229+
torch.load(model_path, map_location=device)
230+
)
231+
195232
def save_model(self, epoch):
196233
"""
197234
Saves the current model state to a file.
@@ -202,6 +239,6 @@ def save_model(self, epoch):
202239
Current training epoch.
203240
"""
204241
date = datetime.today().strftime("%Y%m%d")
205-
filename = f"BM4DNet-{date}-{epoch}-{round(self.best_l1, 4)}.pth"
242+
filename = f"BM4DNet-{date}-{epoch}-{self.best_l1:.4f}.pth"
206243
path = os.path.join(self.log_dir, filename)
207244
torch.save(self.model.state_dict(), path)

src/aind_exaspim_image_compression/utils/img_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _read_zarr(img_path):
8080
fs = gcsfs.GCSFileSystem(anon=False)
8181
store = zarr.storage.FSStore(img_path, fs=fs)
8282
elif _is_s3_path(img_path):
83-
fs = s3fs.S3FileSystem(config_kwargs={"max_pool_connections": 50})
83+
fs = s3fs.S3FileSystem(anon=True)
8484
store = s3fs.S3Map(root=img_path, s3=fs)
8585
else:
8686
store = zarr.DirectoryStore(img_path)

0 commit comments

Comments
 (0)