Skip to content

Commit 7912e4f

Browse files
authored
refactor: optimized inference
1 parent 9ff78ca commit 7912e4f

File tree

2 files changed

+73
-105
lines changed

2 files changed

+73
-105
lines changed

src/aind_exaspim_image_compression/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def compute_metrics(
164164
noise = input_noise[5:-5, 5:-5, 5:-5]
165165
denoised_gt = np.maximum(bm4d(noise, 10), 0).astype(int)
166166
denoised = predict_patch(input_noise, self.model)[5:-5, 5:-5, 5:-5]
167-
167+
168168
# Compute metrics
169169
metrics["cratio"].append(compute_cratio(denoised, self.codec))
170170
metrics["cratio_noise"].append(compute_cratio(noise, self.codec))

src/aind_exaspim_image_compression/inference.py

Lines changed: 72 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@
1010
into a full 3D volume.
1111
1212
"""
13-
13+
from concurrent.futures import (
14+
ThreadPoolExecutor,
15+
as_completed,
16+
)
1417
from tqdm import tqdm
1518

1619
import numpy as np
1720
import torch
1821

1922

2023
def predict(
21-
img, model, batch_size=32, patch_size=64, overlap=16, verbose=True
24+
img, model, batch_size=32, patch_size=64, overlap=16, trim=5, verbose=True
2225
):
2326
"""
2427
Denoises a 3D image by processing patches in batches and running deep
@@ -46,51 +49,76 @@ def predict(
4649
preds : List[numpy.ndarray]
4750
List of predicted patches (3D arrays) matching the patch size.
4851
"""
49-
# Initializations
50-
batch_coords, batch_inputs, mn_mx = list(), list(), list()
52+
# Adjust image dimenions
5153
while len(img.shape) < 5:
5254
img = img[np.newaxis, ...]
53-
coords = generate_coords(img, patch_size, overlap)
55+
56+
# Initializations
57+
starts = generate_patch_starts(img, patch_size, overlap)
58+
denoised = np.zeros_like(img, dtype=np.uint16)
5459

5560
# Main
56-
pbar = tqdm(total=len(coords), desc="Denoise") if verbose else None
57-
preds = list()
58-
for idx, (i, j, k) in enumerate(coords):
59-
# Get end coord
60-
i_end = min(i + patch_size, img.shape[2])
61-
j_end = min(j + patch_size, img.shape[3])
62-
k_end = min(k + patch_size, img.shape[4])
63-
64-
# Get patch
65-
patch = img[0, 0, i:i_end, j:j_end, k:k_end]
61+
pbar = tqdm(total=len(starts), desc="Denoise") if verbose else None
62+
for i in range(0, len(starts), batch_size):
63+
# Run model
64+
starts_i = starts[i:min(i + batch_size, len(starts))]
65+
patches_i = _predict_batch(img, model, starts_i, patch_size, trim)
66+
67+
# Store result
68+
for patch, start in zip(patches_i, starts_i):
69+
start = [max(s + trim, 0) for s in start]
70+
end = [start[i] + patch.shape[i] for i in range(3)]
71+
end = [min(e, s) for e, s in zip(end, img.shape[2:])]
72+
denoised[
73+
0, 0, start[0]:end[0], start[1]:end[1], start[2]:end[2]
74+
] = patch[: end[0] - start[0], : end[1] - start[1], : end[2] - start[2]]
75+
pbar.update(len(starts_i)) if verbose else None
76+
return denoised
77+
78+
79+
def _predict_batch(img, model, starts, patch_size, trim=5):
80+
# Subroutine
81+
def read_patch(i):
82+
start = starts[i]
83+
end = [min(s + patch_size, d) for s, d in zip(start, img.shape[2:])]
84+
patch = img[0, 0, start[0]:end[0], start[1]:end[1], start[2]:end[2]]
6685
mn, mx = np.percentile(patch, 5), np.percentile(patch, 99.9)
67-
patch = (patch - mn) / mx
68-
mn_mx.append((mn, mx))
69-
70-
# Store patch
71-
patch = add_padding(patch, patch_size)
72-
batch_inputs.append(patch)
73-
batch_coords.append((i, j, k))
74-
75-
# If batch is full or it's the last patch
76-
if len(batch_inputs) == batch_size or idx == len(coords) - 1:
77-
# Run model
78-
input_tensor = batch_to_tensor(np.stack(batch_inputs))
79-
with torch.no_grad():
80-
output_tensor = model(input_tensor)
81-
82-
# Store result
83-
output_tensor = output_tensor.cpu()
84-
for cnt in range(output_tensor.shape[0]):
85-
mn, mx = mn_mx[cnt]
86-
patch = np.array(output_tensor[cnt, 0, ...])
87-
preds.append(np.maximum(patch * mx + mn, 0))
88-
pbar.update(1) if verbose else None
89-
90-
batch_coords.clear()
91-
batch_inputs.clear()
92-
mn_mx.clear()
93-
return stitch(img, coords, preds)
86+
patch = add_padding((patch - mn) / mx, patch_size)
87+
return i, patch, (mn, mx)
88+
89+
# Main
90+
with ThreadPoolExecutor() as executor:
91+
# Read patches
92+
threads = list()
93+
for i in range(len(starts)):
94+
threads.append(executor.submit(read_patch, i))
95+
96+
# Compile batch
97+
inputs = np.zeros((len(starts),) + (patch_size,) * 3)
98+
mn_mx = len(starts) * [None]
99+
for thread in as_completed(threads):
100+
i, patch_i, mn_mx_i = thread.result()
101+
mn_mx[i] = mn_mx_i
102+
inputs[i, ...] = patch_i
103+
104+
# Run model
105+
inputs = batch_to_tensor(inputs)
106+
with torch.no_grad():
107+
outputs = model(inputs)
108+
outputs = np.array(outputs.cpu()).squeeze(1)
109+
110+
# Store result
111+
preds = list()
112+
start, end = trim, patch_size - trim
113+
for i in range(outputs.shape[0]):
114+
mn, mx = mn_mx[i]
115+
pred = np.maximum(outputs[i] * mx + mn, 0).astype(np.uint16)
116+
preds.append(pred[start:end, start:end, start:end])
117+
return preds
118+
119+
120+
def predict_largescale(img, model):
121+
pass
94122

95123

96124
def predict_patch(patch, model):
@@ -117,67 +145,7 @@ def predict_patch(patch, model):
117145

118146
# Process output
119147
pred = np.array(output_tensor.cpu())
120-
return np.maximum(pred[0, 0, ...] * mx + mn, 0).astype(int)
121-
122-
123-
def stitch(img, coords, preds, patch_size=64, trim=5):
124-
"""
125-
Stitches overlapping 3D patches back into a full denoised image by
126-
averaging overlapping regions, with optional trimming of patch borders.
127-
128-
Parameters
129-
----------
130-
img : numpy.ndarray
131-
Original image array of shape (batch, channels, depth, height, width).
132-
coords : List[Tuple[int]]
133-
List of starting (i, j, k) coordinates for each patch.
134-
preds : List[numpy.ndarray]
135-
Predicted patches with shape (patch_size, patch_size, patch_size).
136-
patch_size : int, optional
137-
Size of each cubic patch. Default is 64.
138-
trim : int, optional
139-
Number of voxels to trim from each side of a patch before stitching.
140-
Default is 5.
141-
142-
Returns
143-
-------
144-
numpy.ndarray
145-
Reconstructed image with patches stitched and overlapping areas
146-
averaged.
147-
"""
148-
denoised_accum = np.zeros_like(img, dtype=np.float32)
149-
weight_map = np.zeros_like(img, dtype=np.float32)
150-
for (i, j, k), pred in zip(coords, preds):
151-
# Trim prediction
152-
start, end = trim, patch_size - trim
153-
pred = pred[start:end, start:end, start:end]
154-
155-
# Adjust insertion indices
156-
i_start = i + trim
157-
j_start = j + trim
158-
k_start = k + trim
159-
160-
i_end = i_start + pred.shape[0]
161-
j_end = j_start + pred.shape[1]
162-
k_end = k_start + pred.shape[2]
163-
164-
# Clip to image bounds (for safety)
165-
i_end = min(i_end, img.shape[2])
166-
j_end = min(j_end, img.shape[3])
167-
k_end = min(k_end, img.shape[4])
168-
169-
i_start = max(i_start, 0)
170-
j_start = max(j_start, 0)
171-
k_start = max(k_start, 0)
172-
173-
denoised_accum[
174-
0, 0, i_start:i_end, j_start:j_end, k_start:k_end
175-
] += pred[: i_end - i_start, : j_end - j_start, : k_end - k_start]
176-
weight_map[0, 0, i_start:i_end, j_start:j_end, k_start:k_end] += 1
177-
178-
# Average accumulated
179-
weight_map[weight_map == 0] = 1
180-
return denoised_accum / weight_map
148+
return np.maximum(pred[0, 0, ...] * mx + mn, 0).astype(np.uint16)
181149

182150

183151
# --- Helpers ---
@@ -205,7 +173,7 @@ def add_padding(patch, patch_size):
205173
return np.pad(patch, pad_width, mode="constant", constant_values=0)
206174

207175

208-
def generate_coords(img, patch_size, overlap):
176+
def generate_patch_starts(img, patch_size, overlap):
209177
"""
210178
Generates starting coordinates for 3D patches extracted from an image
211179
tensor, based on specified patch size and overlap.

0 commit comments

Comments
 (0)