Skip to content

Commit cab37c0

Browse files
Implement splitting of non-convex objects
1 parent d748003 commit cab37c0

File tree

2 files changed

+247
-1
lines changed

2 files changed

+247
-1
lines changed

flamingo_tools/segmentation/postprocessing.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import math
22
import multiprocessing as mp
3+
import threading
34
from concurrent import futures
4-
from typing import Callable, List, Optional, Tuple
5+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
56

67
import elf.parallel as parallel
78
import numpy as np
@@ -15,6 +16,9 @@
1516
from scipy.spatial import distance
1617
from scipy.spatial import cKDTree, ConvexHull
1718
from skimage import measure
19+
from skimage.filters import gaussian
20+
from skimage.feature import peak_local_max
21+
from skimage.segmentation import find_boundaries, watershed
1822
from sklearn.neighbors import NearestNeighbors
1923
from tqdm import tqdm
2024

@@ -734,3 +738,116 @@ def filter_cochlea_volume(
734738
combined_dilated[combined_dilated > 0] = 1
735739

736740
return combined_dilated
741+
742+
743+
def split_nonconvex_objects(
744+
segmentation: np.typing.ArrayLike,
745+
output: np.typing.ArrayLike,
746+
segmentation_table: pd.DataFrame,
747+
min_size: int,
748+
resolution: Union[float, Sequence[float]],
749+
height_map: Optional[np.typing.ArrayLike] = None,
750+
component_labels: Optional[List[int]] = None,
751+
n_threads: Optional[int] = None,
752+
) -> Dict[int, List[int]]:
753+
"""Split noncovex objects into multiple parts inplace.
754+
755+
Args:
756+
segmentation:
757+
output:
758+
segmentation_table:
759+
min_size:
760+
resolution:
761+
height_map:
762+
component_labels:
763+
n_threads:
764+
"""
765+
if isinstance(resolution, float):
766+
resolution = [resolution] * 3
767+
assert len(resolution) == 3
768+
resolution = np.array(resolution)
769+
770+
lock = threading.Lock()
771+
offset = len(segmentation_table)
772+
773+
def split_object(object_id):
774+
nonlocal offset
775+
776+
row = segmentation_table[segmentation_table.label_id == object_id]
777+
if min_size and row.n_pixels.values[0] < min_size:
778+
return [object_id]
779+
780+
bb_min = np.array([
781+
row.bb_min_z.values[0], row.bb_min_y.values[0], row.bb_min_x.values[0],
782+
]) / resolution
783+
bb_max = np.array([
784+
row.bb_max_z.values[0], row.bb_max_y.values[0], row.bb_max_x.values[0],
785+
]) / resolution
786+
787+
bb_min = np.maximum(bb_min.astype(int) - 1, np.array([0, 0, 0]))
788+
bb_max = np.minimum(bb_max.astype(int) + 1, np.array(list(segmentation.shape)))
789+
bb = tuple(slice(mi, ma) for mi, ma in zip(bb_min, bb_max))
790+
791+
seg = segmentation[bb]
792+
mask = ~find_boundaries(seg)
793+
dist = distance_transform_edt(mask, sampling=resolution)
794+
795+
seg_mask = seg == object_id
796+
dist[~seg_mask] = 0
797+
dist = gaussian(dist, (0.6, 1.2, 1.2))
798+
maxima = peak_local_max(dist, min_distance=3, exclude_border=True)
799+
800+
if len(maxima) == 1:
801+
return [object_id]
802+
803+
with lock:
804+
old_offset = offset
805+
offset += len(maxima)
806+
807+
seeds = np.zeros(seg.shape, dtype=int)
808+
for i, pos in enumerate(maxima, 1):
809+
seeds[tuple(pos)] = old_offset + i
810+
811+
if height_map is None:
812+
hmap = dist.max() - dist
813+
else:
814+
hmap = height_map[bb]
815+
new_seg = watershed(hmap, markers=seeds, mask=seg_mask)
816+
817+
seg_ids, sizes = np.unique(new_seg, return_counts=True)
818+
seg_ids, sizes = seg_ids[1:], sizes[1:]
819+
820+
keep_ids = seg_ids[sizes > min_size]
821+
if len(keep_ids) < 2:
822+
return [object_id]
823+
824+
elif len(keep_ids) != len(seg_ids):
825+
new_seg[~np.isin(new_seg, keep_ids)] = 0
826+
new_seg = watershed(hmap, markers=new_seg, mask=seg_mask)
827+
828+
output[bb][seg_mask] = new_seg[seg_mask]
829+
return seg_ids.tolist()
830+
831+
# import napari
832+
# v = napari.Viewer()
833+
# v.add_image(hmap)
834+
# v.add_labels(seg)
835+
# v.add_labels(new_seg)
836+
# v.add_points(maxima)
837+
# napari.run()
838+
839+
if component_labels is None:
840+
object_ids = segmentation_table.label_id.values
841+
else:
842+
object_ids = segmentation_table[segmentation_table.isin(component_labels)].label_id.values
843+
844+
if n_threads is None:
845+
n_threads = mp.cpu_count()
846+
847+
with futures.ThreadPoolExecutor(n_threads) as tp:
848+
new_id_mapping = list(
849+
tqdm(tp.map(split_object, object_ids), total=len(object_ids), desc="Split non-convex objects")
850+
)
851+
852+
new_id_mapping = {object_id: mapped_ids for object_id, mapped_ids in zip(object_ids, new_id_mapping)}
853+
return new_id_mapping
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import os
2+
from glob import glob
3+
4+
import napari
5+
import numpy as np
6+
import imageio.v3 as imageio
7+
import vigra
8+
9+
from skimage.filters import gaussian
10+
from skimage.segmentation import find_boundaries, watershed
11+
from scipy.ndimage import distance_transform_edt
12+
from skimage.feature import peak_local_max
13+
from skimage.measure import regionprops, label
14+
15+
16+
def _size_filter(segmentation, heightmap, min_size):
17+
ids, sizes = np.unique(segmentation, return_counts=True)
18+
discard_ids = ids[sizes < min_size]
19+
mask = segmentation > 0
20+
segmentation[np.isin(segmentation, discard_ids)] = 0
21+
return watershed(heightmap, markers=segmentation, mask=mask)
22+
23+
24+
def postproc(image, segmentation, view=False):
25+
# First get rid of small objects.
26+
min_size = 250
27+
heightmap = vigra.filters.laplacianOfGaussian(image.astype("float32"), 3)
28+
29+
segmentation = _size_filter(segmentation, heightmap, min_size)
30+
31+
mask = ~find_boundaries(segmentation)
32+
dist = distance_transform_edt(mask, sampling=(2, 1, 1))
33+
dist[segmentation == 0] = 0
34+
dist = gaussian(dist, (0.6, 1.2, 1.2))
35+
maxima = peak_local_max(dist, min_distance=3, exclude_border=False)
36+
37+
maxima_image = np.zeros(segmentation.shape, dtype="uint8")
38+
pos = tuple(maxima[:, i] for i in range(3))
39+
maxima_image[pos] = 1
40+
maxima_image = label(maxima_image)
41+
42+
def maxima_ids(seg, im):
43+
ids = np.unique(im[seg])
44+
return ids[1:]
45+
46+
seed_maxima_ids, keep_seg_ids, split_seg_ids = [], [], []
47+
props = regionprops(segmentation, maxima_image, extra_properties=[maxima_ids])
48+
for prop in props:
49+
this_maxima_ids = prop.maxima_ids
50+
if len(this_maxima_ids) == 1:
51+
keep_seg_ids.append(prop.label)
52+
continue
53+
seed_maxima_ids.extend(this_maxima_ids.tolist())
54+
split_seg_ids.append(prop.label)
55+
56+
split_mask = np.isin(segmentation, split_seg_ids)
57+
# segmentation[split_mask] = 0
58+
59+
new_seeds = maxima_image.copy()
60+
new_seeds[~np.isin(maxima_image, seed_maxima_ids)] = 0
61+
new_seg = watershed(heightmap, markers=new_seeds, mask=split_mask)
62+
63+
segmentation[split_mask] = 0
64+
offset = segmentation.max()
65+
new_seg[new_seg != 0] += offset
66+
segmentation[split_mask] = new_seg[split_mask]
67+
segmentation = label(segmentation)
68+
segmentation = _size_filter(segmentation, heightmap, min_size)
69+
70+
if view:
71+
v = napari.Viewer()
72+
v.add_image(image)
73+
v.add_labels(segmentation)
74+
# v.add_labels(new_seg)
75+
# v.add_image(heightmap)
76+
# v.add_image(dist)
77+
# v.add_points(maxima)
78+
# v.add_labels(split_mask)
79+
napari.run()
80+
81+
return segmentation
82+
83+
84+
def postprocess_volume(im_path, seg_path, out_root):
85+
image = imageio.imread(im_path)
86+
segmentation = imageio.imread(seg_path)
87+
segmentation = postproc(image, segmentation, view=True)
88+
89+
os.makedirs(out_root, exist_ok=True)
90+
fname = os.path.basename(im_path)
91+
imageio.imwrite(os.path.join(out_root, fname), segmentation, compression="zlib")
92+
93+
94+
def postprocess_volume_scalable(im_path, seg_path, out_root):
95+
from flamingo_tools.segmentation.postprocessing import split_nonconvex_objects, compute_table_on_the_fly
96+
97+
image = imageio.imread(im_path)
98+
segmentation = imageio.imread(seg_path)
99+
100+
# TODO aniso resolution
101+
resolution = 0.38
102+
table = compute_table_on_the_fly(segmentation, resolution)
103+
104+
out = np.zeros_like(segmentation)
105+
id_mapping = split_nonconvex_objects(segmentation, out, table, n_threads=1, resolution=resolution, min_size=250)
106+
n_prev = len(id_mapping)
107+
n_after = sum([len(v) for v in id_mapping.values()])
108+
print("Before splitting:", n_prev)
109+
print("After splitting:", n_after)
110+
111+
v = napari.Viewer()
112+
v.add_image(image)
113+
v.add_labels(segmentation, visible=False)
114+
v.add_labels(out)
115+
napari.run()
116+
117+
118+
def main():
119+
im_paths = sorted(glob("la-vision-sgn-new/images/*.tif"))
120+
seg_paths = sorted(glob("la-vision-sgn-new/segmentation/*.tif"))
121+
out_root = "la-vision-sgn-new/segmentation-postprocessed"
122+
for im_path, seg_path in zip(im_paths, seg_paths):
123+
# postprocess_volume(im_path, seg_path, out_root)
124+
postprocess_volume_scalable(im_path, seg_path, out_root)
125+
break
126+
127+
128+
if __name__ == "__main__":
129+
main()

0 commit comments

Comments
 (0)