1818# import xarray 
1919
2020from  elf .io  import  open_file 
21+ from  numpy .typing  import  ArrayLike 
2122from  scipy .ndimage  import  binary_closing 
2223from  skimage .measure  import  regionprops 
2324from  skimage .morphology  import  remove_small_holes 
@@ -99,16 +100,32 @@ def rescale_output(self, output, is_segmentation):
99100        return  output 
100101
101102
103+ def  _preprocess (input_volume , with_channels , channels_to_standardize ):
104+     # We standardize the data for the whole volume beforehand. 
105+     # If we have channels then the standardization is done independently per channel. 
106+     if  with_channels :
107+         input_volume  =  input_volume .astype (np .float32 , copy = False )
108+         # TODO Check that this is the correct axis. 
109+         if  channels_to_standardize  is  None :  # assume all channels 
110+             channels_to_standardize  =  range (input_volume .shape [0 ])
111+         for  ch  in  channels_to_standardize :
112+             input_volume [ch ] =  torch_em .transform .raw .standardize (input_volume [ch ])
113+     else :
114+         input_volume  =  torch_em .transform .raw .standardize (input_volume )
115+     return  input_volume 
116+ 
117+ 
102118def  get_prediction (
103-     input_volume : np . ndarray ,  # [z, y, x] 
119+     input_volume : ArrayLike ,  # [z, y, x] 
104120    tiling : Optional [Dict [str , Dict [str , int ]]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 
105121    model_path : Optional [str ] =  None ,
106122    model : Optional [torch .nn .Module ] =  None ,
107123    verbose : bool  =  True ,
108124    with_channels : bool  =  False ,
109125    channels_to_standardize : Optional [List [int ]] =  None ,
110-     mask : Optional [np .ndarray ] =  None ,
111- ) ->  np .ndarray :
126+     mask : Optional [ArrayLike ] =  None ,
127+     prediction : Optional [ArrayLike ] =  None ,
128+ ) ->  ArrayLike :
112129    """Run prediction on a given volume. 
113130
114131    This function will automatically choose the correct prediction implementation, 
@@ -124,6 +141,8 @@ def get_prediction(
124141        channels_to_standardize: List of channels to standardize. Defaults to None. 
125142        mask: Optional binary mask. If given, the prediction will only be run in 
126143            the foreground region of the mask. 
144+         prediction: An array like object for writing the prediction. 
145+             If not given, the prediction will be computed in moemory. 
127146
128147    Returns: 
129148        The predicted volume. 
@@ -140,17 +159,11 @@ def get_prediction(
140159    if  tiling  is  None :
141160        tiling  =  get_default_tiling ()
142161
143-     # We standardize the data for the whole volume beforehand. 
144-     # If we have channels then the standardization is done independently per channel. 
145-     if  with_channels :
146-         input_volume  =  input_volume .astype (np .float32 , copy = False )
147-         # TODO Check that this is the correct axis. 
148-         if  channels_to_standardize  is  None :  # assume all channels 
149-             channels_to_standardize  =  range (input_volume .shape [0 ])
150-         for  ch  in  channels_to_standardize :
151-             input_volume [ch ] =  torch_em .transform .raw .standardize (input_volume [ch ])
152-     else :
153-         input_volume  =  torch_em .transform .raw .standardize (input_volume )
162+     # Normalize the whole input volume if it is a numpy array. 
163+     # Otherwise we have a zarr array or similar as input, and can't normalize it en-block. 
164+     # Normalization will be applied later per block in this case. 
165+     if  isinstance (input_volume , np .ndarray ):
166+         input_volume  =  _preprocess (input_volume , with_channels , channels_to_standardize )
154167
155168    # Run prediction with the bioimage.io library. 
156169    if  is_bioimageio :
@@ -174,21 +187,23 @@ def get_prediction(
174187        for  dim  in  tiling ["tile" ]:
175188            updated_tiling ["tile" ][dim ] =  tiling ["tile" ][dim ] -  2  *  tiling ["halo" ][dim ]
176189        # print(f"updated_tiling {updated_tiling}") 
177-         pred  =  get_prediction_torch_em (
178-             input_volume , updated_tiling , model_path , model , verbose , with_channels , mask = mask 
190+         prediction  =  get_prediction_torch_em (
191+             input_volume , updated_tiling , model_path , model , verbose , with_channels ,
192+             mask = mask , prediction = prediction ,
179193        )
180194
181-     return  pred 
195+     return  prediction 
182196
183197
184198def  get_prediction_torch_em (
185-     input_volume : np . ndarray ,  # [z, y, x] 
199+     input_volume : ArrayLike ,  # [z, y, x] 
186200    tiling : Dict [str , Dict [str , int ]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 
187201    model_path : Optional [str ] =  None ,
188202    model : Optional [torch .nn .Module ] =  None ,
189203    verbose : bool  =  True ,
190204    with_channels : bool  =  False ,
191-     mask : Optional [np .ndarray ] =  None ,
205+     mask : Optional [ArrayLike ] =  None ,
206+     prediction : Optional [ArrayLike ] =  None ,
192207) ->  np .ndarray :
193208    """Run prediction using torch-em on a given volume. 
194209
@@ -201,6 +216,8 @@ def get_prediction_torch_em(
201216        with_channels: Whether to predict with channels. 
202217        mask: Optional binary mask. If given, the prediction will only be run in 
203218            the foreground region of the mask. 
219+         prediction: An array like object for writing the prediction. 
220+             If not given, the prediction will be computed in moemory. 
204221
205222    Returns: 
206223        The predicted volume. 
@@ -234,14 +251,16 @@ def get_prediction_torch_em(
234251                print ("Run prediction with mask." )
235252            mask  =  mask .astype ("bool" )
236253
237-         pred  =  predict_with_halo (
254+         preprocess  =  None  if  isinstance (input_volume , np .ndarray ) else  torch_em .transform .raw .standardize 
255+         prediction  =  predict_with_halo (
238256            input_volume , model , gpu_ids = [device ],
239257            block_shape = block_shape , halo = halo ,
240-             preprocess = None , with_channels = with_channels , mask = mask ,
258+             preprocess = preprocess , with_channels = with_channels , mask = mask ,
259+             output = prediction ,
241260        )
242261    if  verbose :
243262        print ("Prediction time in" , time .time () -  t0 , "s" )
244-     return  pred 
263+     return  prediction 
245264
246265
247266def  _get_file_paths (input_path , ext = ".mrc" ):
@@ -325,6 +344,7 @@ def inference_helper(
325344    output_key : Optional [str ] =  None ,
326345    model_resolution : Optional [Tuple [float , float , float ]] =  None ,
327346    scale : Optional [Tuple [float , float , float ]] =  None ,
347+     allocate_output : bool  =  False ,
328348) ->  None :
329349    """Helper function to run segmentation for mrc files. 
330350
@@ -347,6 +367,7 @@ def inference_helper(
347367        model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. 
348368            If given, the scaling factor will automatically be determined based on the voxel_size of the input data. 
349369        scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'. 
370+         allocate_output: Whether to allocate the output for the segmentation function. 
350371    """ 
351372    if  (scale  is  not   None ) and  (model_resolution  is  not   None ):
352373        raise  ValueError ("You must not provide both 'scale' and 'model_resolution' arguments." )
@@ -412,7 +433,11 @@ def inference_helper(
412433            this_scale  =  _derive_scale (img_path , model_resolution )
413434
414435        # Run the segmentation. 
415-         segmentation  =  segmentation_function (input_volume , mask = mask , scale = this_scale )
436+         if  allocate_output :
437+             segmentation  =  np .zeros (input_volume .shape , dtype = "uint32" )
438+             segmentation_function (input_volume , output = segmentation , mask = mask , scale = this_scale )
439+         else :
440+             segmentation  =  segmentation_function (input_volume , mask = mask , scale = this_scale )
416441
417442        # Write the result to tif or h5. 
418443        os .makedirs (os .path .split (output_path )[0 ], exist_ok = True )
0 commit comments