Skip to content

Commit 8cf13ce

Browse files
authored
bug: fixed compression computation
1 parent fa60e6d commit 8cf13ce

File tree

3 files changed

+164
-22
lines changed

3 files changed

+164
-22
lines changed

src/aind_exaspim_image_compression/machine_learning/data_handling.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def __getitem__(self, idx):
226226
def get_patch(self, brain_id, voxel):
227227
s, e = img_util.get_start_end(voxel, self.patch_shape)
228228
patch = self.imgs[brain_id][0, 0, s[0]: e[0], s[1]: e[1], s[2]: e[2]]
229-
return patch / np.percentile(patch, 99.9)
229+
return patch #/ np.percentile(patch, 99.9)
230230

231231

232232
# --- Custom Dataloaders ---
@@ -338,10 +338,10 @@ def _load_batch(self, dummy_input):
338338
noise_patches = np.zeros((self.batch_size, 1,) + self.patch_shape)
339339
clean_patches = np.zeros((self.batch_size, 1,) + self.patch_shape)
340340
for i, process in enumerate(as_completed(processes)):
341-
noise, clean = process.result()
341+
noise, clean, _ = process.result()
342342
noise_patches[i, 0, ...] = noise
343343
clean_patches[i, 0, ...] = clean
344-
return to_tensor(noise_patches), to_tensor(clean_patches)
344+
return to_tensor(noise_patches), to_tensor(clean_patches), None
345345

346346

347347
class ValidateN2VDataLoader(DataLoader):
@@ -415,11 +415,13 @@ def _load_batch(self, start_idx):
415415
# Process results
416416
noise_patches = np.zeros((self.batch_size, 1,) + self.patch_shape)
417417
clean_patches = np.zeros((self.batch_size, 1,) + self.patch_shape)
418+
mn_mxes = np.zeros((self.batch_size, 2))
418419
for i, process in enumerate(as_completed(processes)):
419-
noise, clean = process.result()
420+
noise, clean, mn_mx = process.result()
420421
noise_patches[i, 0, ...] = noise
421422
clean_patches[i, 0, ...] = clean
422-
return to_tensor(noise_patches), to_tensor(clean_patches)
423+
mn_mxes[i, :] = mn_mx
424+
return to_tensor(noise_patches), to_tensor(clean_patches), mn_mxes
423425

424426

425427
# --- Helpers ---
@@ -461,7 +463,6 @@ def init_datasets(
461463
brain_id = train_dataset.sample_brain()
462464
voxel = train_dataset.sample_voxel(brain_id)
463465
val_dataset.ingest_example(brain_id, voxel)
464-
465466
return train_dataset, val_dataset
466467

467468

src/aind_exaspim_image_compression/machine_learning/trainer.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
precision_score,
1717
recall_score,
1818
)
19+
from tifffile import imwrite
1920
from torch.optim.lr_scheduler import CosineAnnealingLR
2021
from torch.utils.tensorboard import SummaryWriter
2122

@@ -75,15 +76,18 @@ def run(self, train_dataset, val_dataset, n_upds=25):
7576
train_loss = self.train_step(train_dataloader, epoch)
7677
val_loss, val_cratio, new_best = self.validate_model(val_dataloader, epoch)
7778
if new_best:
79+
print(f"Epoch {epoch}: train_loss={train_loss}, val_loss={val_loss}, val_cratio={val_cratio} - New Best!")
80+
else:
7881
print(f"Epoch {epoch}: train_loss={train_loss}, val_loss={val_loss}, val_cratio={val_cratio}")
82+
7983

8084
# Step scheduler
8185
self.scheduler.step()
8286

8387
def train_step(self, train_dataloader, epoch):
8488
losses = list()
8589
self.model.train()
86-
for x_i, y_i in train_dataloader:
90+
for x_i, y_i, _ in train_dataloader:
8791
# Forward pass
8892
x_i, y_i = x_i.to("cuda"), y_i.to("cuda")
8993
hat_y_i = self.model(x_i)
@@ -100,37 +104,40 @@ def train_step(self, train_dataloader, epoch):
100104
return np.mean(losses)
101105

102106
def validate_model(self, val_dataloader, epoch):
103-
# Run model
104107
losses = list()
105108
cratios = list()
106109
self.model.eval()
107110
with torch.no_grad():
108-
for x_i, y_i in val_dataloader:
109-
x_i, y_i = x_i.to("cuda"), y_i.to("cuda")
110-
hat_y_i = self.model(x_i)
111-
loss = self.criterion(hat_y_i, y_i)
111+
for x, y, mn_mx in val_dataloader:
112+
# Run model
113+
x, y = x.to("cuda"), y.to("cuda")
114+
hat_y = self.model(x)
115+
loss = self.criterion(hat_y, y)
112116

113-
# Store loss for tensorboard
114-
cratios.extend(self.compute_cratios(hat_y_i))
115-
losses.append(loss.detach().cpu())
117+
# Evalute result
118+
cratios.extend(self.compute_cratios(hat_y, mn_mx))
119+
losses.append(loss.detach().cpu())
116120

117121
# Log results
118-
loss, cratio = np.median(losses), np.median(cratios)
122+
loss, cratio = np.mean(losses), np.mean(cratios)
119123
self.writer.add_scalar("val_loss", loss, epoch)
120-
self.writer.add_scalar("val_cratio", cratio, epoch) if cratio < 1000 else None
124+
self.writer.add_scalar("val_cratio", cratio, epoch)
121125
if loss < self.best_l1:
122126
self.save_model(epoch)
123127
self.best_l1 = loss
124128
return loss, cratio, True
125129
else:
126130
return loss, cratio, False
127131

128-
def compute_cratios(self, denoised_patches):
132+
def compute_cratios(self, imgs, mn_mx):
129133
cratios = list()
130-
denoised_patches = 1000 * np.array(denoised_patches.detach().cpu())
131-
for i in range(denoised_patches.shape[0]):
132-
patch = denoised_patches[i, 0, ...].astype(np.uint16)
133-
cratios.append(img_util.compute_cratio(patch, self.codec))
134+
imgs = np.array(imgs.detach().cpu())
135+
for i in range(imgs.shape[0]):
136+
mn, mx = tuple(mn_mx[i, :])
137+
img = (imgs[i, 0, ...] * mx + mn).astype(np.uint16)
138+
cratios.append(
139+
img_util.compute_cratio(img, self.codec)
140+
)
134141
return cratios
135142

136143
def save_model(self, epoch):
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Adapted from https://github.com/milesial/Pytorch-UNet/tree/master/unet"""
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
6+
7+
class UNet(nn.Module):
8+
def __init__(self, width_multiplier=1, trilinear=True, use_ds_conv=False):
9+
# Call parent class
10+
super(UNet, self).__init__()
11+
12+
# Initializations
13+
_channels = (32, 64, 128, 256, 512)
14+
conv_type = DepthwiseSeparableConv3d if use_ds_conv else nn.Conv3d
15+
factor = 2 if trilinear else 1
16+
17+
# Instance attributes
18+
self.channels = [int(c * width_multiplier) for c in _channels]
19+
self.trilinear = trilinear
20+
21+
# Contracting layers
22+
self.inc = DoubleConv(1, self.channels[0], conv_type=conv_type)
23+
self.down1 = Down(self.channels[0], self.channels[1], conv_type=conv_type)
24+
self.down2 = Down(self.channels[1], self.channels[2], conv_type=conv_type)
25+
self.down3 = Down(self.channels[2], self.channels[3], conv_type=conv_type)
26+
self.down4 = Down(self.channels[3], self.channels[4] // factor, conv_type=conv_type)
27+
28+
# Expanding layers
29+
self.up1 = Up(self.channels[4], self.channels[3] // factor, trilinear)
30+
self.up2 = Up(self.channels[3], self.channels[2] // factor, trilinear)
31+
self.up3 = Up(self.channels[2], self.channels[1] // factor, trilinear)
32+
self.up4 = Up(self.channels[1], self.channels[0], trilinear)
33+
self.outc = OutConv(self.channels[0], 1)
34+
35+
def forward(self, x):
36+
# Contracting layers
37+
x1 = self.inc(x)
38+
x2 = self.down1(x1)
39+
x3 = self.down2(x2)
40+
x4 = self.down3(x3)
41+
x5 = self.down4(x4)
42+
43+
# Expanding layers
44+
x = self.up1(x5, x4)
45+
x = self.up2(x, x3)
46+
x = self.up3(x, x2)
47+
x = self.up4(x, x1)
48+
logits = self.outc(x)
49+
return logits
50+
51+
52+
class DoubleConv(nn.Module):
53+
"""(convolution => [BN] => ReLU) * 2"""
54+
55+
def __init__(self, in_channels, out_channels, conv_type=nn.Conv3d, mid_channels=None):
56+
super().__init__()
57+
if not mid_channels:
58+
mid_channels = out_channels
59+
self.double_conv = nn.Sequential(
60+
conv_type(in_channels, mid_channels, kernel_size=3, padding=1),
61+
nn.BatchNorm3d(mid_channels),
62+
nn.ReLU(inplace=True),
63+
conv_type(mid_channels, out_channels, kernel_size=3, padding=1),
64+
nn.BatchNorm3d(out_channels),
65+
nn.ReLU(inplace=True)
66+
)
67+
68+
def forward(self, x):
69+
return self.double_conv(x)
70+
71+
72+
class Down(nn.Module):
73+
"""Downscaling with maxpool then double conv"""
74+
75+
def __init__(self, in_channels, out_channels, conv_type=nn.Conv3d):
76+
super().__init__()
77+
self.maxpool_conv = nn.Sequential(
78+
nn.MaxPool3d(2),
79+
DoubleConv(in_channels, out_channels, conv_type=conv_type)
80+
)
81+
82+
def forward(self, x):
83+
return self.maxpool_conv(x)
84+
85+
86+
class Up(nn.Module):
87+
"""Upscaling then double conv"""
88+
89+
def __init__(self, in_channels, out_channels, trilinear=True):
90+
super().__init__()
91+
92+
# if trilinear, use the normal convolutions to reduce the number of channels
93+
if trilinear:
94+
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
95+
self.conv = DoubleConv(in_channels, out_channels, mid_channels=in_channels // 2)
96+
else:
97+
self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
98+
self.conv = DoubleConv(in_channels, out_channels)
99+
100+
101+
def forward(self, x1, x2):
102+
x1 = self.up(x1)
103+
# input is CHW
104+
diffY = x2.size()[2] - x1.size()[2]
105+
diffX = x2.size()[3] - x1.size()[3]
106+
107+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
108+
diffY // 2, diffY - diffY // 2])
109+
# if you have padding issues, see
110+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
111+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
112+
x = torch.cat([x2, x1], dim=1)
113+
return self.conv(x)
114+
115+
116+
class OutConv(nn.Module):
117+
def __init__(self, in_channels, out_channels):
118+
super(OutConv, self).__init__()
119+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
120+
121+
def forward(self, x):
122+
return self.conv(x)
123+
124+
125+
class DepthwiseSeparableConv3d(nn.Module):
126+
def __init__(self, nin, nout, kernel_size, padding, kernels_per_layer=1):
127+
super(DepthwiseSeparableConv3d, self).__init__()
128+
self.depthwise = nn.Conv3d(nin, nin * kernels_per_layer, kernel_size=kernel_size, padding=padding, groups=nin)
129+
self.pointwise = nn.Conv3d(nin * kernels_per_layer, nout, kernel_size=1)
130+
131+
def forward(self, x):
132+
out = self.depthwise(x)
133+
out = self.pointwise(out)
134+
return out

0 commit comments

Comments
 (0)