Skip to content

Commit caba462

Browse files
committed
Make segmentation thread-safe
1 parent 90087c4 commit caba462

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

flamingo_tools/segmentation/sgn_detection.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import multiprocessing as mp
2+
import threading
23
from concurrent import futures
34
import os
45
from typing import Optional, Tuple
@@ -20,7 +21,7 @@ def sgn_detection(
2021
model_path: str,
2122
block_shape: Optional[Tuple[int, int, int]] = None,
2223
halo: Optional[Tuple[int, int, int]] = None,
23-
spot_radius: int = 4,
24+
spot_radius: int = 2,
2425
):
2526
"""Run prediction for sgn detection.
2627
@@ -52,16 +53,13 @@ def sgn_detection(
5253
apply_postprocessing=False, output_channels=1,
5354
)
5455

55-
detection_path = os.path.join(output_folder, "SGN_detection.tsv")
5656
detection_path = os.path.join(output_folder, "SGN_detection.tsv")
5757
if not os.path.exists(detection_path):
5858
input_ = zarr.open(output_path, "r")[prediction_key]
5959
detections = find_local_maxima(
6060
input_, block_shape=block_shape, min_distance=4, threshold_abs=0.5, verbose=True, n_threads=16,
6161
)
6262

63-
print(detections.shape)
64-
6563
shape = input_.shape
6664
chunks = (128, 128, 128)
6765
segmentation_path = os.path.join(output_folder, "segmentation.zarr")
@@ -72,14 +70,17 @@ def sgn_detection(
7270
chunks=chunks, compression="gzip"
7371
)
7472

73+
lock = threading.Lock()
74+
7575
def add_halo_segm(detection_index):
7676
"""Create a segmentation volume around all detected spots.
7777
"""
78-
coord = detections[detection_index]
78+
coord = list(detections[detection_index])
7979
block_begin = [round(c) - spot_radius for c in coord]
8080
block_end = [round(c) + spot_radius for c in coord]
8181
volume_index = tuple(slice(beg, end) for beg, end in zip(block_begin, block_end))
82-
output_dataset[volume_index] = detection_index + 1
82+
with lock:
83+
output_dataset[volume_index] = detection_index + 1
8384

8485
# Limit the number of cores for parallelization.
8586
n_threads = min(16, mp.cpu_count())

0 commit comments

Comments
 (0)