1818# import xarray
1919
2020from elf .io import open_file
21+ from numpy .typing import ArrayLike
2122from scipy .ndimage import binary_closing
2223from skimage .measure import regionprops
2324from skimage .morphology import remove_small_holes
@@ -100,15 +101,16 @@ def rescale_output(self, output, is_segmentation):
100101
101102
102103def get_prediction (
103- input_volume : np . ndarray , # [z, y, x]
104+ input_volume : ArrayLike , # [z, y, x]
104105 tiling : Optional [Dict [str , Dict [str , int ]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
105106 model_path : Optional [str ] = None ,
106107 model : Optional [torch .nn .Module ] = None ,
107108 verbose : bool = True ,
108109 with_channels : bool = False ,
109110 channels_to_standardize : Optional [List [int ]] = None ,
110- mask : Optional [np .ndarray ] = None ,
111- ) -> np .ndarray :
111+ mask : Optional [ArrayLike ] = None ,
112+ prediction : Optional [ArrayLike ] = None ,
113+ ) -> ArrayLike :
112114 """Run prediction on a given volume.
113115
114116 This function will automatically choose the correct prediction implementation,
@@ -124,6 +126,8 @@ def get_prediction(
124126 channels_to_standardize: List of channels to standardize. Defaults to None.
125127 mask: Optional binary mask. If given, the prediction will only be run in
126128 the foreground region of the mask.
129+ prediction: An array like object for writing the prediction.
130+ If not given, the prediction will be computed in moemory.
127131
128132 Returns:
129133 The predicted volume.
@@ -174,21 +178,23 @@ def get_prediction(
174178 for dim in tiling ["tile" ]:
175179 updated_tiling ["tile" ][dim ] = tiling ["tile" ][dim ] - 2 * tiling ["halo" ][dim ]
176180 # print(f"updated_tiling {updated_tiling}")
177- pred = get_prediction_torch_em (
178- input_volume , updated_tiling , model_path , model , verbose , with_channels , mask = mask
181+ prediction = get_prediction_torch_em (
182+ input_volume , updated_tiling , model_path , model , verbose , with_channels ,
183+ mask = mask , prediction = prediction ,
179184 )
180185
181- return pred
186+ return prediction
182187
183188
184189def get_prediction_torch_em (
185- input_volume : np . ndarray , # [z, y, x]
190+ input_volume : ArrayLike , # [z, y, x]
186191 tiling : Dict [str , Dict [str , int ]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
187192 model_path : Optional [str ] = None ,
188193 model : Optional [torch .nn .Module ] = None ,
189194 verbose : bool = True ,
190195 with_channels : bool = False ,
191- mask : Optional [np .ndarray ] = None ,
196+ mask : Optional [ArrayLike ] = None ,
197+ prediction : Optional [ArrayLike ] = None ,
192198) -> np .ndarray :
193199 """Run prediction using torch-em on a given volume.
194200
@@ -201,6 +207,8 @@ def get_prediction_torch_em(
201207 with_channels: Whether to predict with channels.
202208 mask: Optional binary mask. If given, the prediction will only be run in
203209 the foreground region of the mask.
210+ prediction: An array like object for writing the prediction.
211+ If not given, the prediction will be computed in moemory.
204212
205213 Returns:
206214 The predicted volume.
@@ -234,14 +242,15 @@ def get_prediction_torch_em(
234242 print ("Run prediction with mask." )
235243 mask = mask .astype ("bool" )
236244
237- pred = predict_with_halo (
245+ prediction = predict_with_halo (
238246 input_volume , model , gpu_ids = [device ],
239247 block_shape = block_shape , halo = halo ,
240248 preprocess = None , with_channels = with_channels , mask = mask ,
249+ output = prediction ,
241250 )
242251 if verbose :
243252 print ("Prediction time in" , time .time () - t0 , "s" )
244- return pred
253+ return prediction
245254
246255
247256def _get_file_paths (input_path , ext = ".mrc" ):
0 commit comments