88import numpy as np
99import nifty .tools as nt
1010import vigra
11+ import tifffile
1112import torch
1213import z5py
1314
@@ -37,7 +38,10 @@ def ndim(self):
3738 return self ._volume .ndim - 1
3839
3940
40- def prediction_impl (input_path , input_key , output_folder , model_path , scale , block_shape , halo ):
41+ def prediction_impl (
42+ input_path , input_key , output_folder , model_path , scale , block_shape , halo ,
43+ output_channels = 3 , apply_postprocessing = True ,
44+ ):
4145 with warnings .catch_warnings ():
4246 warnings .simplefilter ("ignore" )
4347 if os .path .isdir (model_path ):
@@ -46,10 +50,16 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
4650 model = torch .load (model_path )
4751
4852 mask_path = os .path .join (output_folder , "mask.zarr" )
49- image_mask = z5py .File (mask_path , "r" )["mask" ]
53+ if os .path .exists (mask_path ):
54+ image_mask = z5py .File (mask_path , "r" )["mask" ]
55+ else :
56+ image_mask = None
5057
5158 if input_key is None :
52- input_ = imageio .imread (input_path )
59+ try :
60+ input_ = tifffile .memmap (input_path )
61+ except Exception :
62+ input_ = imageio .imread (input_path )
5363 else :
5464 input_ = open_file (input_path , "r" )[input_key ]
5565
@@ -93,17 +103,27 @@ def preprocess(raw):
93103 raw /= std
94104 return raw
95105
96- # Smooth the distance prediction channel.
97- def postprocess (x ):
98- x [1 ] = vigra .filters .gaussianSmoothing (x [1 ], sigma = 2.0 )
99- return x
106+ if apply_postprocessing :
107+ # Smooth the distance prediction channel.
108+ def postprocess (x ):
109+ x [1 ] = vigra .filters .gaussianSmoothing (x [1 ], sigma = 2.0 )
110+ return x
111+ else :
112+ postprocess = None if output_channels > 1 else lambda x : x .squeeze ()
113+
114+ if output_channels > 1 :
115+ output_shape = (output_channels ,) + input_ .shape
116+ output_chunks = (1 ,) + block_shape
117+ else :
118+ output_shape = input_ .shape
119+ output_chunks = block_shape
100120
101121 output_path = os .path .join (output_folder , "predictions.zarr" )
102122 with open_file (output_path , "a" ) as f :
103123 output = f .require_dataset (
104124 "prediction" ,
105- shape = ( 3 ,) + input_ . shape ,
106- chunks = ( 1 ,) + block_shape ,
125+ shape = output_shape ,
126+ chunks = output_chunks ,
107127 compression = "gzip" ,
108128 dtype = "float32" ,
109129 )
0 commit comments