|
17 | 17 | from numcodecs import blosc |
18 | 18 | from tqdm import tqdm |
19 | 19 |
|
| 20 | +import itertools |
20 | 21 | import numpy as np |
21 | 22 | import torch |
22 | 23 |
|
@@ -63,26 +64,27 @@ def predict( |
63 | 64 | img = img[np.newaxis, ...] |
64 | 65 |
|
65 | 66 | # Initializations |
66 | | - starts = generate_patch_starts(img, patch_size, overlap) |
| 67 | + starts_generator = generate_patch_starts(img, patch_size, overlap) |
| 68 | + n_starts = count_patches(img, patch_size, overlap) |
67 | 69 | if denoised is None: |
68 | 70 | denoised = np.zeros_like(img, dtype=np.uint16) |
69 | 71 |
|
70 | 72 | # Main |
71 | | - pbar = tqdm(total=len(starts), desc="Denoise") if verbose else None |
72 | | - for i in range(0, len(starts), batch_size): |
| 73 | + pbar = tqdm(total=n_starts, desc="Denoise") if verbose else None |
| 74 | + for i in range(0, n_starts, batch_size): |
73 | 75 | # Run model |
74 | | - starts_i = starts[i:min(i + batch_size, len(starts))] |
75 | | - patches_i = _predict_batch(img, model, starts_i, patch_size, trim) |
| 76 | + starts = list(itertools.islice(starts_generator, batch_size)) |
| 77 | + patches = _predict_batch(img, model, starts, patch_size, trim) |
76 | 78 |
|
77 | 79 | # Store result |
78 | | - for patch, start in zip(patches_i, starts_i): |
| 80 | + for patch, start in zip(patches, starts): |
79 | 81 | start = [max(s + trim, 0) for s in start] |
80 | 82 | end = [start[i] + patch.shape[i] for i in range(3)] |
81 | 83 | end = [min(e, s) for e, s in zip(end, img.shape[2:])] |
82 | 84 | denoised[ |
83 | 85 | 0, 0, start[0]:end[0], start[1]:end[1], start[2]:end[2] |
84 | 86 | ] = patch[: end[0] - start[0], : end[1] - start[1], : end[2] - start[2]] |
85 | | - pbar.update(len(starts_i)) if verbose else None |
| 87 | + pbar.update(len(starts)) if verbose else None |
86 | 88 | return denoised |
87 | 89 |
|
88 | 90 |
|
@@ -220,8 +222,33 @@ def generate_patch_starts(img, patch_size, overlap): |
220 | 222 | for i in range(0, img.shape[2] - patch_size + stride, stride): |
221 | 223 | for j in range(0, img.shape[3] - patch_size + stride, stride): |
222 | 224 | for k in range(0, img.shape[4] - patch_size + stride, stride): |
223 | | - coords.append((i, j, k)) |
224 | | - return coords |
| 225 | + yield (i, j, k) |
| 226 | + |
| 227 | + |
| 228 | +def count_patches(img, patch_size, overlap): |
| 229 | + """ |
| 230 | + Counts the number of patches within a 3D image for a given patch size |
| 231 | + and overlap between patches. |
| 232 | +
|
| 233 | + Parameters |
| 234 | + ---------- |
| 235 | + img : torch.Tensor or numpy.ndarray |
| 236 | + Input image tensor with shape (batch, channels, depth, height, width). |
| 237 | + patch_size : int |
| 238 | + The size of each cubic patch along each spatial dimension. |
| 239 | + overlap : int |
| 240 | + Number of voxels that adjacent patches overlap. |
| 241 | +
|
| 242 | + Returns |
| 243 | + ------- |
| 244 | + int |
| 245 | + Number of patches. |
| 246 | + """ |
| 247 | + stride = patch_size - overlap |
| 248 | + d_range = range(0, img.shape[2] - patch_size + stride, stride) |
| 249 | + h_range = range(0, img.shape[3] - patch_size + stride, stride) |
| 250 | + w_range = range(0, img.shape[4] - patch_size + stride, stride) |
| 251 | + return len(d_range) * len(h_range) * len(w_range) |
225 | 252 |
|
226 | 253 |
|
227 | 254 | def load_model(path, device="cuda"): |
|
0 commit comments