Skip to content

Commit 03ec955

Browse files
Add segmentation and measurement scripts for SGN stainings
1 parent 4cf9e41 commit 03ec955

File tree

4 files changed

+136
-1
lines changed

4 files changed

+136
-1
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def run_unet_prediction(
324324
scale: Optional[float] = None,
325325
block_shape: Optional[Tuple[int, int, int]] = None,
326326
halo: Optional[Tuple[int, int, int]] = None,
327+
use_mask: bool = True,
327328
) -> None:
328329
"""Run prediction and segmentation with a distance U-Net.
329330
@@ -337,10 +338,12 @@ def run_unet_prediction(
337338
By default the data will not be rescaled.
338339
block_shape: The block-shape for running the prediction.
339340
halo: The halo (= block overlap) to use for prediction.
341+
use_mask: Whether to use the masking heuristics to not run inference on empty blocks.
340342
"""
341343
os.makedirs(output_folder, exist_ok=True)
342344

343-
find_mask(input_path, input_key, output_folder)
345+
if use_mask:
346+
find_mask(input_path, input_key, output_folder)
344347

345348
original_shape = prediction_impl(
346349
input_path, input_key, output_folder, model_path, scale, block_shape, halo
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
from glob import glob
3+
4+
import imageio.v3 as imageio
5+
import napari
6+
7+
8+
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops"
9+
SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations"
10+
11+
12+
def main():
13+
files = sorted(glob(os.path.join(ROOT, "**/*.tif")))
14+
for ff in files:
15+
if "segmentations" in ff:
16+
return
17+
print("Visualizing", ff)
18+
rel_path = os.path.relpath(ff, ROOT)
19+
seg_path = os.path.join(SAVE_ROOT, rel_path)
20+
21+
image = imageio.imread(ff)
22+
if os.path.exists(seg_path):
23+
seg = imageio.imread(seg_path)
24+
else:
25+
seg = None
26+
27+
v = napari.Viewer()
28+
v.add_image(image)
29+
if seg is not None:
30+
v.add_labels(seg)
31+
napari.run()
32+
33+
34+
if __name__ == "__main__":
35+
main()
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
from glob import glob
3+
4+
import tifffile
5+
from flamingo_tools.measurements import compute_object_measures_impl
6+
7+
8+
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops"
9+
SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations"
10+
11+
12+
def measure_intensities(ff):
13+
rel_path = os.path.relpath(ff, ROOT)
14+
out_path = os.path.join("./measurements", rel_path.replace(".tif", ".xlsx"))
15+
if os.path.exists(out_path):
16+
return
17+
18+
print("Computing measurements for", rel_path)
19+
seg_path = os.path.join(SAVE_ROOT, rel_path)
20+
21+
image_data = tifffile.memmap(ff)
22+
seg_data = tifffile.memmap(seg_path)
23+
24+
table = compute_object_measures_impl(image_data, seg_data, n_threads=8)
25+
26+
os.makedirs(os.path.split(out_path)[0], exist_ok=True)
27+
table.to_excel(out_path, index=False)
28+
29+
30+
def main():
31+
files = sorted(glob(os.path.join(ROOT, "**/*.tif")))
32+
for ff in files:
33+
if "segmentations" in ff:
34+
return
35+
measure_intensities(ff)
36+
37+
38+
if __name__ == "__main__":
39+
main()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
import tempfile
3+
from glob import glob
4+
5+
import tifffile
6+
from elf.io import open_file
7+
from flamingo_tools.segmentation import run_unet_prediction
8+
9+
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops"
10+
MODEL_PATH = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/cochlea_distance_unet_SGN_March2025Model" # noqa
11+
12+
SAVE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/LS_sampleprepcomparison_crops/segmentations"
13+
14+
15+
def check_data():
16+
files = glob(os.path.join(ROOT, "**/*.tif"), recursive=True)
17+
for ff in files:
18+
rel_path = sorted(os.path.relpath(ff, ROOT))
19+
shape = tifffile.memmap(ff).shape
20+
print(rel_path, shape)
21+
22+
23+
def segment_crop(input_file):
24+
fname = os.path.relpath(input_file, ROOT)
25+
out_file = os.path.join(SAVE_ROOT, fname)
26+
if "segmentations" in input_file:
27+
return
28+
if os.path.exists(out_file):
29+
return
30+
31+
print("Run prediction for", input_file)
32+
os.makedirs(os.path.split(out_file)[0], exist_ok=True)
33+
with tempfile.TemporaryDirectory() as tmp_folder:
34+
run_unet_prediction(
35+
input_file, input_key=None, output_folder=tmp_folder,
36+
model_path=MODEL_PATH, min_size=1000, use_mask=False,
37+
)
38+
seg_path = os.path.join(tmp_folder, "segmentation.zarr")
39+
with open_file(seg_path, mode="r") as f:
40+
seg = f["segmentation"][:]
41+
42+
print("Writing output to", out_file)
43+
tifffile.imwrite(out_file, seg, bigtiff=True)
44+
45+
46+
def segment_all():
47+
files = sorted(glob(os.path.join(ROOT, "**/*.tif"), recursive=True))
48+
for ff in files:
49+
segment_crop(ff)
50+
51+
52+
def main():
53+
# check_data()
54+
segment_all()
55+
56+
57+
if __name__ == "__main__":
58+
main()

0 commit comments

Comments
 (0)