Skip to content

Commit 7a1a0ca

Browse files
authored
minor updates
1 parent f052aa9 commit 7a1a0ca

File tree

1 file changed

+36
-9
lines changed

1 file changed

+36
-9
lines changed

src/aind_exaspim_image_compression/inference.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from numcodecs import blosc
1818
from tqdm import tqdm
1919

20+
import itertools
2021
import numpy as np
2122
import torch
2223

@@ -63,26 +64,27 @@ def predict(
6364
img = img[np.newaxis, ...]
6465

6566
# Initializations
66-
starts = generate_patch_starts(img, patch_size, overlap)
67+
starts_generator = generate_patch_starts(img, patch_size, overlap)
68+
n_starts = count_patches(img, patch_size, overlap)
6769
if denoised is None:
6870
denoised = np.zeros_like(img, dtype=np.uint16)
6971

7072
# Main
71-
pbar = tqdm(total=len(starts), desc="Denoise") if verbose else None
72-
for i in range(0, len(starts), batch_size):
73+
pbar = tqdm(total=n_starts, desc="Denoise") if verbose else None
74+
for i in range(0, n_starts, batch_size):
7375
# Run model
74-
starts_i = starts[i:min(i + batch_size, len(starts))]
75-
patches_i = _predict_batch(img, model, starts_i, patch_size, trim)
76+
starts = list(itertools.islice(starts_generator, batch_size))
77+
patches = _predict_batch(img, model, starts, patch_size, trim)
7678

7779
# Store result
78-
for patch, start in zip(patches_i, starts_i):
80+
for patch, start in zip(patches, starts):
7981
start = [max(s + trim, 0) for s in start]
8082
end = [start[i] + patch.shape[i] for i in range(3)]
8183
end = [min(e, s) for e, s in zip(end, img.shape[2:])]
8284
denoised[
8385
0, 0, start[0]:end[0], start[1]:end[1], start[2]:end[2]
8486
] = patch[: end[0] - start[0], : end[1] - start[1], : end[2] - start[2]]
85-
pbar.update(len(starts_i)) if verbose else None
87+
pbar.update(len(starts)) if verbose else None
8688
return denoised
8789

8890

@@ -220,8 +222,33 @@ def generate_patch_starts(img, patch_size, overlap):
220222
for i in range(0, img.shape[2] - patch_size + stride, stride):
221223
for j in range(0, img.shape[3] - patch_size + stride, stride):
222224
for k in range(0, img.shape[4] - patch_size + stride, stride):
223-
coords.append((i, j, k))
224-
return coords
225+
yield (i, j, k)
226+
227+
228+
def count_patches(img, patch_size, overlap):
229+
"""
230+
Counts the number of patches within a 3D image for a given patch size
231+
and overlap between patches.
232+
233+
Parameters
234+
----------
235+
img : torch.Tensor or numpy.ndarray
236+
Input image tensor with shape (batch, channels, depth, height, width).
237+
patch_size : int
238+
The size of each cubic patch along each spatial dimension.
239+
overlap : int
240+
Number of voxels that adjacent patches overlap.
241+
242+
Returns
243+
-------
244+
int
245+
Number of patches.
246+
"""
247+
stride = patch_size - overlap
248+
d_range = range(0, img.shape[2] - patch_size + stride, stride)
249+
h_range = range(0, img.shape[3] - patch_size + stride, stride)
250+
w_range = range(0, img.shape[4] - patch_size + stride, stride)
251+
return len(d_range) * len(h_range) * len(w_range)
225252

226253

227254
def load_model(path, device="cuda"):

0 commit comments

Comments
 (0)