Skip to content

Commit a1a0cf5

Browse files
authored
Merge pull request #22 from AllenNeuralDynamics/refactor-simplify-inference
refactor: simplified inference
2 parents 82f425b + fb0cb8e commit a1a0cf5

File tree

4 files changed

+86
-118
lines changed

4 files changed

+86
-118
lines changed

src/aind_exaspim_image_compression/inference.py

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,11 @@
1818
import torch
1919

2020
from aind_exaspim_image_compression.machine_learning.unet3d import UNet
21-
from aind_exaspim_image_compression.utils import img_util
2221

2322

2423
def predict(
2524
img,
2625
model,
27-
denoised=None,
2826
batch_size=32,
2927
normalization_percentiles=(0.5, 99.9),
3028
patch_size=64,
@@ -64,59 +62,49 @@ def predict(
6462
"""
6563
# Preprocess image
6664
mn, mx = np.percentile(img, normalization_percentiles)
67-
img = (img - mn) / (mx - mn + 1e-5)
65+
img = (img - mn) / (mx - mn + 1e-8)
6866
img = np.clip(img, 0, 5)
6967
while len(img.shape) < 5:
7068
img = img[np.newaxis, ...]
7169

7270
# Initializations
7371
patch_starts_generator = generate_patch_starts(img, patch_size, overlap)
7472
n_starts = count_patches(img, patch_size, overlap)
75-
if denoised is None:
76-
denoised = np.zeros_like(img)
73+
pbar = tqdm(total=n_starts, desc="Denoise") if verbose else None
7774

7875
# Main
79-
pbar = tqdm(total=n_starts, desc="Denoise") if verbose else None
80-
for i in range(0, n_starts, batch_size):
76+
accum_pred = np.zeros(img.shape[2:])
77+
accum_wgt = np.zeros(img.shape[2:])
78+
for _ in range(0, n_starts, batch_size):
8179
# Extract batch and run model
8280
starts = list(itertools.islice(patch_starts_generator, batch_size))
8381
patches = _predict_batch(img, model, starts, patch_size, trim=trim)
8482

85-
# Store result
83+
# Add batch predictions to result
8684
for patch, start in zip(patches, starts):
87-
start = [max(s + trim, 0) for s in start]
88-
end = [start[i] + patch.shape[i] for i in range(3)]
89-
end = [min(e, s) for e, s in zip(end, img.shape[2:])]
90-
denoised[
91-
0, 0, start[0]:end[0], start[1]:end[1], start[2]:end[2]
92-
] = patch[: end[0] - start[0], : end[1] - start[1], : end[2] - start[2]]
85+
# Compute start and end coordinates
86+
s = [max(si + trim, 0) for si in start]
87+
e = [
88+
min(si + pi, di)
89+
for si, pi, di in zip(s, patch.shape, img.shape[2:])
90+
]
91+
92+
# Create slices
93+
pred_slices = tuple(slice(si, ei) for si, ei in zip(s, e))
94+
patch_slices = tuple(slice(0, ei - si) for si, ei in zip(s, e))
95+
96+
# Add patch prediction to result
97+
accum_pred[pred_slices] += patch[patch_slices]
98+
accum_wgt[pred_slices] += 1
99+
93100
pbar.update(len(starts)) if verbose else None
94101

95-
# Postprocess image
102+
# Postprocess prediction
103+
denoised = accum_pred[:, ...] / (accum_wgt + 1e-8)
96104
denoised = np.clip(denoised * (mx - mn) + mn, 0, 2**16 - 1)
97105
return denoised.astype(np.uint16)
98106

99107

100-
def predict_largescale(
101-
img,
102-
model,
103-
output_path,
104-
compressor,
105-
batch_size=32,
106-
normalization_percentiles=(0.5, 99.9),
107-
patch_size=64,
108-
overlap=12,
109-
output_chunks=(1, 1, 64, 128, 128),
110-
trim=5,
111-
verbose=True
112-
):
113-
# Initializations
114-
denoised = img_util.init_ome_zarr(
115-
img, output_path, compressor=compressor, chunks=output_chunks
116-
)
117-
predict(img, model, denoised=denoised)
118-
119-
120108
def predict_patch(patch, model, normalization_percentiles=(0.5, 99.9)):
121109
"""
122110
Denoises a single 3D patch using the provided model.
@@ -133,15 +121,15 @@ def predict_patch(patch, model, normalization_percentiles=(0.5, 99.9)):
133121
134122
Returns
135123
-------
136-
numpy.ndarray
124+
pred : numpy.ndarray
137125
Denoised 3D patch with the same shape as input patch.
138126
"""
139127
# Preprocess image
140128
mn, mx = np.percentile(patch, normalization_percentiles)
141-
patch = (patch - mn) / (mx - mn + 1e-5)
129+
patch = (patch - mn) / (mx - mn + 1e-8)
142130
patch = np.clip(patch, 0, 5)
143-
while len(img.shape) < 5:
144-
img = img[np.newaxis, ...]
131+
while len(patch.shape) < 5:
132+
patch = patch[np.newaxis, ...]
145133

146134
# Run model
147135
patch = to_tensor(patch)
@@ -269,7 +257,7 @@ def load_model(path, device="cuda"):
269257
270258
Returns
271259
-------
272-
torch.nn.Module
260+
model : torch.nn.Module
273261
UNet model loaded with weights and set to evaluation mode.
274262
"""
275263
model = UNet()

src/aind_exaspim_image_compression/machine_learning/data_handling.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ def __init__(
4444
patch_shape,
4545
anisotropy=(0.748, 0.748, 1.0),
4646
boundary_buffer=5000,
47-
foreground_sampling_rate=0.2,
47+
foreground_sampling_rate=0.3,
4848
min_brightness=200,
4949
n_examples_per_epoch=300,
5050
normalization_percentiles=(0.5, 99.9),
51-
prefetch_foreground_sampling=12,
52-
sigma_bm4d=10,
51+
prefetch_foreground_sampling=16,
52+
sigma_bm4d=16,
5353
):
5454
# Call parent class
5555
super(TrainDataset, self).__init__()
@@ -290,9 +290,9 @@ def sample_segmentation_voxel(self, brain_id):
290290
291291
Returns
292292
-------
293-
Tuple[int]
293+
best_voxel : Tuple[int]
294294
Voxel coordinate whose patch contains a sufficiently large object
295-
or had the largest object after 32 attempts.
295+
or had the largest object after 5 * self.prefetch attempts.
296296
"""
297297
cnt = 0
298298
best_voxel = self.sample_interior_voxel(brain_id)
@@ -330,8 +330,7 @@ def sample_segmentation_voxel(self, brain_id):
330330

331331
def sample_bright_voxel(self, brain_id):
332332
"""
333-
Samples a voxel coordinate whose surrounding image patch is
334-
sufficiently bright.
333+
Samples a voxel coordinate whose image patch is sufficiently bright.
335334
336335
Parameters
337336
----------
@@ -340,9 +339,9 @@ def sample_bright_voxel(self, brain_id):
340339
341340
Returns
342341
-------
343-
Tuple[int]
342+
brightest_voxel : Tuple[int]
344343
Voxel coordinate whose patch is sufficiently bright or is the
345-
highest observed brightness after 32 attempts.
344+
highest observed brightness after 5 * self.prefetch attempts.
346345
"""
347346
cnt = 0
348347
brightest_voxel = self.sample_interior_voxel(brain_id)
@@ -544,12 +543,13 @@ def __getitem__(self, idx):
544543
545544
Returns
546545
-------
547-
tuple
548-
A tuple containing:
549-
- noise (ndarray): Noisy image patch at the given index.
550-
- denoised (ndarray): Corresponding denoised image patch.
551-
- mn_mx (tuple): Minimum and maximum values used for normalization
552-
of the image patches.
546+
noise : numpy.ndarray
547+
Noisy image patch at the given index.
548+
denoised : numpy.ndarray
549+
Corresponding denoised image patch.
550+
mn_mx : Tuple[int]
551+
Minimum and maximum values used for normalization of the image
552+
patches.
553553
"""
554554
return self.noise[idx], self.denoised[idx], self.mn_mxs[idx]
555555

@@ -563,6 +563,15 @@ class DataLoader:
563563
"""
564564
DataLoader that uses multithreading to fetch image patches from the cloud
565565
to form batches.
566+
567+
Attributes
568+
----------
569+
dataset : torch.utils.data.Dataset
570+
Dataset to iterated over.
571+
batch_size : int
572+
Number of examples in each batch.
573+
patch_shape : Tuple[int]
574+
Shape of image patch expected by the model.
566575
"""
567576

568577
def __init__(self, dataset, batch_size=16):
@@ -629,7 +638,7 @@ def init_datasets(
629638
n_train_examples_per_epoch=100,
630639
n_validate_examples=0,
631640
segmentation_prefixes_path=None,
632-
sigma_bm4d=10,
641+
sigma_bm4d=16,
633642
swc_pointers=None
634643
):
635644
# Initializations

src/aind_exaspim_image_compression/machine_learning/train.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
self.batch_size = batch_size
6868
self.device = device
6969
self.max_epochs = max_epochs
70-
self.log_dir = log_dir
70+
self.log_dir = log_dir
7171

7272
self.codec = blosc.Blosc(cname="zstd", clevel=5, shuffle=blosc.SHUFFLE)
7373
self.criterion = nn.L1Loss()
@@ -162,11 +162,12 @@ def validate_step(self, val_dataloader, epoch):
162162
163163
Returns
164164
-------
165-
tuple
166-
A tuple containing the following:
167-
- float: Average loss over the validation dataset.
168-
- float: Average compression ratio over the validation dataset.
169-
- bool: Indication of whether the model is the best so far.
165+
loss : float
166+
Average loss over the validation dataset.
167+
cratio : float
168+
Average compression ratio over the validation dataset.
169+
is_best : bool
170+
Indication of whether the model is the best so far.
170171
"""
171172
losses = list()
172173
cratios = list()
@@ -186,12 +187,11 @@ def validate_step(self, val_dataloader, epoch):
186187
self.writer.add_scalar("val_cratio", cratio, epoch)
187188

188189
# Check if current model is best so far
189-
if loss < self.best_l1:
190+
is_best = True if loss < self.best_l1 else False
191+
if is_best:
190192
self.best_l1 = loss
191193
self.save_model(epoch)
192-
return loss, cratio, True
193-
else:
194-
return loss, cratio, False
194+
return loss, cratio, is_best
195195

196196
def forward_pass(self, x, y):
197197
"""
@@ -224,7 +224,7 @@ def compute_cratios(self, imgs, mn_mx):
224224
imgs = np.array(imgs.detach().cpu())
225225
for i in range(imgs.shape[0]):
226226
mn, mx = tuple(mn_mx[i, :])
227-
img = imgs[i, 0, ...] * (mx - mn) + mn
227+
img = np.clip(imgs[i, 0, ...] * (mx - mn) + mn, 0, 2**16 - 1)
228228
cratios.append(img_util.compute_cratio(img, self.codec))
229229
if i < 10:
230230
tifffile.imwrite(f"{i}.tiff", img)

0 commit comments

Comments
 (0)