Skip to content

Commit 4473d9d

Browse files
Merge branch 'intensity-masking' of https://github.com/computational-cell-analytics/lightsheet-moser into intensity-masking
2 parents a6b2de8 + 0c8b97f commit 4473d9d

File tree

3 files changed

+185
-4
lines changed

3 files changed

+185
-4
lines changed
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import multiprocessing
2+
import os
3+
import threading
4+
from concurrent import futures
5+
from threadpoolctl import threadpool_limits
6+
from typing import Optional, Tuple, Union
7+
8+
import numpy as np
9+
from numpy.typing import ArrayLike
10+
import pandas as pd
11+
from scipy.ndimage import distance_transform_edt
12+
from skimage.segmentation import watershed
13+
import zarr
14+
15+
from elf.io import open_file
16+
from elf.parallel.local_maxima import find_local_maxima
17+
from flamingo_tools.segmentation.unet_prediction import prediction_impl
18+
from tqdm import tqdm
19+
20+
from elf.parallel.common import get_blocking
21+
22+
23+
def distance_based_marker_extension(
24+
markers: np.ndarray,
25+
output: ArrayLike,
26+
extension_distance: float,
27+
sampling: Union[float, Tuple[float, ...]],
28+
block_shape: Tuple[int, ...],
29+
n_threads: Optional[int] = None,
30+
verbose: bool = False,
31+
roi: Optional[Tuple[slice, ...]] = None,
32+
):
33+
"""
34+
Extend SGN detection to emulate shape of SGNs for better visualization.
35+
36+
Args:
37+
markers: Array of coordinates for seeding watershed.
38+
output: Output for watershed.
39+
extension_distance: Distance in micrometer for extension.
40+
sampling: Resolution in micrometer.
41+
block_shape:
42+
n_threads:
43+
verbose:
44+
roi:
45+
"""
46+
n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
47+
blocking = get_blocking(output, block_shape, roi, n_threads)
48+
49+
lock = threading.Lock()
50+
51+
# determine the correct halo in pixels based on the sampling and the extension distance.
52+
halo = [round(extension_distance / s) + 2 for s in sampling]
53+
54+
@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
55+
def extend_block(block_id):
56+
block = blocking.getBlockWithHalo(block_id, halo)
57+
outer_block = block.outerBlock
58+
inner_block = block.innerBlock
59+
60+
# get the indices and coordinates of the markers in the INNER block
61+
mask = (
62+
(inner_block.begin[0] <= markers[:, 0]) & (markers[:, 0] <= inner_block.end[0]) &
63+
(inner_block.begin[1] <= markers[:, 1]) & (markers[:, 1] <= inner_block.end[1]) &
64+
(inner_block.begin[2] <= markers[:, 2]) & (markers[:, 2] <= inner_block.end[2])
65+
)
66+
markers_in_block_ids = np.where(mask)[0]
67+
markers_in_block_coords = markers[markers_in_block_ids]
68+
69+
# proceed if detections fall within inner block
70+
if len(markers_in_block_coords) > 0:
71+
markers_in_block_coords = [coord - outer_block.begin for coord in markers_in_block_coords]
72+
markers_in_block_coords = [[round(c) for c in coord] for coord in markers_in_block_coords]
73+
74+
markers_in_block_coords = np.array(markers_in_block_coords, dtype=int)
75+
z, y, x = markers_in_block_coords.T
76+
77+
# Shift index by one so that zero is reserved for background id
78+
markers_in_block_ids += 1
79+
80+
# Create the seed volume.
81+
outer_block_shape = tuple(end - begin for begin, end in zip(outer_block.begin, outer_block.end))
82+
seeds = np.zeros(outer_block_shape, dtype="uint32")
83+
seeds[z, y, x] = markers_in_block_ids
84+
85+
# Compute the distance map.
86+
distance = distance_transform_edt(seeds == 0, sampling=sampling)
87+
88+
# And extend the seeds
89+
mask = distance < extension_distance
90+
segmentation = watershed(distance.max() - distance, markers=seeds, mask=mask)
91+
92+
# Write the segmentation. Note: we need to lock here because we write outside of our inner block
93+
bb = tuple(slice(begin, end) for begin, end in zip(outer_block.begin, outer_block.end))
94+
with lock:
95+
this_output = output[bb]
96+
this_output[mask] = segmentation[mask]
97+
output[bb] = this_output
98+
99+
n_blocks = blocking.numberOfBlocks
100+
with futures.ThreadPoolExecutor(n_threads) as tp:
101+
list(tqdm(
102+
tp.map(extend_block, range(n_blocks)), total=n_blocks, desc="Marker extension", disable=not verbose
103+
))
104+
105+
106+
def sgn_detection(
107+
input_path: str,
108+
input_key: str,
109+
output_folder: str,
110+
model_path: str,
111+
extension_distance: float,
112+
sampling: Union[float, Tuple[float, ...]],
113+
block_shape: Optional[Tuple[int, int, int]] = None,
114+
halo: Optional[Tuple[int, int, int]] = None,
115+
n_threads: Optional[int] = None,
116+
):
117+
"""Run prediction for SGN detection.
118+
119+
Args:
120+
input_path: Input path to image channel for SGN detection.
121+
input_key: Input key for resolution of image channel and mask channel.
122+
output_folder: Output folder for SGN segmentation.
123+
model_path: Path to model for SGN detection.
124+
block_shape: The block-shape for running the prediction.
125+
halo: The halo (= block overlap) to use for prediction.
126+
spot_radius: Radius in pixel to convert spot detection of SGNs into a volume.
127+
"""
128+
if block_shape is None:
129+
block_shape = (12, 128, 128)
130+
if halo is None:
131+
halo = (10, 64, 64)
132+
133+
# Skip existing prediction, which is saved in output_folder/predictions.zarr
134+
skip_prediction = False
135+
output_path = os.path.join(output_folder, "predictions.zarr")
136+
prediction_key = "prediction"
137+
if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"):
138+
skip_prediction = True
139+
140+
if not skip_prediction:
141+
prediction_impl(
142+
input_path, input_key, output_folder, model_path,
143+
scale=None, block_shape=block_shape, halo=halo,
144+
apply_postprocessing=False, output_channels=1,
145+
)
146+
147+
detection_path = os.path.join(output_folder, "SGN_detection.tsv")
148+
if not os.path.exists(detection_path):
149+
input_ = zarr.open(output_path, "r")[prediction_key]
150+
detections_maxima = find_local_maxima(
151+
input_, block_shape=block_shape, min_distance=4, threshold_abs=0.5, verbose=True, n_threads=16,
152+
)
153+
154+
# Save the result in mobie compatible format.
155+
detections = np.concatenate(
156+
[np.arange(1, len(detections_maxima) + 1)[:, None], detections_maxima[:, ::-1]], axis=1
157+
)
158+
detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"])
159+
detections.to_csv(detection_path, index=False, sep="\t")
160+
161+
# extend detection
162+
shape = input_.shape
163+
chunks = (128, 128, 128)
164+
segmentation_path = os.path.join(output_folder, "segmentation.zarr")
165+
output = open_file(segmentation_path, mode="a")
166+
segmentation_key = "segmentation"
167+
output_dataset = output.create_dataset(
168+
segmentation_key, shape=shape, dtype=np.uint64,
169+
chunks=chunks, compression="gzip"
170+
)
171+
172+
distance_based_marker_extension(
173+
markers=detections_maxima,
174+
output=output_dataset,
175+
extension_distance=extension_distance,
176+
sampling=sampling,
177+
block_shape=(128, 128, 128),
178+
n_threads=n_threads,
179+
verbose=True,
180+
)

reproducibility/label_components/repro_label_components.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ def repro_label_components(
4545
# table_name = "PV_SGN_V2_DA"
4646
# table_name = "CR_SGN_v2"
4747
# table_name = "Ntng1_SGN_v2"
48+
table_name = "SGN_detect-v6b"
4849

49-
table_name = f"{cell_type.upper()}_{unet_version}"
50+
# table_name = f"{cell_type.upper()}_{unet_version}"
5051

5152
s3_path = os.path.join(f"{cochlea}", "tables", table_name, "default.tsv")
5253
tsv_path, fs = get_s3_path(s3_path, bucket_name=s3_bucket_name,

scripts/la-vision/train_sgn_detection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from utils.training.training import supervised_training # noqa
1313
from detection_dataset import DetectionDataset, MinPointSampler # noqa
1414

15-
ROOT = "./la-vision-sgn-new/train/sgn-detection" # noqa
16-
# ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN/sgn-detection"
15+
# ROOT = "./la-vision-sgn-new/train/sgn-detection" # noqa
16+
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN/sgn-detection"
1717

1818
TRAIN = os.path.join(ROOT, "images")
1919
TRAIN_EMPTY = os.path.join(ROOT, "empty_images")
@@ -52,7 +52,7 @@ def get_paths(split):
5252

5353
def train():
5454

55-
model_name = "sgn-low-res-detection-v6"
55+
model_name = "sgn-low-res-detection-v7"
5656

5757
train_paths, train_label_paths = get_paths("train")
5858
val_paths, val_label_paths = get_paths("val")

0 commit comments

Comments
 (0)