66
77sys .path .append ("../.." )
88
9+ """
10+ Prediction using distance U-Net.
11+ Parallelization using multiple GPUs is currently only possible by calling functions located in segmentation/unet_prediction.py directly.
12+ Functions for the parallelization end with '_slurm' and divide the process into preprocessing, prediction, and segmentation.
13+ """
914
1015def main ():
11- from flamingo_tools .segmentation import run_unet_prediction , run_unet_prediction_slurm
16+ from flamingo_tools .segmentation import run_unet_prediction
1217
1318 parser = argparse .ArgumentParser ()
1419 parser .add_argument ("-i" , "--input" , required = True )
1520 parser .add_argument ("-o" , "--output_folder" , required = True )
1621 parser .add_argument ("-m" , "--model" , required = True )
1722 parser .add_argument ("-k" , "--input_key" , default = None )
1823 parser .add_argument ("-s" , "--scale" , default = None , type = float , help = "Downscale the image by the given factor." )
19- parser .add_argument ("-n" , "--number_gpu" , default = 1 , type = int , help = "Number of GPUs to use in parallel." )
2024
2125 args = parser .parse_args ()
2226
@@ -37,24 +41,11 @@ def main():
3741 block_shape = tuple ([2 * ch for ch in chunks ]) if have_cuda else tuple (chunks )
3842 halo = (16 , 64 , 64 ) if have_cuda else (8 , 32 , 32 )
3943
40- prediction_instances = args .number_gpu if have_cuda else 1
41-
42- if 1 > prediction_instances :
43- # FIXME: only does prediction part, no segmentation yet
44- # FIXME: implement array job
45- run_unet_prediction_slurm (
46- args .input , args .input_key , args .output_folder , args .model ,
47- scale = scale ,
48- block_shape = block_shape , halo = halo ,
49- prediction_instances = prediction_instances ,
50- )
51- else :
52-
53- run_unet_prediction (
54- args .input , args .input_key , args .output_folder , args .model ,
55- scale = scale , min_size = min_size ,
56- block_shape = block_shape , halo = halo ,
57- )
44+ run_unet_prediction (
45+ args .input , args .input_key , args .output_folder , args .model ,
46+ scale = scale , min_size = min_size ,
47+ block_shape = block_shape , halo = halo ,
48+ )
5849
5950
6051if __name__ == "__main__" :
0 commit comments