Skip to content

Commit 39f28d8

Browse files
Update synapse validation
1 parent ecebe20 commit 39f28d8

File tree

5 files changed

+355
-56
lines changed

5 files changed

+355
-56
lines changed

flamingo_tools/segmentation/synapse_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def map_and_filter_detections(
2525
Args:
2626
segmentation: The IHC segmentation.
2727
detections: The synapse marker detections.
28-
max_distance: The maximal distance for a valid match of synapse markers to IHCs.
28+
max_distance: The maximal distance in micrometer for a valid match of synapse markers to IHCs.
2929
resolution: The resolution / voxel size of the data in micrometer.
3030
n_threads: The number of threads for parallelizing the mapping of detections to objects.
3131
verbose: Whether to print the progress of the mapping procedure.

scripts/synapse_marker_detection/extract_training_data.py

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44

55
import h5py
6+
import imageio.v3 as imageio
67
import napari
78
import numpy as np
89
import pandas as pd
@@ -19,34 +20,61 @@ def get_voxel_size(imaris_file):
1920
return vsize
2021

2122

22-
def extract_training_data(imaris_file, output_folder, crop=True, scale=True):
23+
def get_transformation(imaris_file):
24+
with h5py.File(imaris_file) as f:
25+
info = f["DataSetInfo"]["Image"].attrs
26+
ext_min = np.array([float(b"".join(info[f"ExtMin{i}"]).decode()) for i in range(3)])
27+
ext_max = np.array([float(b"".join(info[f"ExtMax{i}"]).decode()) for i in range(3)])
28+
size = [int(b"".join(info[dim]).decode()) for dim in ["X", "Y", "Z"]]
29+
spacing = (ext_max - ext_min) / size # µm / voxel
30+
31+
# build 4×4 affine: world → index
32+
T = np.eye(4)
33+
T[:3, :3] = np.diag(1/spacing) # scale
34+
T[:3, 3] = -ext_min/spacing # translate
35+
36+
return T
37+
38+
39+
def extract_training_data(imaris_file, output_folder, tif_file=None, crop=True):
2340
point_key = "/Scene/Content/Points0/CoordsXYZR"
2441
with h5py.File(imaris_file, "r") as f:
2542
if point_key not in f:
2643
print("Skipping", imaris_file, "due to missing annotations")
2744
return
28-
data = f["/DataSet/ResolutionLevel 0/TimePoint 0/Channel 0/Data"][:]
2945
points = f[point_key][:]
3046
points = points[:, :-1]
31-
points = points[:, ::-1]
3247

33-
# TODO crop the data to the original shape.
34-
# Can we just crop the zero-padding ?!
48+
g = f["/DataSet/ResolutionLevel 0/TimePoint 0"]
49+
# The first channel is ctbp2 / the synapse marker channel.
50+
data = g["Channel 0/Data"][:]
51+
# The second channel is vglut / the ihc channel.
52+
if "Channel 1" in g:
53+
ihc_data = g["Channel 1/Data"][:]
54+
else:
55+
ihc_data = None
56+
57+
T = get_transformation(imaris_file)
58+
points = (T @ np.c_[points, np.ones(len(points))].T).T[:, :3]
59+
points = points[:, ::-1]
60+
3561
if crop:
3662
crop_box = np.where(data != 0)
3763
crop_box = tuple(slice(0, int(cb.max() + 1)) for cb in crop_box)
3864
data = data[crop_box]
3965

40-
# Scale the points to match the image dimensions.
41-
voxel_size = get_voxel_size(imaris_file)
42-
if scale:
43-
points /= voxel_size[None]
44-
45-
print(data.shape, voxel_size)
66+
if tif_file is None:
67+
original_data = None
68+
else:
69+
original_data = imageio.imread(tif_file)
4670

4771
if output_folder is None:
4872
v = napari.Viewer()
4973
v.add_image(data)
74+
if ihc_data is not None:
75+
v.add_image(ihc_data)
76+
if original_data is not None:
77+
v.add_image(original_data, visible=False)
5078
v.add_points(points)
5179
v.title = os.path.basename(imaris_file)
5280
napari.run()
@@ -66,6 +94,8 @@ def extract_training_data(imaris_file, output_folder, crop=True, scale=True):
6694

6795
f = zarr.open(image_file, "a")
6896
f.create_dataset("raw", data=data)
97+
if ihc_data is not None:
98+
f.create_dataset("raw_ihc", data=ihc_data)
6999

70100

71101
# Files that look good for training:
@@ -82,6 +112,21 @@ def process_training_data_v1():
82112
extract_training_data(ff, output_folder="./training_data")
83113

84114

115+
def _match_tif(imaris_file):
116+
folder = os.path.split(imaris_file)[0]
117+
118+
fname = os.path.basename(imaris_file)
119+
parts = fname.split("_")
120+
cochlea = parts[0].upper()
121+
region = parts[1]
122+
123+
tif_name = f"{cochlea}_{region}_CTBP2.tif"
124+
tif_path = os.path.join(folder, tif_name)
125+
assert os.path.exists(tif_path), tif_path
126+
127+
return tif_path
128+
129+
85130
def process_training_data_v2(visualize=True):
86131
input_root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/ImageCropsIHC_synapses"
87132

@@ -110,16 +155,46 @@ def process_training_data_v2(visualize=True):
110155

111156
imaris_files = sorted(glob(os.path.join(input_root, folder, "*.ims")))
112157
for imaris_file in imaris_files:
113-
fname = os.path.basename(imaris_file)
114-
if fname not in valid_files:
158+
if os.path.basename(imaris_file) not in valid_files:
159+
continue
160+
extract_training_data(imaris_file, output_folder, tif_file=None, crop=True, scale=True)
161+
162+
163+
# We have fixed the imaris data extraction problem and can use all the crops!
164+
def process_training_data_v3(visualize=True):
165+
input_root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/ImageCropsIHC_synapses"
166+
167+
train_output = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v3" # noqa
168+
test_output = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_data/v3" # noqa
169+
170+
train_folders = ["synapse_stains", "M78L_IHC-synapse_crops", "M226R_IHC-synapsecrops"]
171+
test_folders = ["M226L_IHC-synapse_crops"]
172+
173+
exclude_names = ["220824_Ex3IL_rbCAST1635_mCtBP2580_chCR488_cell1_CtBP2spots.ims"]
174+
175+
for folder in train_folders + test_folders:
176+
177+
if visualize:
178+
output_folder = None
179+
elif folder in train_folders:
180+
output_folder = train_output
181+
os.makedirs(output_folder, exist_ok=True)
182+
else:
183+
output_folder = test_output
184+
os.makedirs(output_folder, exist_ok=True)
185+
186+
imaris_files = sorted(glob(os.path.join(input_root, folder, "*.ims")))
187+
for imaris_file in imaris_files:
188+
if os.path.basename(imaris_file) in exclude_names:
189+
print("Skipping", imaris_file)
115190
continue
116-
print(fname)
117-
extract_training_data(imaris_file, output_folder, crop=True, scale=True)
191+
extract_training_data(imaris_file, output_folder, tif_file=None, crop=True)
118192

119193

120194
def main():
121195
# process_training_data_v1()
122-
process_training_data_v2(visualize=False)
196+
# process_training_data_v2(visualize=True)
197+
process_training_data_v3(visualize=False)
123198

124199

125200
if __name__ == "__main__":

scripts/synapse_marker_detection/train_synapse_detection.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,42 @@
11
import os
22
import sys
3+
from glob import glob
34

5+
from sklearn.model_selection import train_test_split
46
from detection_dataset import DetectionDataset, MinPointSampler
57

68
sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge")
79
sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge")
810

911
from utils.training.training import supervised_training # noqa
1012

11-
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v2" # noqa
13+
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v3" # noqa
1214
TRAIN_ROOT = os.path.join(ROOT, "images")
1315
LABEL_ROOT = os.path.join(ROOT, "labels")
1416

1517

1618
def get_paths(split):
17-
file_names = [
18-
"4.1L_apex_IHCribboncount_Z",
19-
"4.1L_base_IHCribbons_Z",
20-
"4.1L_mid_IHCribboncount_Z",
21-
"4.2R_apex_IHCribboncount_Z",
22-
"4.2R_apex_IHCribboncount_Z",
23-
"6.2R_apex_IHCribboncount_Z",
24-
"m78l_apexp2718_cr-ctbp2",
25-
"6.2R_base_IHCribbons_Z",
26-
]
27-
image_paths = [os.path.join(TRAIN_ROOT, f"{fname}.zarr") for fname in file_names]
28-
label_paths = [os.path.join(LABEL_ROOT, f"{fname}.csv") for fname in file_names]
19+
image_paths = sorted(glob(os.path.join(TRAIN_ROOT, "*.zarr")))
20+
label_paths = sorted(glob(os.path.join(LABEL_ROOT, "*.csv")))
21+
assert len(image_paths) == len(label_paths)
22+
23+
train_images, val_images, train_labels, val_labels = train_test_split(
24+
image_paths, label_paths, test_size=2, random_state=42
25+
)
2926

3027
if split == "train":
31-
image_paths = image_paths[:-1]
32-
label_paths = label_paths[:-1]
28+
image_paths = train_images
29+
label_paths = train_labels
3330
else:
34-
image_paths = image_paths[-1:]
35-
label_paths = label_paths[-1:]
36-
37-
for path in image_paths:
38-
assert os.path.exists(path), path
31+
image_paths = val_images
32+
label_paths = val_labels
3933

4034
return image_paths, label_paths
4135

4236

43-
# TODO maybe add a sampler for the label data
4437
def train():
4538

46-
model_name = "synapse_detection_v2"
39+
model_name = "synapse_detection_v3"
4740

4841
train_paths, train_label_paths = get_paths("train")
4942
val_paths, val_label_paths = get_paths("val")
@@ -56,7 +49,7 @@ def train():
5649

5750
patch_shape = [40, 112, 112]
5851
batch_size = 32
59-
check = True
52+
check = False
6053

6154
supervised_training(
6255
name=model_name,
@@ -81,7 +74,7 @@ def train():
8174
dataset_class=DetectionDataset,
8275
n_samples_train=3200,
8376
n_samples_val=160,
84-
sampler=MinPointSampler(min_points=1, p_reject=0.6),
77+
sampler=MinPointSampler(min_points=1, p_reject=0.8),
8578
)
8679

8780

scripts/validation/synapses/prediction.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
from elf.parallel.local_maxima import find_local_maxima
1111
from flamingo_tools.segmentation.unet_prediction import prediction_impl, run_unet_prediction
1212

13-
INPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_data/v2/images" # noqa
13+
INPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_data/v3/images" # noqa
14+
GT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_data/v3/labels"
1415
OUTPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/SynapseValidation"
1516

1617
sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge")
1718
sys.path.append("../../synapse_marker_detection")
1819

1920

2021
def pred_synapse_impl(input_path, output_folder):
21-
model_path = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/synapse_marker_detection/checkpoints/synapse_detection_v2" # noqa
22+
model_path = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/synapse_marker_detection/checkpoints/synapse_detection_v3" # noqa
2223
input_key = "raw"
2324

2425
block_shape = (32, 128, 128)
@@ -48,7 +49,7 @@ def pred_synapse_impl(input_path, output_folder):
4849

4950

5051
def predict_synapses():
51-
files = glob(os.path.join(INPUT_ROOT, "*.zarr"))
52+
files = sorted(glob(os.path.join(INPUT_ROOT, "*.zarr")))
5253
for ff in files:
5354
print("Segmenting", ff)
5455
output_folder = os.path.join(OUTPUT_ROOT, Path(ff).stem)
@@ -59,34 +60,95 @@ def pred_ihc_impl(input_path, output_folder):
5960
model_path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/IHC/v2_cochlea_distance_unet_IHC_supervised_2025-05-21" # noqa
6061

6162
run_unet_prediction(
62-
input_path, input_key=None, output_folder=output_folder, model_path=model_path, min_size=1000,
63+
input_path, input_key="raw_ihc", output_folder=output_folder, model_path=model_path, min_size=1000,
6364
seg_class="ihc", center_distance_threshold=0.5, boundary_distance_threshold=0.5,
6465
)
6566

6667

6768
def predict_ihcs():
68-
files = [
69-
"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/ImageCropsIHC_synapses/M226R_IHC-synapsecrops/M226R_base_p800_Vglut3.tif", # noqa
70-
"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/ImageCropsIHC_synapses/M226R_IHC-synapsecrops/M226R_apex_p1268_Vglut3.tif", # noqa
71-
]
69+
files = sorted(glob(os.path.join(INPUT_ROOT, "*.zarr")))
7270
for ff in files:
7371
print("Segmenting", ff)
74-
output_folder = os.path.join(OUTPUT_ROOT, Path(ff).stem)
72+
output_folder = os.path.join(OUTPUT_ROOT, f"{Path(ff).stem}_ihc")
7573
pred_ihc_impl(ff, output_folder)
7674

7775

78-
# TODO also filter GT
76+
def _filter_synapse_impl(detections, ihc_file, output_path):
77+
from flamingo_tools.segmentation.synapse_detection import map_and_filter_detections
78+
79+
with open_file(ihc_file, mode="r") as f:
80+
if "segmentation_filtered" in f:
81+
print("Using filtered segmentation!")
82+
segmentation = open_file(ihc_file)["segmentation_filtered"][:]
83+
else:
84+
segmentation = open_file(ihc_file)["segmentation"][:]
85+
86+
max_distance = 5 # 5 micrometer
87+
filtered_detections = map_and_filter_detections(segmentation, detections, max_distance=max_distance)
88+
filtered_detections.to_csv(output_path, index=False, sep="\t")
89+
90+
7991
def filter_synapses():
80-
pass
92+
input_files = sorted(glob(os.path.join(INPUT_ROOT, "*.zarr")))
93+
for ff in input_files:
94+
ihc = os.path.join(OUTPUT_ROOT, f"{Path(ff).stem}_ihc", "segmentation.zarr")
95+
output_folder = os.path.join(OUTPUT_ROOT, Path(ff).stem)
96+
synapses = os.path.join(output_folder, "synapse_detection.tsv")
97+
synapses = pd.read_csv(synapses, sep="\t")
98+
output_path = os.path.join(output_folder, "filtered_synapse_detection.tsv")
99+
_filter_synapse_impl(synapses, ihc, output_path)
100+
101+
102+
def filter_gt():
103+
input_files = sorted(glob(os.path.join(INPUT_ROOT, "*.zarr")))
104+
gt_files = sorted(glob(os.path.join(GT_ROOT, "*.csv")))
105+
for ff, gt in zip(input_files, gt_files):
106+
ihc = os.path.join(OUTPUT_ROOT, f"{Path(ff).stem}_ihc", "segmentation.zarr")
107+
output_folder, fname = os.path.split(gt)
108+
output_path = os.path.join(output_folder, fname.replace(".csv", "_filtered.tsv"))
109+
110+
gt = pd.read_csv(gt)
111+
gt = gt.rename(columns={"axis-0": "z", "axis-1": "y", "axis-2": "x"})
112+
gt.insert(0, "spot_id", np.arange(1, len(gt) + 1))
113+
114+
_filter_synapse_impl(gt, ihc, output_path)
115+
116+
117+
def _check_prediction(input_file, ihc_file, detection_file):
118+
import napari
119+
120+
synapses = pd.read_csv(detection_file, sep="\t")[["z", "y", "x"]].values
121+
122+
vglut = open_file(input_file)["raw_ihc"][:]
123+
ctbp2 = open_file(input_file)["raw"][:]
124+
ihcs = open_file(ihc_file)["segmentation"][:]
125+
126+
v = napari.Viewer()
127+
v.add_image(vglut)
128+
v.add_image(ctbp2)
129+
v.add_labels(ihcs)
130+
v.add_points(synapses)
131+
napari.run()
81132

82133

83134
def check_predictions():
84-
pass
135+
input_files = sorted(glob(os.path.join(INPUT_ROOT, "*.zarr")))
136+
for ff in input_files:
137+
ihc = os.path.join(OUTPUT_ROOT, f"{Path(ff).stem}_ihc", "segmentation.zarr")
138+
synapses = os.path.join(OUTPUT_ROOT, Path(ff).stem, "filtered_synapse_detection.tsv")
139+
_check_prediction(ff, ihc, synapses)
85140

86141

87-
def main():
88-
# predict_synapses()
142+
def process_everything():
143+
predict_synapses()
89144
predict_ihcs()
145+
filter_synapses()
146+
filter_gt()
147+
148+
149+
def main():
150+
process_everything()
151+
# check_predictions()
90152

91153

92154
if __name__ == "__main__":

0 commit comments

Comments
 (0)