Skip to content

Commit e1f2fe7

Browse files
authored
refactor: simplified inference
1 parent 31b35df commit e1f2fe7

File tree

2 files changed

+67
-35
lines changed

2 files changed

+67
-35
lines changed

src/aind_exaspim_image_compression/evaluate.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
@author: Anna Grim
55
@email: anna.grim@alleninstitute.org
66
7-
Code for evaluating a denoising model on an image dataset.
7+
Code for evaluating a denoising model.
88
99
"""
1010

@@ -20,7 +20,7 @@
2020
from aind_exaspim_image_compression.utils import img_util, util
2121

2222

23-
class Evaluator:
23+
class SupervisedEvaluator:
2424
def __init__(self, img_paths, model, output_dir):
2525
# Instance attributes
2626
self.codec = blosc.Blosc(cname="zstd", clevel=6, shuffle=blosc.SHUFFLE)
@@ -52,7 +52,7 @@ def load_images(self):
5252
self.noise_imgs[block_id] = img
5353

5454
output_path = os.path.join(noise_dir, block_id)
55-
img_util.plot_mips(img[0, 0, ...], output_path=output_path)
55+
img_util.plot_mips(img, output_path=output_path)
5656

5757
# --- Main ---
5858
def run(self, model_path):
@@ -63,14 +63,13 @@ def run(self, model_path):
6363
util.mkdir(results_dir)
6464

6565
# Generate prediction
66-
rows = list(self.noise_imgs.keys())
66+
rows = sorted(list(self.noise_imgs.keys()))
6767
df = pd.DataFrame(index=rows, columns=["cratio", "ssim"])
6868
desc = "Denoise Blocks"
6969
for block_id, noise in tqdm(self.noise_imgs.items(), desc=desc):
7070
# Run model
7171
coords, preds = predict(noise, self.model, verbose=False)
7272
denoised = stitch(noise, coords, preds)
73-
denoised = denoised.astype(np.uint16)
7473

7574
# Compute metrics
7675
df.loc[block_id, "cratio"] = img_util.compute_cratio(
@@ -79,17 +78,17 @@ def run(self, model_path):
7978
df.loc[block_id, "ssim"] = img_util.compute_ssim3D(
8079
noise[0, 0, ...],
8180
denoised[0, 0, ...],
82-
data_range=np.percentile(noise, 99.9),
81+
data_range=np.max(noise),
8382
)
8483

8584
# Save MIPs
8685
output_path = os.path.join(results_dir, block_id)
87-
img_util.plot_mips(denoised[0, 0, ...], output_path=output_path)
86+
img_util.plot_mips(denoised, output_path=output_path)
8887

8988
# Save metrics
9089
path = os.path.join(results_dir, "results.csv")
9190
df.to_csv(path, index=True)
92-
return model_name, df
91+
return df
9392

9493
# --- Helpers ---
9594
def find_img_name(self, img_path):

src/aind_exaspim_image_compression/inference.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,10 @@ def predict(
4141
4242
Returns
4343
-------
44-
coords : list of tuples
44+
coords : List[Tuple[int]]
4545
List of (i, j, k) starting coordinates of patches processed.
46-
preds : list of numpy.ndarray
46+
preds : List[numpy.ndarray]
4747
List of predicted patches (3D arrays) matching the patch size.
48-
4948
"""
5049
# Initializations
5150
batch_coords, batch_inputs, mn_mx = list(), list(), list()
@@ -74,7 +73,7 @@ def predict(
7473
# If batch is full or it's the last patch
7574
if len(batch_inputs) == batch_size or idx == len(coords) - 1:
7675
# Run model
77-
input_tensor = to_tensor(np.stack(batch_inputs))
76+
input_tensor = batch_to_tensor(np.stack(batch_inputs))
7877
with torch.no_grad():
7978
output_tensor = model(input_tensor)
8079

@@ -92,6 +91,33 @@ def predict(
9291
return coords, preds
9392

9493

94+
def predict_patch(patch, model):
95+
"""
96+
Denoised a single 3D patch using the provided model.
97+
98+
Parameters
99+
----------
100+
model : torch.nn.Module
101+
PyTorch model used for prediction.
102+
patch : numpy.ndarray
103+
3D input patch to denoise.
104+
105+
Returns
106+
-------
107+
numpy.ndarray
108+
Denoised 3D patch with the same shape as input patch.
109+
"""
110+
# Run model
111+
mn, mx = np.percentile(patch, 5), np.percentile(patch, 99.9)
112+
patch = to_tensor((patch - mn) / max(mx, 1))
113+
with torch.no_grad():
114+
output_tensor = model(patch)
115+
116+
# Process output
117+
pred = np.array(output_tensor.cpu())
118+
return np.maximum(pred[0, 0, ...] * mx + mn, 0).astype(int)
119+
120+
95121
def stitch(img, coords, preds, patch_size=64, trim=5):
96122
"""
97123
Stitches overlapping 3D patches back into a full denoised image by
@@ -116,28 +142,22 @@ def stitch(img, coords, preds, patch_size=64, trim=5):
116142
numpy.ndarray
117143
Reconstructed image with patches stitched and overlapping areas
118144
averaged.
119-
120145
"""
121146
denoised_accum = np.zeros_like(img, dtype=np.float32)
122147
weight_map = np.zeros_like(img, dtype=np.float32)
123148
for (i, j, k), pred in zip(coords, preds):
124-
# Determine how much to trim
125-
trim_start = trim
126-
trim_end = patch_size - trim
127-
128149
# Trim prediction
129-
pred_trimmed = pred[
130-
trim_start:trim_end, trim_start:trim_end, trim_start:trim_end
131-
]
150+
start, end = trim, patch_size - trim
151+
pred = pred[start:end, start:end, start:end]
132152

133153
# Adjust insertion indices
134154
i_start = i + trim
135155
j_start = j + trim
136156
k_start = k + trim
137157

138-
i_end = i_start + pred_trimmed.shape[0]
139-
j_end = j_start + pred_trimmed.shape[1]
140-
k_end = k_start + pred_trimmed.shape[2]
158+
i_end = i_start + pred.shape[0]
159+
j_end = j_start + pred.shape[1]
160+
k_end = k_start + pred.shape[2]
141161

142162
# Clip to image bounds (for safety)
143163
i_end = min(i_end, img.shape[2])
@@ -150,9 +170,7 @@ def stitch(img, coords, preds, patch_size=64, trim=5):
150170

151171
denoised_accum[
152172
0, 0, i_start:i_end, j_start:j_end, k_start:k_end
153-
] += pred_trimmed[
154-
: i_end - i_start, : j_end - j_start, : k_end - k_start
155-
]
173+
] += pred[: i_end - i_start, : j_end - j_start, : k_end - k_start]
156174
weight_map[0, 0, i_start:i_end, j_start:j_end, k_start:k_end] += 1
157175

158176
# Average accumulated
@@ -176,7 +194,6 @@ def add_padding(patch, patch_size):
176194
-------
177195
numpy.ndarray
178196
Zero-padded patch with shape (patch_size, patch_size, patch_size).
179-
180197
"""
181198
pad_width = [
182199
(0, patch_size - patch.shape[0]),
@@ -205,7 +222,6 @@ def generate_coords(img, patch_size, overlap):
205222
coords : List[Tuple[int]]
206223
List of (depth_start, height_start, width_start) coordinates for image
207224
patches.
208-
209225
"""
210226
coords = list()
211227
stride = patch_size - overlap
@@ -218,20 +234,37 @@ def generate_coords(img, patch_size, overlap):
218234

219235
def to_tensor(arr):
220236
"""
221-
Converts a NumPy array to a PyTorch tensor with an added channel dimension,
222-
and moves it to the GPU.
237+
Converts a NumPy array containing to a PyTorch tensor and moves it to the
238+
GPU.
223239
224240
Parameters
225241
----------
226242
arr : numpy.ndarray
227-
Input array to be converted.
243+
Array to be converted.
228244
229245
Returns
230246
-------
231247
torch.Tensor
232-
Input array as a float tensor on the CUDA device, with shape
233-
(batch_size, 1, ...).
248+
Tensor on GPU, with shape (1, 1, depth, height, width).
249+
"""
250+
while(len(arr.shape)) < 5:
251+
arr = arr[np.newaxis, ...]
252+
return torch.tensor(arr).to("cuda", dtype=torch.float)
253+
234254

255+
def batch_to_tensor(arr):
256+
"""
257+
Converts a NumPy array containing a batch of inputs to a PyTorch tensor
258+
and moves it to the GPU.
259+
260+
Parameters
261+
----------
262+
arr : numpy.ndarray
263+
Array to be converted, with shape (batch_size, depth, height, width).
264+
265+
Returns
266+
-------
267+
torch.Tensor
268+
Tensor on GPU, with shape (batch_size, 1, depth, height, width).
235269
"""
236-
dtype = torch.float
237-
return torch.tensor(arr[:, np.newaxis, ...]).to("cuda", dtype=dtype)
270+
return to_tensor(arr[:, np.newaxis, ...])

0 commit comments

Comments
 (0)