Skip to content

Commit a6b2de8

Browse files
Implement IHC grid search WIP
1 parent 0b7e4b2 commit a6b2de8

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import os
2+
from glob import glob
3+
4+
import imageio.v3 as imageio
5+
import numpy as np
6+
from skimage.measure import label
7+
from skimage.segmentation import relabel_sequential
8+
9+
from torch_em.util import load_model
10+
from torch_em.util.grid_search import DistanceBasedInstanceSegmentation, instance_segmentation_grid_search
11+
12+
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/" # noqa
13+
MODEL_PATH = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/" # noqa
14+
15+
GRID_SEARCH_VALUES = {
16+
"center_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7],
17+
"boundary_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7],
18+
"distance_smoothing": [0.0, 0.6, 1.0, 1.6],
19+
"min_size": [0, 100, 250],
20+
}
21+
22+
23+
def preprocess_gt():
24+
label_paths = sorted(glob(os.path.join(ROOT, "**/*_corrected.tif"), recursive=True))
25+
for label_path in label_paths:
26+
seg = imageio.imread(label_path)
27+
seg = label(seg)
28+
29+
min_size = 750
30+
ids, sizes = np.unique(seg, return_counts=True)
31+
filter_ids = ids[sizes < min_size]
32+
seg[np.isin(seg, filter_ids)] = 0
33+
seg, _, _ = relabel_sequential(seg)
34+
35+
out_path = label_path.replace("_corrected", "_gt")
36+
imageio.imwrite(out_path, seg)
37+
38+
39+
def run_grid_search():
40+
image_paths = sorted(glob(os.path.join(ROOT, "**/*lut3*.tif"), recursive=True))
41+
label_paths = sorted(glob(os.path.join(ROOT, "**/*_gt.tif"), recursive=True))
42+
assert len(image_paths) == len(label_paths), f"{len(image_paths)}, {len(label_paths)}"
43+
result_dir = "ihc-v4-gs"
44+
45+
block_shape = (96, 256, 256)
46+
halo = (8, 64, 64)
47+
model = load_model(MODEL_PATH)
48+
segmenter = DistanceBasedInstanceSegmentation(model, block_shape=block_shape, halo=halo)
49+
50+
best_kwargs, best_score = instance_segmentation_grid_search(
51+
segmenter, image_paths, label_paths, result_dir, grid_search_values=GRID_SEARCH_VALUES
52+
)
53+
print("Grid-search result:")
54+
print(best_kwargs)
55+
print(best_score)
56+
57+
58+
# TODO plot the grid search results
59+
def evaluate_grid_search():
60+
import matplotlib.pyplot as plt
61+
import pandas as pd
62+
import seaborn as sns
63+
64+
result_dir = "ihc-v4-gs"
65+
criterion = "SA50"
66+
67+
gs_files = glob(os.path.join(result_dir, "*.csv"))
68+
gs_result = pd.concat([pd.read_csv(gs_file) for gs_file in gs_files])
69+
70+
grid_search_parameters = list(GRID_SEARCH_VALUES.keys())
71+
72+
# Retrieve only the relevant columns and group by the gridsearch columns
73+
# and compute the mean value
74+
gs_result = gs_result[grid_search_parameters + [criterion]].reset_index()
75+
gs_result = gs_result.groupby(grid_search_parameters).mean().reset_index()
76+
77+
# Find the best score and best result
78+
best_score, best_idx = gs_result[criterion].max(), gs_result[criterion].idxmax()
79+
best_params = gs_result.iloc[best_idx]
80+
best_kwargs = {k: v for k, v in zip(grid_search_parameters, best_params)}
81+
82+
print("Best parameters:")
83+
print(best_kwargs)
84+
print("With score:", best_score)
85+
86+
fig, axes = plt.subplots(3)
87+
for i, (idx, col) in enumerate([
88+
("center_distance_threshold", "boundary_distance_threshold"),
89+
("distance_smoothing", "boundary_distance_threshold"),
90+
("distance_smoothing", "center_distance_threshold"),
91+
]):
92+
res = gs_result.groupby([idx, col]).mean().reset_index()
93+
res = res[[idx, col, criterion]]
94+
res = res.pivot(index=idx, columns=col, values=criterion)
95+
sns.heatmap(res, cmap="viridis", annot=True, fmt=".2g", cbar_kws={"label": criterion}, ax=axes[i])
96+
axes[i].set_xlabel(idx)
97+
axes[i].set_xlabel(col)
98+
plt.show()
99+
100+
101+
def main():
102+
# preprocess_gt()
103+
# run_grid_search()
104+
evaluate_grid_search()
105+
106+
107+
if __name__ == "__main__":
108+
main()

0 commit comments

Comments
 (0)