Skip to content

Commit 2688952

Browse files
Implement grid-search for IHC segmentation
1 parent 7d600eb commit 2688952

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import os
2+
from glob import glob
3+
4+
import imageio.v3 as imageio
5+
import numpy as np
6+
from skimage.measure import label
7+
from skimage.segmentation import relabel_sequential
8+
9+
from torch_em.util import load_model
10+
from torch_em.util.grid_search import DistanceBasedInstanceSegmentation, instance_segmentation_grid_search
11+
12+
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/IHC/2025-07-for-grid-search"
13+
MODEL_PATH = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/IHC/v4_cochlea_distance_unet_IHC_supervised_2025-07-14" # noqa
14+
15+
16+
def preprocess_gt():
17+
label_paths = sorted(glob(os.path.join(ROOT, "**/*_corrected.tif"), recursive=True))
18+
for label_path in label_paths:
19+
seg = imageio.imread(label_path)
20+
seg = label(seg)
21+
22+
min_size = 750
23+
ids, sizes = np.unique(seg, return_counts=True)
24+
filter_ids = ids[sizes < min_size]
25+
seg[np.isin(seg, filter_ids)] = 0
26+
seg, _, _ = relabel_sequential(seg)
27+
28+
out_path = label_path.replace("_corrected", "_gt")
29+
imageio.imwrite(out_path, seg)
30+
31+
32+
def run_grid_search():
33+
image_paths = sorted(glob(os.path.join(ROOT, "**/*lut3*.tif"), recursive=True))
34+
label_paths = sorted(glob(os.path.join(ROOT, "**/*_gt.tif"), recursive=True))
35+
assert len(image_paths) == len(label_paths), f"{len(image_paths)}, {len(label_paths)}"
36+
result_dir = "ihc-v4-gs"
37+
38+
block_shape = (96, 256, 256)
39+
halo = (8, 64, 64)
40+
model = load_model(MODEL_PATH)
41+
segmenter = DistanceBasedInstanceSegmentation(model, block_shape=block_shape, halo=halo)
42+
43+
grid_search_values = {
44+
"center_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7],
45+
"boundary_distance_threshold": [0.3, 0.4, 0.5, 0.6, 0.7],
46+
"distance_smoothing": [0.0, 0.6, 1.0, 1.6],
47+
"min_size": [0],
48+
}
49+
best_kwargs, best_score = instance_segmentation_grid_search(
50+
segmenter, image_paths, label_paths, result_dir, grid_search_values=grid_search_values
51+
)
52+
print("Grid-search result:")
53+
print(best_kwargs)
54+
print(best_score)
55+
56+
57+
# TODO plot the grid search results
58+
def evaluate_grid_search():
59+
pass
60+
61+
62+
def main():
63+
# preprocess_gt()
64+
run_grid_search()
65+
evaluate_grid_search()
66+
67+
68+
if __name__ == "__main__":
69+
main()

scripts/validation/IHCs/run_evaluation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def main():
5959
parser.add_argument("-i", "--input", default=ROOT)
6060
parser.add_argument("--folders", default=ANNOTATION_FOLDERS)
6161
parser.add_argument("--result_file", default="results.csv")
62-
parser.add_argument("--segmentation_name", default="IHC_v2")
62+
parser.add_argument("--segmentation_name", default="IHC_v4")
6363
parser.add_argument("--cache_folder")
6464
args = parser.parse_args()
6565
run_evaluation(args.input, args.folders, args.result_file, args.cache_folder, args.segmentation_name)

0 commit comments

Comments
 (0)