Skip to content

Commit 9293963

Browse files
authored
Merge pull request #41 from computational-cell-analytics/add_baselines
The baseline methods were evaluated for the initial submission of the paper. The branch has been brought to a satisfactory state and can now be merged into the master branch.
2 parents 64321fd + 26af401 commit 9293963

17 files changed

+1222
-1
lines changed

flamingo_tools/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def compute_matches_for_annotated_slice(
217217
A dictionary with keys 'tp_objects', 'tp_annotations' 'fp' and 'fn', mapping to the respective ids.
218218
"""
219219
assert segmentation.ndim in (2, 3)
220-
coordinates = ["axis-0", "axis-1"] if segmentation.ndim == 2 else ["axis-0", "axis-1", "axis-2"]
220+
coordinates = ["axis-1", "axis-2"] if segmentation.ndim == 2 else ["axis-0", "axis-1", "axis-2"]
221221
segmentation_ids = np.unique(segmentation)[1:]
222222

223223
# Crop to the minimal enclosing bounding box of points and segmented objects.

scripts/baselines/NIS3D_apply.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import os
2+
import sys
3+
4+
script_dir = "/user/schilling40/u15000/flamingo-tools/scripts/prediction"
5+
sys.path.append(script_dir)
6+
7+
import run_prediction_distance_unet
8+
9+
checkpoint_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/nucleus"
10+
model_name = "NIS3D_supervised_2025-07-17"
11+
model_dir = os.path.join(checkpoint_dir, model_name)
12+
checkpoint = os.path.join(checkpoint_dir, model_name, "best.pt")
13+
14+
cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet"
15+
16+
image_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/nucleus/2025-07_NIS3D/test"
17+
18+
out_dir = os.path.join(cochlea_dir, "predictions", "val_nucleus", "distance_unet_NIS3D") # /distance_unet
19+
20+
boundary_distance_threshold = 0.5
21+
seg_class = "ihc"
22+
23+
block_shape = (128, 128, 128)
24+
halo = (16, 32, 32)
25+
26+
block_shape_str = ",".join([str(b) for b in block_shape])
27+
halo_str = ",".join([str(h) for h in halo])
28+
29+
images = [entry.path for entry in os.scandir(image_dir) if entry.is_file() and "iitest.tif" in entry.path]
30+
31+
for image in images:
32+
sys.argv = [
33+
os.path.join(script_dir, "run_prediction_distance_unet.py"),
34+
f"--input={image}",
35+
f"--output_folder={out_dir}",
36+
f"--model={model_dir}",
37+
f"--block_shape=[{block_shape_str}]",
38+
f"--halo=[{halo_str}]",
39+
"--memory",
40+
"--time",
41+
"--no_masking",
42+
f"--seg_class={seg_class}",
43+
f"--boundary_distance_threshold={boundary_distance_threshold}"
44+
]
45+
46+
run_prediction_distance_unet.main()

scripts/baselines/NIS3D_eval.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import argparse
2+
import json
3+
import multiprocessing as mp
4+
import os
5+
from concurrent import futures
6+
from typing import List
7+
8+
import numpy as np
9+
import tifffile
10+
from tqdm import tqdm
11+
12+
GT_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/nucleus/2025-07_NIS3D/test"
13+
PRED_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/val_nucleus/distance_unet_NIS3D"
14+
15+
16+
def find_overlapping_masks(
17+
arr_base: np.ndarray,
18+
arr_ref: np.ndarray,
19+
label_id_base: int,
20+
min_overlap: float = 0.5,
21+
) -> List[int]:
22+
"""Find masks of segmentation, which have an overlap with undefined mask greater than 0.5.
23+
"""
24+
labels_undefined_mask = []
25+
arr_base_undefined = arr_base == label_id_base
26+
27+
# iterate through segmentation ids in reference mask
28+
ref_ids = list(np.unique(arr_ref)[1:])
29+
for ref_id in ref_ids:
30+
arr_ref_instance = arr_ref == ref_id
31+
32+
intersection = np.logical_and(arr_ref_instance, arr_base_undefined)
33+
overlap_ratio = np.sum(intersection) / np.sum(arr_ref_instance)
34+
if overlap_ratio >= min_overlap:
35+
labels_undefined_mask.append(ref_id)
36+
37+
return labels_undefined_mask
38+
39+
40+
def find_matching_masks(arr_gt, arr_ref, out_path, labels_undefined_mask=[]):
41+
"""For each instance in the reference array, the corresponding mask of the ground truth array,
42+
which has the biggest overlap, is identified.
43+
44+
Args:
45+
arr_gt:
46+
arr_ref:
47+
out_path: Output path for saving dictionary.
48+
labels_undefined_mask: Labels of the reference array to exclude.
49+
"""
50+
seg_ids_ref = [int(i) for i in np.unique(arr_ref)[1:]]
51+
print(f"total number of segmentation masks: {len(seg_ids_ref)}")
52+
seg_ids_ref = [s for s in seg_ids_ref if s not in labels_undefined_mask]
53+
print(f"number of segmentation masks after filtering undefined masks: {len(seg_ids_ref)}")
54+
55+
def compute_overlap(ref_id):
56+
"""Identify ID of segmentation mask with biggest overlap.
57+
Return matched IDs and overlap.
58+
"""
59+
arr_ref_instance = arr_ref == ref_id
60+
61+
seg_ids_gt = np.unique(arr_gt[arr_ref_instance])[1:]
62+
63+
max_overlap = 0
64+
gt_id_match = None
65+
66+
for gt_id in seg_ids_gt:
67+
arr_gt_instance = arr_gt == gt_id
68+
69+
intersection = np.logical_and(arr_ref_instance, arr_gt_instance)
70+
overlap_ratio = np.sum(intersection) / np.sum(arr_ref_instance)
71+
if overlap_ratio > max_overlap:
72+
gt_id_match = int(gt_id.tolist())
73+
max_overlap = np.max([max_overlap, overlap_ratio])
74+
75+
if gt_id_match is not None:
76+
return {
77+
"ref_id": ref_id,
78+
"gt_id": gt_id_match,
79+
"overlap": float(max_overlap.tolist())
80+
}
81+
else:
82+
return None
83+
84+
n_threads = min(16, mp.cpu_count())
85+
print(f"Parallelizing with {n_threads} Threads.")
86+
with futures.ThreadPoolExecutor(n_threads) as pool:
87+
results = list(tqdm(pool.map(compute_overlap, seg_ids_ref), total=len(seg_ids_ref)))
88+
89+
matching_masks = {r['ref_id']: r for r in results if r is not None}
90+
91+
with open(out_path, "w") as f:
92+
json.dump(matching_masks, f, indent='\t', separators=(',', ': '))
93+
94+
95+
def filter_true_positives(output_folder, prefixes, force_overwrite):
96+
""" Filter true positives from segmentation.
97+
Segmentation instances and ground truth labels are filtered symmetrically.
98+
The maximal overlap of each is computed and taken as a true positive if symmetric.
99+
The instance ID, the reference ID, and the overlap are saved in dictionaries.
100+
101+
Args:
102+
output_folder: Output folder for dictionaries.
103+
prefixes: List of prefixes for evaluation. One or multiple of ["Drosophila", "MusMusculus", "Zebrafish"].
104+
force_overwrite: Flag for forced overwrite of existing output files.
105+
"""
106+
if "PRED_DIR" in globals():
107+
pred_dir = PRED_DIR
108+
if "GT_DIR" in globals():
109+
gt_dir = GT_DIR
110+
111+
if prefixes is None:
112+
prefixes = ["Drosophila", "MusMusculus", "Zebrafish"]
113+
114+
for prefix in prefixes:
115+
conf_file = os.path.join(gt_dir, f"{prefix}_1_iitest_confidence.tif")
116+
annot_file = os.path.join(gt_dir, f"{prefix}_1_iitest_annotations.tif")
117+
conf_arr = tifffile.imread(conf_file)
118+
gt_arr = tifffile.imread(annot_file)
119+
120+
seg_file = os.path.join(pred_dir, f"{prefix}_1_iitest_seg.tif")
121+
seg_arr = tifffile.imread(seg_file)
122+
123+
# find largest overlap of ground truth mask with each segmentation instance
124+
out_path = os.path.join(output_folder, f"{prefix}_matching_ref_gt.json")
125+
if os.path.isfile(out_path) and not force_overwrite:
126+
print(f"Skipping the creation of {out_path}. File already exists.")
127+
else:
128+
# exclude detections with more than 50% of pixels in undefined category
129+
if 1 in np.unique(conf_arr)[1:]:
130+
labels_undefined_mask = find_overlapping_masks(conf_arr, seg_arr, label_id_base=1)
131+
else:
132+
labels_undefined_mask = []
133+
print("Array does not contain undefined mask")
134+
135+
find_matching_masks(gt_arr, seg_arr, out_path, labels_undefined_mask=labels_undefined_mask)
136+
137+
# find largest overlap of segmentation instance with each ground truth mask
138+
out_path = os.path.join(output_folder, f"{prefix}_matching_gt_ref.json")
139+
if os.path.isfile(out_path) and not force_overwrite:
140+
print(f"Skipping the creation of {out_path}. File already exists.")
141+
else:
142+
find_matching_masks(seg_arr, gt_arr, out_path)
143+
144+
145+
def main():
146+
parser = argparse.ArgumentParser()
147+
parser.add_argument("--output_folder", "-o", required=True)
148+
parser.add_argument("--prefix", "-p", nargs="+", type=str, default=None)
149+
parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.")
150+
args = parser.parse_args()
151+
152+
filter_true_positives(
153+
args.output_folder,
154+
args.prefix,
155+
args.force,
156+
)
157+
158+
159+
if __name__ == "__main__":
160+
main()
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import os
2+
3+
import numpy as np
4+
import tifffile
5+
6+
NIS3D_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/nucleus/NIS3D"
7+
TRAIN_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/nucleus/2025-07_NIS3D/train"
8+
VAL_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/nucleus/2025-07_NIS3D/val"
9+
TEST_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/nucleus/2025-07_NIS3D(test"
10+
11+
# ---Training data---
12+
13+
# clear: contains only 2,3,4 as seg ids
14+
train_dict_01 = {
15+
"data_dir": os.path.join(NIS3D_DIR, "suggestive_splitting/cross-image/train"),
16+
"name": "Drosophila_2",
17+
"conf_file": "ConfidenceScore.tif",
18+
"gt_file": "GroundTruth.tif",
19+
"output_dir": TRAIN_DIR,
20+
"output_name": "Drosophila_2_annotations.tif",
21+
}
22+
23+
# contains 1, 2, 3, 4
24+
train_dict_02 = {
25+
"data_dir": os.path.join(NIS3D_DIR, "suggestive_splitting/cross-image/train"),
26+
"name": "Zebrafish_2",
27+
"conf_file": "ConfidenceScore.tif",
28+
"gt_file": "GroundTruth.tif",
29+
"output_dir": TRAIN_DIR,
30+
"output_name": "Zebrafish_2_annotations.tif",
31+
}
32+
33+
# contains 1, 3, 4
34+
train_dict_03 = {
35+
"data_dir": os.path.join(NIS3D_DIR, "suggestive_splitting/cross-image/train"),
36+
"name": "MusMusculus_2",
37+
"conf_file": "scoreOfConfidence.tif",
38+
"gt_file": "gt.tif",
39+
"output_dir": TRAIN_DIR,
40+
"output_name": "MusMusculus_2_annotations.tif",
41+
}
42+
43+
# ---Validation data---
44+
45+
val_dict_01 = {
46+
"data_dir": os.path.join(NIS3D_DIR, "suggestive_splitting/in-image/train"),
47+
"name": "Drosophila_1",
48+
"conf_file": "ConfidenceScore.tif",
49+
"gt_file": "GroundTruth.tif",
50+
"output_dir": VAL_DIR,
51+
"output_name": "Drosophila_1_iitrain_annotations.tif",
52+
}
53+
54+
val_dict_02 = {
55+
"data_dir": os.path.join(NIS3D_DIR, "suggestive_splitting/in-image/train"),
56+
"name": "Zebrafish_1",
57+
"conf_file": "ConfidenceScore.tif",
58+
"gt_file": "GroundTruth.tif",
59+
"output_dir": VAL_DIR,
60+
"output_name": "Zebrafish_1_iitrain_annotations.tif",
61+
}
62+
63+
val_dict_03 = {
64+
"data_dir": os.path.join(NIS3D_DIR, "suggestive_splitting/in-image/train"),
65+
"name": "MusMusculus_1",
66+
"conf_file": "ConfidenceScore.tif",
67+
"gt_file": "GroundTruth.tif",
68+
"output_dir": VAL_DIR,
69+
"output_name": "MusMusculus_1_iitrain_annotations.tif",
70+
}
71+
72+
# ---Test data---
73+
74+
test_dict_01 = {
75+
"data_dir": os.path.join(NIS3D_DIR, "suggestive_splitting/in-image/test"),
76+
"name": "Drosophila_1",
77+
"conf_file": "ConfidenceScore.tif",
78+
"gt_file": "GroundTruth.tif",
79+
"output_dir": TEST_DIR,
80+
"output_name": "Drosophila_1_iitest_annotations.tif",
81+
}
82+
83+
test_dict_02 = {
84+
"data_dir": os.path.join(NIS3D_DIR, "suggestive_splitting/in-image/test"),
85+
"name": "Zebrafish_1",
86+
"conf_file": "ConfidenceScore.tif",
87+
"gt_file": "GroundTruth.tif",
88+
"output_dir": TEST_DIR,
89+
"output_name": "Zebrafish_1_iitest_annotations.tif",
90+
}
91+
92+
test_dict_03 = {
93+
"data_dir": os.path.join(NIS3D_DIR, "suggestive_splitting/in-image/test"),
94+
"name": "MusMusculus_1",
95+
"conf_file": "ConfidenceScore.tif",
96+
"gt_file": "GroundTruth.tif",
97+
"output_dir": TEST_DIR,
98+
"output_name": "MusMusculus_1_iitest_annotations.tif",
99+
}
100+
101+
102+
def filter_unmasked_data(conf_path, in_path, out_path):
103+
conf = tifffile.imread(conf_path)
104+
gt = tifffile.imread(in_path)
105+
segmentation_ids = list(np.unique(conf)[1:])
106+
if 1 in segmentation_ids:
107+
instance_ids = list(np.unique(gt)[1:])
108+
print(f"Number of instances before filtering: {len(instance_ids)}")
109+
gt[conf == 1] = 0
110+
instance_ids = list(np.unique(gt)[1:])
111+
print(f"Number of instances after filtering: {len(instance_ids)}")
112+
tifffile.imwrite(out_path, gt)
113+
else:
114+
instance_ids = list(np.unique(gt)[1:])
115+
print(f"Number of instances: {len(instance_ids)}")
116+
tifffile.imwrite(out_path, gt)
117+
118+
119+
def process_data_dicts(data_dicts):
120+
for data_dict in data_dicts:
121+
data_dir = data_dict["data_dir"]
122+
dataset = os.path.join(data_dir, data_dict["name"])
123+
conf_path = os.path.join(dataset, data_dict["conf_file"])
124+
gt_path = os.path.join(dataset, data_dict["gt_file"])
125+
126+
out_dir = data_dict["output_dir"]
127+
out_name = data_dict["output_name"]
128+
out_path = os.path.join(out_dir, out_name)
129+
filter_unmasked_data(conf_path, in_path=gt_path, out_path=out_path)
130+
131+
132+
def prepare_training_data():
133+
"""Prepare training data based on NIS3D data.
134+
135+
Cross-image data of half of the samples is used for training.
136+
The other half of the samples is divided into validation data used for training and test data.
137+
The in-image data is used for this, so that every remaining sample is split in half.
138+
"""
139+
process_data_dicts([test_dict_01, test_dict_02, test_dict_03])

scripts/baselines/NIS3D_train.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
3+
export MODEL_NAME="nucleus_NIS3D_supervised_2025-07-17"
4+
5+
export IDIR=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/nucleus/2025-07_NIS3D
6+
7+
export SCRIPT_DIR=/user/schilling40/u15000/flamingo-tools/scripts/training
8+
9+
python $SCRIPT_DIR/train_distance_unet.py -i $IDIR --name $MODEL_NAME
10+

0 commit comments

Comments
 (0)