Skip to content

Commit 7c8a126

Browse files
authored
feat: evaluation routines
1 parent c0ef46b commit 7c8a126

File tree

2 files changed

+133
-23
lines changed

2 files changed

+133
-23
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
Created on Tue June 3 12:00:00 2025
3+
4+
@author: Anna Grim
5+
@email: anna.grim@alleninstitute.org
6+
7+
Code for evaluating a denoising model on an image dataset.
8+
9+
"""
10+
11+
from numcodecs import blosc
12+
from tqdm import tqdm
13+
14+
import numpy as np
15+
import os
16+
import pandas as pd
17+
import torch
18+
19+
from aind_exaspim_image_compression.inference import predict, stitch
20+
from aind_exaspim_image_compression.utils import img_util, util
21+
22+
23+
class Evaluator:
24+
def __init__(self, img_paths, model, output_dir):
25+
# Instance attributes
26+
self.codec = blosc.Blosc(cname="zstd", clevel=6, shuffle=blosc.SHUFFLE)
27+
self.img_paths = img_paths
28+
self.model = model
29+
self.model.eval().to("cuda")
30+
31+
# Initialize output directory
32+
self.output_dir = output_dir
33+
util.mkdir(output_dir, delete=True)
34+
35+
# Load images
36+
self.load_images()
37+
38+
def load_images(self):
39+
# Initialize MIPs directory
40+
noise_dir = os.path.join(self.output_dir, "noise_mips")
41+
util.mkdir(noise_dir)
42+
43+
# Read images
44+
self.noise_imgs = dict()
45+
self.noise_cratios = dict()
46+
for img_path in self.img_paths:
47+
block_id = self.find_img_name(img_path)
48+
img = img_util.read(img_path)
49+
self.noise_cratios[block_id] = img_util.compute_cratio(
50+
img, self.codec
51+
)
52+
self.noise_imgs[block_id] = img
53+
54+
output_path = os.path.join(noise_dir, block_id)
55+
img_util.plot_mips(img[0, 0, ...], output_path=output_path)
56+
break
57+
58+
# --- Main ---
59+
def run(self, model_path):
60+
# Initializations
61+
self.model.load_state_dict(torch.load(model_path))
62+
model_name = os.path.basename(model_path)
63+
results_dir = os.path.join(self.output_dir, model_name)
64+
util.mkdir(results_dir)
65+
66+
# Generate prediction
67+
rows = list(self.noise_imgs.keys())
68+
df = pd.DataFrame(index=rows, columns=["cratio", "ssim"])
69+
desc = "Denoise Blocks"
70+
for block_id, noise in tqdm(self.noise_imgs.items(), desc=desc):
71+
# Run model
72+
coords, preds = predict(noise, self.model, verbose=False)
73+
denoised = stitch(noise, coords, preds)
74+
denoised = denoised.astype(np.uint16)
75+
76+
# Compute metrics
77+
df.loc[block_id, "cratio"] = img_util.compute_cratio(
78+
denoised, self.codec
79+
)
80+
df.loc[block_id, "ssim"] = img_util.compute_ssim3D(
81+
noise[0, 0, ...],
82+
denoised[0, 0, ...],
83+
data_range=np.percentile(noise, 99.9),
84+
)
85+
86+
# Save MIPs
87+
output_path = os.path.join(results_dir, block_id)
88+
img_util.plot_mips(denoised[0, 0, ...], output_path=output_path)
89+
90+
# Save metrics
91+
path = os.path.join(results_dir, "results.csv")
92+
df.to_csv(path, index=True)
93+
return model_name, df
94+
95+
# --- Helpers ---
96+
def find_img_name(self, img_path):
97+
for part in img_path.split("/"):
98+
if "block_" in part:
99+
return part
100+
raise Exception(f"Block ID not found in {img_path}")

src/aind_exaspim_image_compression/inference.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,27 @@
1111
import numpy as np
1212
import torch
1313

14-
from aind_exaspim_image_compression.utils import img_util, util
14+
from aind_exaspim_image_compression.utils import img_util
1515

1616

17-
def predict(img, model, batch_size=32, patch_size=64, overlap=16):
17+
def predict(
18+
img, model, batch_size=32, patch_size=64, overlap=16, verbose=True
19+
):
1820
# Initializations
1921
batch_coords, batch_inputs, mn_mx = list(), list(), list()
2022
coords = generate_coords(img, patch_size, overlap)
2123

2224
# Main
23-
pbar = tqdm(total=len(coords), desc="Denoise")
25+
pbar = tqdm(total=len(coords), desc="Denoise") if verbose else None
2426
preds = list()
2527
for idx, (i, j, k) in enumerate(coords):
2628
# Get end coord
27-
i_end = min(i + patch_size, img.shape[0])
28-
j_end = min(j + patch_size, img.shape[1])
29-
k_end = min(k + patch_size, img.shape[2])
29+
i_end = min(i + patch_size, img.shape[2])
30+
j_end = min(j + patch_size, img.shape[3])
31+
k_end = min(k + patch_size, img.shape[4])
3032

3133
# Get patch
32-
patch = img[i:i_end, j:j_end, k:k_end]
34+
patch = img[0, 0, i:i_end, j:j_end, k:k_end]
3335
mn, mx = np.percentile(patch, 5), np.percentile(patch, 99.9)
3436
patch = (patch - mn) / mx
3537
mn_mx.append((mn, mx))
@@ -42,7 +44,7 @@ def predict(img, model, batch_size=32, patch_size=64, overlap=16):
4244
# If batch is full or it's the last patch
4345
if len(batch_inputs) == batch_size or idx == len(coords) - 1:
4446
# Run model
45-
input_tensor =to_tensor(np.stack(batch_inputs))
47+
input_tensor = to_tensor(np.stack(batch_inputs))
4648
with torch.no_grad():
4749
output_tensor = model(input_tensor)
4850

@@ -51,8 +53,8 @@ def predict(img, model, batch_size=32, patch_size=64, overlap=16):
5153
for cnt in range(output_tensor.shape[0]):
5254
mn, mx = mn_mx[cnt]
5355
patch = np.array(output_tensor[cnt, 0, ...])
54-
preds.append(patch * mx + mn)
55-
pbar.update(1)
56+
preds.append(np.maximum(patch * mx + mn, 0))
57+
pbar.update(1) if verbose else None
5658

5759
batch_coords.clear()
5860
batch_inputs.clear()
@@ -63,14 +65,15 @@ def predict(img, model, batch_size=32, patch_size=64, overlap=16):
6365
def stitch(img, coords, preds, patch_size=64, trim=5):
6466
denoised_accum = np.zeros_like(img, dtype=np.float32)
6567
weight_map = np.zeros_like(img, dtype=np.float32)
66-
6768
for (i, j, k), pred in zip(coords, preds):
6869
# Determine how much to trim
6970
trim_start = trim
7071
trim_end = patch_size - trim
7172

7273
# Trim prediction
73-
pred_trimmed = pred[trim_start:trim_end, trim_start:trim_end, trim_start:trim_end]
74+
pred_trimmed = pred[
75+
trim_start:trim_end, trim_start:trim_end, trim_start:trim_end
76+
]
7477

7578
# Adjust insertion indices
7679
i_start = i + trim
@@ -82,23 +85,29 @@ def stitch(img, coords, preds, patch_size=64, trim=5):
8285
k_end = k_start + pred_trimmed.shape[2]
8386

8487
# Clip to image bounds (for safety)
85-
i_end = min(i_end, img.shape[0])
86-
j_end = min(j_end, img.shape[1])
87-
k_end = min(k_end, img.shape[2])
88+
i_end = min(i_end, img.shape[2])
89+
j_end = min(j_end, img.shape[3])
90+
k_end = min(k_end, img.shape[4])
8891

8992
i_start = max(i_start, 0)
9093
j_start = max(j_start, 0)
9194
k_start = max(k_start, 0)
9295

93-
denoised_accum[i_start:i_end, j_start:j_end, k_start:k_end] += pred_trimmed[:i_end - i_start, :j_end - j_start, :k_end - k_start]
94-
weight_map[i_start:i_end, j_start:j_end, k_start:k_end] += 1
96+
denoised_accum[
97+
0, 0, i_start:i_end, j_start:j_end, k_start:k_end
98+
] += pred_trimmed[
99+
: i_end - i_start, : j_end - j_start, : k_end - k_start
100+
]
101+
weight_map[0, 0, i_start:i_end, j_start:j_end, k_start:k_end] += 1
95102

96103
# Average accumulated
97104
weight_map[weight_map == 0] = 1
98105
denoised = denoised_accum / weight_map
99106

100107
# Fill boundary trim
101-
fill_value = np.percentile(denoised[trim:-trim, trim:-trim, trim:-trim], 10)
108+
fill_value = np.percentile(
109+
denoised[..., trim:-trim, trim:-trim, trim:-trim], 10
110+
)
102111
return img_util.fill_boundary(denoised, trim, fill_value)
103112

104113

@@ -109,18 +118,19 @@ def add_padding(patch, patch_size):
109118
(0, patch_size - patch.shape[1]),
110119
(0, patch_size - patch.shape[2]),
111120
]
112-
return np.pad(patch, pad_width, mode='constant', constant_values=0)
121+
return np.pad(patch, pad_width, mode="constant", constant_values=0)
113122

114123

115124
def generate_coords(img, patch_size, overlap):
116125
coords = list()
117126
stride = patch_size - overlap
118-
for i in range(0, img.shape[0] - patch_size + stride, stride):
119-
for j in range(0, img.shape[1] - patch_size + stride, stride):
120-
for k in range(0, img.shape[2] - patch_size + stride, stride):
127+
for i in range(0, img.shape[2] - patch_size + stride, stride):
128+
for j in range(0, img.shape[3] - patch_size + stride, stride):
129+
for k in range(0, img.shape[4] - patch_size + stride, stride):
121130
coords.append((i, j, k))
122131
return coords
123132

124133

125134
def to_tensor(arr):
126-
return torch.tensor(arr[:, np.newaxis, ...]).to("cuda")
135+
dtype = torch.float
136+
return torch.tensor(arr[:, np.newaxis, ...]).to("cuda", dtype=dtype)

0 commit comments

Comments
 (0)