11import multiprocessing as mp
2+ import threading
23from concurrent import futures
34import os
45from typing import Optional , Tuple
@@ -20,7 +21,7 @@ def sgn_detection(
2021 model_path : str ,
2122 block_shape : Optional [Tuple [int , int , int ]] = None ,
2223 halo : Optional [Tuple [int , int , int ]] = None ,
23- spot_radius : int = 4 ,
24+ spot_radius : int = 2 ,
2425):
2526 """Run prediction for sgn detection.
2627
@@ -52,16 +53,13 @@ def sgn_detection(
5253 apply_postprocessing = False , output_channels = 1 ,
5354 )
5455
55- detection_path = os .path .join (output_folder , "SGN_detection.tsv" )
5656 detection_path = os .path .join (output_folder , "SGN_detection.tsv" )
5757 if not os .path .exists (detection_path ):
5858 input_ = zarr .open (output_path , "r" )[prediction_key ]
5959 detections = find_local_maxima (
6060 input_ , block_shape = block_shape , min_distance = 4 , threshold_abs = 0.5 , verbose = True , n_threads = 16 ,
6161 )
6262
63- print (detections .shape )
64-
6563 shape = input_ .shape
6664 chunks = (128 , 128 , 128 )
6765 segmentation_path = os .path .join (output_folder , "segmentation.zarr" )
@@ -72,14 +70,17 @@ def sgn_detection(
7270 chunks = chunks , compression = "gzip"
7371 )
7472
73+ lock = threading .Lock ()
74+
7575 def add_halo_segm (detection_index ):
7676 """Create a segmentation volume around all detected spots.
7777 """
78- coord = detections [detection_index ]
78+ coord = list ( detections [detection_index ])
7979 block_begin = [round (c ) - spot_radius for c in coord ]
8080 block_end = [round (c ) + spot_radius for c in coord ]
8181 volume_index = tuple (slice (beg , end ) for beg , end in zip (block_begin , block_end ))
82- output_dataset [volume_index ] = detection_index + 1
82+ with lock :
83+ output_dataset [volume_index ] = detection_index + 1
8384
8485 # Limit the number of cores for parallelization.
8586 n_threads = min (16 , mp .cpu_count ())
0 commit comments