Skip to content

Commit f109887

Browse files
committed
Initial evaluation for NIS3D nucleus segmentation
1 parent 1ad3d8a commit f109887

File tree

6 files changed

+226
-4
lines changed

6 files changed

+226
-4
lines changed

scripts/baselines/NIS3D_apply.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import os
2+
import sys
3+
4+
script_dir = "/user/schilling40/u15000/flamingo-tools/scripts/prediction"
5+
sys.path.append(script_dir)
6+
7+
import run_prediction_distance_unet
8+
9+
checkpoint_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/nucleus"
10+
model_name = "NIS3D_supervised_2025-07-17"
11+
model_dir = os.path.join(checkpoint_dir, model_name)
12+
checkpoint = os.path.join(checkpoint_dir, model_name, "best.pt")
13+
14+
cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet"
15+
16+
image_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/nucleus/2025-07_NIS3D/test"
17+
18+
out_dir = os.path.join(cochlea_dir, "predictions", "val_nucleus", "distance_unet_NIS3D") # /distance_unet
19+
20+
boundary_distance_threshold = 0.5
21+
seg_class = "ihc"
22+
23+
block_shape = (128, 128, 128)
24+
halo = (16, 32, 32)
25+
26+
block_shape_str = ",".join([str(b) for b in block_shape])
27+
halo_str = ",".join([str(h) for h in halo])
28+
29+
images = [entry.path for entry in os.scandir(image_dir) if entry.is_file() and "iitest.tif" in entry.path]
30+
31+
for image in images:
32+
sys.argv = [
33+
os.path.join(script_dir, "run_prediction_distance_unet.py"),
34+
f"--input={image}",
35+
f"--output_folder={out_dir}",
36+
f"--model={model_dir}",
37+
f"--block_shape=[{block_shape_str}]",
38+
f"--halo=[{halo_str}]",
39+
"--memory",
40+
"--time",
41+
"--no_masking",
42+
f"--seg_class={seg_class}",
43+
f"--boundary_distance_threshold={boundary_distance_threshold}"
44+
]
45+
46+
run_prediction_distance_unet.main()

scripts/baselines/NIS3D_eval.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import argparse
2+
import json
3+
import multiprocessing as mp
4+
import os
5+
from concurrent import futures
6+
from typing import List
7+
8+
import numpy as np
9+
import tifffile
10+
from tqdm import tqdm
11+
12+
GT_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/nucleus/2025-07_NIS3D/test"
13+
PRED_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/val_nucleus/distance_unet_NIS3D"
14+
15+
16+
def find_overlapping_masks(
17+
arr_base: np.ndarray,
18+
arr_ref: np.ndarray,
19+
label_id_base: int,
20+
min_overlap: float = 0.5,
21+
) -> List[int]:
22+
"""Find masks of segmentation, which have an overlap with undefined mask greater than 0.5.
23+
"""
24+
labels_undefined_mask = []
25+
arr_base_undefined = arr_base == label_id_base
26+
27+
# iterate through segmentation ids in reference mask
28+
ref_ids = list(np.unique(arr_ref)[1:])
29+
for ref_id in ref_ids:
30+
arr_ref_instance = arr_ref == ref_id
31+
32+
intersection = np.logical_and(arr_ref_instance, arr_base_undefined)
33+
overlap_ratio = np.sum(intersection) / np.sum(arr_ref_instance)
34+
if overlap_ratio >= min_overlap:
35+
labels_undefined_mask.append(ref_id)
36+
37+
return labels_undefined_mask
38+
39+
40+
def find_matching_masks(arr_gt, arr_ref, out_path, labels_undefined_mask=[]):
41+
"""For each instance in the reference array, the corresponding mask of the ground truth array,
42+
which has the biggest overlap, is identified.
43+
44+
Args:
45+
arr_gt:
46+
arr_ref:
47+
out_path: Output path for saving dictionary.
48+
labels_undefined_mask: Labels of the reference array to exclude.
49+
"""
50+
seg_ids_ref = [int(i) for i in np.unique(arr_ref)[1:]]
51+
print(f"total number of segmentation masks: {len(seg_ids_ref)}")
52+
seg_ids_ref = [s for s in seg_ids_ref if s not in labels_undefined_mask]
53+
print(f"number of segmentation masks after filtering undefined masks: {len(seg_ids_ref)}")
54+
55+
def compute_overlap(ref_id):
56+
"""Identify ID of segmentation mask with biggest overlap.
57+
Return matched IDs and overlap.
58+
"""
59+
arr_ref_instance = arr_ref == ref_id
60+
61+
seg_ids_gt = np.unique(arr_gt[arr_ref_instance])[1:]
62+
63+
max_overlap = 0
64+
gt_id_match = None
65+
66+
for gt_id in seg_ids_gt:
67+
arr_gt_instance = arr_gt == gt_id
68+
69+
intersection = np.logical_and(arr_ref_instance, arr_gt_instance)
70+
overlap_ratio = np.sum(intersection) / np.sum(arr_ref_instance)
71+
if overlap_ratio > max_overlap:
72+
gt_id_match = int(gt_id.tolist())
73+
max_overlap = np.max([max_overlap, overlap_ratio])
74+
75+
if gt_id_match is not None:
76+
return {
77+
"ref_id": ref_id,
78+
"gt_id": gt_id_match,
79+
"overlap": float(max_overlap.tolist())
80+
}
81+
else:
82+
return None
83+
84+
n_threads = min(16, mp.cpu_count())
85+
print(f"Parallelizing with {n_threads} Threads.")
86+
with futures.ThreadPoolExecutor(n_threads) as pool:
87+
results = list(tqdm(pool.map(compute_overlap, seg_ids_ref), total=len(seg_ids_ref)))
88+
89+
matching_masks = {r['ref_id']: r for r in results if r is not None}
90+
91+
with open(out_path, "w") as f:
92+
json.dump(matching_masks, f, indent='\t', separators=(',', ': '))
93+
94+
95+
def filter_true_positives(output_folder, prefixes, force_overwrite):
96+
""" Filter true positives from segmentation.
97+
Segmentation instances and ground truth labels are filtered symmetrically.
98+
The maximal overlap of each is computed and taken as a true positive if symmetric.
99+
The instance ID, the reference ID, and the overlap are saved in dictionaries.
100+
101+
Args:
102+
output_folder: Output folder for dictionaries.
103+
prefixes: List of prefixes for evaluation. One or multiple of ["Drosophila", "MusMusculus", "Zebrafish"].
104+
force_overwrite: Flag for forced overwrite of existing output files.
105+
"""
106+
if "PRED_DIR" in globals():
107+
pred_dir = PRED_DIR
108+
if "GT_DIR" in globals():
109+
gt_dir = GT_DIR
110+
111+
if prefixes is None:
112+
prefixes = ["Drosophila", "MusMusculus", "Zebrafish"]
113+
114+
for prefix in prefixes:
115+
conf_file = os.path.join(gt_dir, f"{prefix}_1_iitest_confidence.tif")
116+
annot_file = os.path.join(gt_dir, f"{prefix}_1_iitest_annotations.tif")
117+
conf_arr = tifffile.imread(conf_file)
118+
gt_arr = tifffile.imread(annot_file)
119+
120+
seg_file = os.path.join(pred_dir, f"{prefix}_1_iitest_seg.tif")
121+
seg_arr = tifffile.imread(seg_file)
122+
123+
# find largest overlap of ground truth mask with each segmentation instance
124+
out_path = os.path.join(output_folder, f"{prefix}_matching_ref_gt.json")
125+
if os.path.isfile(out_path) and not force_overwrite:
126+
print(f"Skipping the creation of {out_path}. File already exists.")
127+
else:
128+
# exclude detections with more than 50% of pixels in undefined category
129+
if 1 in np.unique(conf_arr)[1:]:
130+
labels_undefined_mask = find_overlapping_masks(conf_arr, seg_arr, label_id_base=1)
131+
else:
132+
labels_undefined_mask = []
133+
print("Array does not contain undefined mask")
134+
135+
find_matching_masks(gt_arr, seg_arr, out_path, labels_undefined_mask=labels_undefined_mask)
136+
137+
# find largest overlap of segmentation instance with each ground truth mask
138+
out_path = os.path.join(output_folder, f"{prefix}_matching_gt_ref.json")
139+
if os.path.isfile(out_path) and not force_overwrite:
140+
print(f"Skipping the creation of {out_path}. File already exists.")
141+
else:
142+
find_matching_masks(seg_arr, gt_arr, out_path)
143+
144+
145+
def main():
146+
parser = argparse.ArgumentParser()
147+
parser.add_argument("--output_folder", "-o", required=True)
148+
parser.add_argument("--prefix", "-p", nargs="+", type=str, default=None)
149+
parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.")
150+
args = parser.parse_args()
151+
152+
filter_true_positives(
153+
args.output_folder,
154+
args.prefix,
155+
args.force,
156+
)
157+
158+
159+
if __name__ == "__main__":
160+
main()
File renamed without changes.

scripts/baselines/NIS3D_train.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
3+
export MODEL_NAME="nucleus_NIS3D_supervised_2025-07-17"
4+
5+
export IDIR=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/nucleus/2025-07_NIS3D
6+
7+
export SCRIPT_DIR=/user/schilling40/u15000/flamingo-tools/scripts/training
8+
9+
python $SCRIPT_DIR/train_distance_unet.py -i $IDIR --name $MODEL_NAME
10+

scripts/baselines/eval_baseline.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ def print_accuracy(eval_dir):
154154
recall = 0
155155
if precision + recall != 0:
156156
f1_score = 2 * precision * recall / (precision + recall)
157-
else: f1_score = 0
157+
else:
158+
f1_score = 0
158159

159160
precision_list.append(precision)
160161
recall_list.append(recall)
@@ -198,9 +199,9 @@ def print_accuracy_ihc():
198199

199200

200201
def main():
201-
#eval_all_sgn()
202-
#eval_all_ihc()
203-
#print_accuracy_sgn()
202+
eval_all_sgn()
203+
eval_all_ihc()
204+
print_accuracy_sgn()
204205
print_accuracy_ihc()
205206

206207

scripts/prediction/run_prediction_distance_unet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def main():
2727
parser.add_argument("--halo", default=None, type=str)
2828
parser.add_argument("--memory", action="store_true", help="Perform prediction in memory and save output as tif.")
2929
parser.add_argument("--time", action="store_true", help="Time prediction process.")
30+
parser.add_argument("--no_masking", action="store_true", help="Do not mask input.")
3031
parser.add_argument("--seg_class", default=None, type=str,
3132
help="Segmentation class to load parameters for masking input.")
3233
parser.add_argument("--center_distance_threshold", default=0.4, type=float,
@@ -67,6 +68,8 @@ def main():
6768
else:
6869
halo = tuple(json.loads(args.halo))
6970

71+
use_mask = ~args.no_masking
72+
7073
if args.time:
7174
start = time.perf_counter()
7275

@@ -75,6 +78,7 @@ def main():
7578
args.input, args.input_key, output_folder=None, model_path=args.model,
7679
scale=scale, min_size=min_size,
7780
block_shape=block_shape, halo=halo,
81+
use_mask=use_mask,
7882
seg_class=args.seg_class,
7983
center_distance_threshold=args.center_distance_threshold,
8084
boundary_distance_threshold=args.boundary_distance_threshold,
@@ -92,6 +96,7 @@ def main():
9296
args.input, args.input_key, output_folder=args.output_folder, model_path=args.model,
9397
scale=scale, min_size=min_size,
9498
block_shape=block_shape, halo=halo,
99+
use_mask=use_mask,
95100
seg_class=args.seg_class,
96101
center_distance_threshold=args.center_distance_threshold,
97102
boundary_distance_threshold=args.boundary_distance_threshold,

0 commit comments

Comments
 (0)