11import os
22import tempfile
3- from typing import Dict , Optional
3+ from typing import Dict , List , Optional
44
55import elf .parallel as parallel
66import numpy as np
99from elf .io import open_file
1010from elf .wrapper import ThresholdWrapper , SimpleTransformationWrapper
1111from elf .wrapper .base import MultiTransformationWrapper
12- from synapse_net . inference . util import get_prediction
12+ from elf . wrapper . resized_volume import ResizedVolume
1313from numpy .typing import ArrayLike
14+ from synapse_net .inference .util import get_prediction
1415
1516
1617class SelectChannel (SimpleTransformationWrapper ):
@@ -37,7 +38,7 @@ def ndim(self):
3738 return self ._volume .ndim - 1
3839
3940
40- def _run_segmentation (pred , output , seeds , chunks , seed_threshold , min_size , verbose ):
41+ def _run_segmentation (pred , output , seeds , chunks , seed_threshold , min_size , verbose , original_shape ):
4142 # Create wrappers for selecting the foreground and the boundary channel.
4243 foreground = SelectChannel (pred , 0 )
4344 boundaries = SelectChannel (pred , 1 )
@@ -51,6 +52,13 @@ def _run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, ver
5152
5253 # Run watershed to extend back from the seeds to the boundaries.
5354 mask = ThresholdWrapper (foreground , 0.5 )
55+
56+ # Resize if necessary.
57+ if original_shape is not None :
58+ boundaries = ResizedVolume (boundaries , original_shape , order = 1 )
59+ seeds = ResizedVolume (seeds , original_shape , order = 0 )
60+ mask = ResizedVolume (mask , original_shape , order = 0 )
61+
5462 parallel .seeded_watershed (
5563 boundaries , seeds = seeds , out = output , verbose = verbose , mask = mask , block_shape = chunks , halo = 3 * (16 ,)
5664 )
@@ -60,44 +68,67 @@ def _run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, ver
6068 parallel .size_filter (output , output , min_size = min_size , verbose = verbose , block_shape = chunks )
6169
6270
63- # TODO support resizing via the wrapper
6471def scalable_segmentation (
6572 input_ : ArrayLike ,
6673 output : ArrayLike ,
6774 model : torch .nn .Module ,
6875 tiling : Optional [Dict [str , Dict [str , int ]]] = None ,
76+ scale : Optional [List [float ]] = None ,
6977 seed_threshold : float = 0.5 ,
7078 min_size : int = 500 ,
79+ prediction : Optional [ArrayLike ] = None ,
7180 verbose : bool = True ,
7281) -> None :
73- """Lorem ipsum
82+ """Run segmentation based on a prediction with foreground and boundary channel.
83+
84+ This function first subtracts the boundary prediction from the foreground prediction,
85+ then applies a threshold, connected components, and a watershed to fit the components
86+ back to the foreground. All processing steps are implemented in a scalable fashion,
87+ so that the function runs for large input volumes.
7488
7589 Args:
76- input_:
77- output:
78- model: The model.
90+ input_: The input data.
91+ output: The array for storing the output segmentation.
92+ Can be a numpy array, a zarr array, or similar.
93+ model: The model for prediction.
7994 tiling: The tiling configuration for the prediction.
95+ scale: The scale factor to use for rescaling the input volume before prediction.
96+ seed_threshold: The threshold applied before computing connected components.
8097 min_size: The minimum size of a vesicle to be considered.
98+ prediction: The array for storing the prediction.
99+ If given, this can be a numpy array, a zarr array, or similar
100+ If not given will be stored in a temporary n5 array.
81101 verbose: Whether to print timing information.
82102 """
83103 assert model .out_channels == 2
84104
85105 # Create a temporary directory for storing the predictions.
106+ chunks = (128 ,) * 3
86107 with tempfile .TemporaryDirectory () as tmp_dir :
87- tmp_pred = os .path .join (tmp_dir , "prediction.n5" )
88- f = open_file (tmp_pred , mode = "a" )
89108
90- # Create the dataset for storing the prediction.
91- chunks = (128 ,) * 3
92- pred_shape = (2 ,) + input_ .shape
93- pred_chunks = (1 ,) + chunks
94- pred = f .create_dataset ("pred" , shape = pred_shape , dtype = "float32" , chunks = pred_chunks )
109+ if scale is None or np .allclose (scale , 1.0 , atol = 1e-3 ):
110+ original_shape = None
111+ else :
112+ original_shape = input_ .shape
113+ new_shape = tuple (int (sh * sc ) for sh , sc in zip (input_ .shape , scale ))
114+ input_ = ResizedVolume (input_ , shape = new_shape , order = 1 )
115+
116+ if prediction is None :
117+ # Create the dataset for storing the prediction.
118+ tmp_pred = os .path .join (tmp_dir , "prediction.n5" )
119+ f = open_file (tmp_pred , mode = "a" )
120+ pred_shape = (2 ,) + input_ .shape
121+ pred_chunks = (1 ,) + chunks
122+ prediction = f .create_dataset ("pred" , shape = pred_shape , dtype = "float32" , chunks = pred_chunks )
123+ else :
124+ assert prediction .shape [0 ] == 2
125+ assert prediction .shape [1 :] == input_ .shape
95126
96127 # Create temporary storage for the seeds.
97128 tmp_seeds = os .path .join (tmp_dir , "seeds.n5" )
98129 f = open_file (tmp_seeds , mode = "a" )
99130 seeds = f .create_dataset ("seeds" , shape = input_ .shape , dtype = "uint64" , chunks = chunks )
100131
101132 # Run prediction and segmentation.
102- get_prediction (input_ , prediction = pred , tiling = tiling , model = model , verbose = verbose )
103- _run_segmentation (pred , output , seeds , chunks , seed_threshold , min_size , verbose )
133+ get_prediction (input_ , prediction = prediction , tiling = tiling , model = model , verbose = verbose )
134+ _run_segmentation (prediction , output , seeds , chunks , seed_threshold , min_size , verbose , original_shape )
0 commit comments