Skip to content

Commit d91eccf

Browse files
Add script for overlap based merging of SGN segmentations
1 parent 93c7fd2 commit d91eccf

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import json
2+
import os
3+
from concurrent import futures
4+
5+
import numpy as np
6+
import zarr
7+
from elf.evaluation.matching import label_overlap, intersection_over_union
8+
from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target, get_s3_path
9+
from nifty.tools import blocking
10+
from tqdm import tqdm
11+
12+
13+
def merge_segmentations(seg_a, seg_b, ids_b, offset, output_path):
14+
assert seg_a.shape == seg_b.shape
15+
16+
output_file = zarr.open(output_path, mode="a")
17+
output = output_file.create_dataset("segmentation", shape=seg_a.shape, dtype=seg_a.dtype, chunks=seg_a.chunks)
18+
blocks = blocking([0, 0, 0], seg_a.shape, seg_a.chunks)
19+
20+
def merge_block(block_id):
21+
block = blocks.getBlock(block_id)
22+
bb = tuple(slice(begin, end) for begin, end in zip(block.begin, block.end))
23+
24+
block_a = seg_a[bb]
25+
block_b = seg_b[bb]
26+
27+
insert_mask = np.isin(block_b, ids_b)
28+
if insert_mask.sum() > 0:
29+
block_b[insert_mask] += offset
30+
block_a[insert_mask] = block_b[insert_mask]
31+
32+
output[bb] = block_a
33+
34+
n_blocks = blocks.numberOfBlocks
35+
with futures.ThreadPoolExecutor(12) as tp:
36+
list(tqdm(tp.map(merge_block, range(n_blocks)), total=n_blocks, desc="Merge segmentation"))
37+
38+
39+
def get_segmentation(cochlea, seg_name, seg_key):
40+
print("Loading segmentation ...")
41+
s3 = create_s3_target()
42+
43+
content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8")
44+
info = json.loads(content.read())
45+
sources = info["sources"]
46+
47+
seg_source = sources[seg_name]
48+
seg_path = os.path.join(cochlea, seg_source["segmentation"]["imageData"]["ome.zarr"]["relativePath"])
49+
seg_store, _ = get_s3_path(seg_path)
50+
51+
return zarr.open(seg_store, mode="r")[seg_key]
52+
53+
54+
def merge_sgns(cochlea, name_a, name_b, overlap_threshold=0.25):
55+
# Get the two segmentations at low resolution for computing the overlaps.
56+
seg_a = get_segmentation(cochlea, seg_name=name_a, seg_key="s2")[:]
57+
seg_b = get_segmentation(cochlea, seg_name=name_b, seg_key="s2")[:]
58+
59+
# Compute the overlaps and determine which SGNs to add from SegB based on the overlap threshold.
60+
print("Compute label overlaps ...")
61+
overlap, ignore_label = label_overlap(seg_a, seg_b)
62+
overlap = intersection_over_union(overlap)
63+
cumulative_overlap = overlap[1:, :].sum(axis=0)
64+
all_ids_b = np.unique(seg_b)
65+
ids_b = all_ids_b[cumulative_overlap < overlap_threshold]
66+
offset = seg_a.max()
67+
68+
# Get the segmentations at full resolution to merge them.
69+
seg_a = get_segmentation(cochlea, seg_name=name_a, seg_key="s2")
70+
seg_b = get_segmentation(cochlea, seg_name=name_b, seg_key="s2")
71+
72+
# Write out the merged segmentations.
73+
output_folder = f"./data/{cochlea}"
74+
os.makedirs(output_folder, exist_ok=True)
75+
output_path = os.path.join(output_folder, "SGN_merged.zarr")
76+
merge_segmentations(seg_a, seg_b, ids_b, offset, output_path)
77+
78+
79+
def main():
80+
# merge_sgns(cochlea="M_AMD_N180_L", name_a="CR_SGN_v2", name_b="Ntng1_SGN_v2")
81+
merge_sgns(cochlea="M_AMD_N180_R", name_a="CR_SGN_v2", name_b="Ntng1_SGN_v2")
82+
83+
84+
if __name__ == "__main__":
85+
main()

0 commit comments

Comments
 (0)