44import numpy as np
55import pandas as pd
66
7+ from skimage .measure import label
8+ from skimage .segmentation import relabel_sequential
9+
710from synapse_net .distance_measurements import measure_segmentation_to_object_distances , load_distances
811from synapse_net .file_utils import read_mrc
912from synapse_net .inference .vesicles import segment_vesicles
1215
1316from common import STRUCTURE_NAMES , get_all_tomograms , get_seg_path , get_adapted_model
1417
18+ # These are tomograms for which the sophisticated membrane processing fails.
19+ # In this case, we just select the largest boundary piece.
20+ SIMPLE_MEM_POSTPROCESSING = [
21+ "Otof_TDAKO1blockA_GridN5_2_rec.mrc" , "Otof_TDAKO2blockC_GridF5_1_rec.mrc" , "Otof_TDAKO2blockC_GridF5_2_rec.mrc"
22+ ]
23+
1524
1625def _get_center_crop (input_ ):
1726 halo_xy = (600 , 600 )
@@ -55,6 +64,13 @@ def process_vesicles(mrc_path, output_path, process_center_crop):
5564 f .create_dataset (key , data = segmentation , compression = "gzip" )
5665
5766
67+ def _simple_membrane_postprocessing (membrane_prediction ):
68+ seg = label (membrane_prediction )
69+ ids , sizes = np .unique (seg , return_counts = True )
70+ ids , sizes = ids [1 :], sizes [1 :]
71+ return (seg == ids [np .argmax (sizes )]).astype ("uint8" )
72+
73+
5874def process_ribbon_structures (mrc_path , output_path , process_center_crop ):
5975 key = "segmentation/ribbon"
6076 with h5py .File (output_path , "r" ) as f :
@@ -78,6 +94,12 @@ def process_ribbon_structures(mrc_path, output_path, process_center_crop):
7894 return_predictions = True , n_slices_exclude = 5 ,
7995 )
8096
97+ # The distance based post-processing for membranes fails for some tomograms.
98+ # In these cases, just choose the largest membrane piece.
99+ fname = os .path .basename (mrc_path )
100+ if fname in SIMPLE_MEM_POSTPROCESSING :
101+ segmentations ["membrane" ] = _simple_membrane_postprocessing (predictions ["membrane" ])
102+
81103 if process_center_crop :
82104 for name , seg in segmentations .items ():
83105 full_seg = np .zeros (full_shape , dtype = seg .dtype )
@@ -94,6 +116,49 @@ def process_ribbon_structures(mrc_path, output_path, process_center_crop):
94116 f .create_dataset (f"prediction/{ name } " , data = predictions [name ], compression = "gzip" )
95117
96118
119+ def postprocess_vesicles (mrc_path , output_path , process_center_crop ):
120+ key = "segmentation/veiscles_postprocessed"
121+ with h5py .File (output_path , "r" ) as f :
122+ if key in f :
123+ return
124+ vesicles = f ["segmentation/vesicles" ][:]
125+ if process_center_crop :
126+ bb , full_shape = _get_center_crop (vesicles )
127+ vesicles = vesicles [bb ]
128+ else :
129+ bb = np .s_ [:]
130+
131+ ribbon = f ["segmentation/ribbon" ][bb ]
132+ membrane = f ["segmentation/membrane" ][bb ]
133+
134+ # Filter out small vesicle fragments.
135+ min_size = 5000
136+ ids , sizes = np .unique (vesicles , return_counts = True )
137+ ids , sizes = ids [1 :], sizes [1 :]
138+ filter_ids = ids [sizes < min_size ]
139+ vesicles [np .isin (vesicles , filter_ids )] = 0
140+
141+ input_ , voxel_size = read_mrc (mrc_path )
142+ voxel_size = tuple (voxel_size [ax ] for ax in "zyx" )
143+ input_ = input_ [bb ]
144+
145+ # Filter out all vesicles farther than 120 nm from the membrane or ribbon.
146+ max_dist = 120
147+ seg = (ribbon + membrane ) > 0
148+ distances , _ , _ , seg_ids = measure_segmentation_to_object_distances (vesicles , seg , resolution = voxel_size )
149+ filter_ids = seg_ids [distances > max_dist ]
150+ vesicles [np .isin (vesicles , filter_ids )] = 0
151+
152+ vesicles , _ , _ = relabel_sequential (vesicles )
153+
154+ if process_center_crop :
155+ full_seg = np .zeros (full_shape , dtype = vesicles .dtype )
156+ full_seg [bb ] = vesicles
157+ vesicles = full_seg
158+ with h5py .File (output_path , "a" ) as f :
159+ f .create_dataset (key , data = vesicles , compression = "gzip" )
160+
161+
97162def measure_distances (mrc_path , seg_path , output_folder ):
98163 result_folder = os .path .join (output_folder , "distances" )
99164 if os .path .exists (result_folder ):
@@ -171,20 +236,32 @@ def process_tomogram(mrc_path):
171236
172237 process_vesicles (mrc_path , output_path , process_center_crop )
173238 process_ribbon_structures (mrc_path , output_path , process_center_crop )
174- return
175- # TODO vesicle post-processing:
176- # snap to boundaries?
177- # remove vesicles in ribbon
239+ postprocess_vesicles (mrc_path , output_path , process_center_crop )
178240
179- measure_distances (mrc_path , output_path , output_folder )
180- assign_vesicle_pools (output_folder )
241+ # We don't need to do the analysis of the auto semgentation, it only
242+ # makes sense to do this after segmentation. I am leaving this here for
243+ # now, to move it to the files for analysis later.
244+
245+ # measure_distances(mrc_path, output_path, output_folder)
246+ # assign_vesicle_pools(output_folder)
181247
182248
183249def main ():
184250 tomograms = get_all_tomograms ()
185251 for tomogram in tqdm (tomograms , desc = "Process tomograms" ):
186252 process_tomogram (tomogram )
187253
254+ # Update the membrane postprocessing for the tomograms where this went wrong.
255+ # for tomo in tqdm(tomograms, desc="Fix membrame postprocesing"):
256+ # if os.path.basename(tomo) not in SIMPLE_MEM_POSTPROCESSING:
257+ # continue
258+ # seg_path = get_seg_path(tomo)
259+ # with h5py.File(seg_path, "r") as f:
260+ # pred = f["prediction/membrane"][:]
261+ # seg = _simple_membrane_postprocessing(pred)
262+ # with h5py.File(seg_path, "a") as f:
263+ # f["segmentation/membrane"][:] = seg
264+
188265
189266if __name__ :
190267 main ()
0 commit comments