Skip to content

Commit 451ff69

Browse files
Implement vesicle post-processing
1 parent 224ef6c commit 451ff69

File tree

4 files changed

+93
-11
lines changed

4 files changed

+93
-11
lines changed

scripts/otoferlin/automatic_processing.py

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import numpy as np
55
import pandas as pd
66

7+
from skimage.measure import label
8+
from skimage.segmentation import relabel_sequential
9+
710
from synapse_net.distance_measurements import measure_segmentation_to_object_distances, load_distances
811
from synapse_net.file_utils import read_mrc
912
from synapse_net.inference.vesicles import segment_vesicles
@@ -12,6 +15,12 @@
1215

1316
from common import STRUCTURE_NAMES, get_all_tomograms, get_seg_path, get_adapted_model
1417

18+
# These are tomograms for which the sophisticated membrane processing fails.
19+
# In this case, we just select the largest boundary piece.
20+
SIMPLE_MEM_POSTPROCESSING = [
21+
"Otof_TDAKO1blockA_GridN5_2_rec.mrc", "Otof_TDAKO2blockC_GridF5_1_rec.mrc", "Otof_TDAKO2blockC_GridF5_2_rec.mrc"
22+
]
23+
1524

1625
def _get_center_crop(input_):
1726
halo_xy = (600, 600)
@@ -55,6 +64,13 @@ def process_vesicles(mrc_path, output_path, process_center_crop):
5564
f.create_dataset(key, data=segmentation, compression="gzip")
5665

5766

67+
def _simple_membrane_postprocessing(membrane_prediction):
68+
seg = label(membrane_prediction)
69+
ids, sizes = np.unique(seg, return_counts=True)
70+
ids, sizes = ids[1:], sizes[1:]
71+
return (seg == ids[np.argmax(sizes)]).astype("uint8")
72+
73+
5874
def process_ribbon_structures(mrc_path, output_path, process_center_crop):
5975
key = "segmentation/ribbon"
6076
with h5py.File(output_path, "r") as f:
@@ -78,6 +94,12 @@ def process_ribbon_structures(mrc_path, output_path, process_center_crop):
7894
return_predictions=True, n_slices_exclude=5,
7995
)
8096

97+
# The distance based post-processing for membranes fails for some tomograms.
98+
# In these cases, just choose the largest membrane piece.
99+
fname = os.path.basename(mrc_path)
100+
if fname in SIMPLE_MEM_POSTPROCESSING:
101+
segmentations["membrane"] = _simple_membrane_postprocessing(predictions["membrane"])
102+
81103
if process_center_crop:
82104
for name, seg in segmentations.items():
83105
full_seg = np.zeros(full_shape, dtype=seg.dtype)
@@ -94,6 +116,49 @@ def process_ribbon_structures(mrc_path, output_path, process_center_crop):
94116
f.create_dataset(f"prediction/{name}", data=predictions[name], compression="gzip")
95117

96118

119+
def postprocess_vesicles(mrc_path, output_path, process_center_crop):
120+
key = "segmentation/veiscles_postprocessed"
121+
with h5py.File(output_path, "r") as f:
122+
if key in f:
123+
return
124+
vesicles = f["segmentation/vesicles"][:]
125+
if process_center_crop:
126+
bb, full_shape = _get_center_crop(vesicles)
127+
vesicles = vesicles[bb]
128+
else:
129+
bb = np.s_[:]
130+
131+
ribbon = f["segmentation/ribbon"][bb]
132+
membrane = f["segmentation/membrane"][bb]
133+
134+
# Filter out small vesicle fragments.
135+
min_size = 5000
136+
ids, sizes = np.unique(vesicles, return_counts=True)
137+
ids, sizes = ids[1:], sizes[1:]
138+
filter_ids = ids[sizes < min_size]
139+
vesicles[np.isin(vesicles, filter_ids)] = 0
140+
141+
input_, voxel_size = read_mrc(mrc_path)
142+
voxel_size = tuple(voxel_size[ax] for ax in "zyx")
143+
input_ = input_[bb]
144+
145+
# Filter out all vesicles farther than 120 nm from the membrane or ribbon.
146+
max_dist = 120
147+
seg = (ribbon + membrane) > 0
148+
distances, _, _, seg_ids = measure_segmentation_to_object_distances(vesicles, seg, resolution=voxel_size)
149+
filter_ids = seg_ids[distances > max_dist]
150+
vesicles[np.isin(vesicles, filter_ids)] = 0
151+
152+
vesicles, _, _ = relabel_sequential(vesicles)
153+
154+
if process_center_crop:
155+
full_seg = np.zeros(full_shape, dtype=vesicles.dtype)
156+
full_seg[bb] = vesicles
157+
vesicles = full_seg
158+
with h5py.File(output_path, "a") as f:
159+
f.create_dataset(key, data=vesicles, compression="gzip")
160+
161+
97162
def measure_distances(mrc_path, seg_path, output_folder):
98163
result_folder = os.path.join(output_folder, "distances")
99164
if os.path.exists(result_folder):
@@ -171,20 +236,32 @@ def process_tomogram(mrc_path):
171236

172237
process_vesicles(mrc_path, output_path, process_center_crop)
173238
process_ribbon_structures(mrc_path, output_path, process_center_crop)
174-
return
175-
# TODO vesicle post-processing:
176-
# snap to boundaries?
177-
# remove vesicles in ribbon
239+
postprocess_vesicles(mrc_path, output_path, process_center_crop)
178240

179-
measure_distances(mrc_path, output_path, output_folder)
180-
assign_vesicle_pools(output_folder)
241+
# We don't need to do the analysis of the auto semgentation, it only
242+
# makes sense to do this after segmentation. I am leaving this here for
243+
# now, to move it to the files for analysis later.
244+
245+
# measure_distances(mrc_path, output_path, output_folder)
246+
# assign_vesicle_pools(output_folder)
181247

182248

183249
def main():
184250
tomograms = get_all_tomograms()
185251
for tomogram in tqdm(tomograms, desc="Process tomograms"):
186252
process_tomogram(tomogram)
187253

254+
# Update the membrane postprocessing for the tomograms where this went wrong.
255+
# for tomo in tqdm(tomograms, desc="Fix membrame postprocesing"):
256+
# if os.path.basename(tomo) not in SIMPLE_MEM_POSTPROCESSING:
257+
# continue
258+
# seg_path = get_seg_path(tomo)
259+
# with h5py.File(seg_path, "r") as f:
260+
# pred = f["prediction/membrane"][:]
261+
# seg = _simple_membrane_postprocessing(pred)
262+
# with h5py.File(seg_path, "a") as f:
263+
# f["segmentation/membrane"][:] = seg
264+
188265

189266
if __name__:
190267
main()

scripts/otoferlin/check_automatic_result.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ def main():
7373
enumerate(tomograms), total=len(tomograms), desc="Visualize automatic segmentation results"
7474
):
7575
print("Checking tomogram", tomogram)
76+
check_automatic_result(tomogram, version)
7677
# check_automatic_result(tomogram, version, segmentation_group="vesicles")
77-
check_automatic_result(tomogram, version, segmentation_group="prediction")
78+
# check_automatic_result(tomogram, version, segmentation_group="prediction")
7879

7980

8081
if __name__:

scripts/otoferlin/check_structure_postprocessing.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,18 @@ def check_structure_postprocessing(mrc_path, center_crop=True):
2929
g = f["segmentation"]
3030
for name in STRUCTURE_NAMES:
3131
segmentations[f"seg/{name}"] = g[name][bb]
32+
colormaps[name] = get_colormaps().get(name, None)
33+
3234
g = f["prediction"]
3335
for name in STRUCTURE_NAMES:
3436
predictions[f"pred/{name}"] = g[name][bb]
35-
colormaps[name] = get_colormaps().get(name, None)
3637

3738
v = napari.Viewer()
3839
v.add_image(tomogram)
3940
for name, seg in segmentations.items():
4041
v.add_labels(seg, name=name, colormap=colormaps.get(name.split("/")[1]))
41-
for name, seg in predictions.items():
42-
v.add_labels(seg, name=name, colormap=colormaps.get(name.split("/")[1]), visible=False)
42+
for name, pred in predictions.items():
43+
v.add_labels(pred, name=name, colormap=colormaps.get(name.split("/")[1]), visible=False)
4344
v.title = os.path.basename(mrc_path)
4445
napari.run()
4546

synapse_net/ground_truth/shape_refinement.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def refine_individual_vesicle_shapes(
203203
edge_map: np.ndarray,
204204
foreground_erosion: int = 4,
205205
background_erosion: int = 8,
206+
compactness: float = 0.5,
206207
) -> np.ndarray:
207208
"""Refine vesicle shapes by fitting vesicles to a boundary map.
208209
@@ -215,6 +216,8 @@ def refine_individual_vesicle_shapes(
215216
You can use `edge_filter` to compute this based on the tomogram.
216217
foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
217218
background_erosion: By how many pixels the background should be eroded in the seeds.
219+
compactness: The compactness parameter passed to the watershed function.
220+
Higher compactness leads to more regular sized vesicles.
218221
Returns:
219222
The refined vesicles.
220223
"""
@@ -250,7 +253,7 @@ def fit_vesicle(prop):
250253

251254
# Run seeded watershed to fit the shapes.
252255
seeds = fg_seed + 2 * bg_seed
253-
seg[z] = watershed(hmap[z], seeds) == 1
256+
seg[z] = watershed(hmap[z], seeds, compactness=compactness) == 1
254257

255258
# import napari
256259
# v = napari.Viewer()

0 commit comments

Comments
 (0)