@@ -126,6 +126,7 @@ def get_prediction(
126126 mask : Optional [ArrayLike ] = None ,
127127 prediction : Optional [ArrayLike ] = None ,
128128 devices : Optional [List [str ]] = None ,
129+ preprocess : Optional [callable ] = None ,
129130) -> ArrayLike :
130131 """Run prediction on a given volume.
131132
@@ -192,7 +193,7 @@ def get_prediction(
192193 # print(f"updated_tiling {updated_tiling}")
193194 prediction = get_prediction_torch_em (
194195 input_volume , updated_tiling , model_path , model , verbose , with_channels ,
195- mask = mask , prediction = prediction , devices = devices ,
196+ mask = mask , prediction = prediction , devices = devices , preprocess = preprocess ,
196197 )
197198
198199 return prediction
@@ -208,6 +209,7 @@ def get_prediction_torch_em(
208209 mask : Optional [ArrayLike ] = None ,
209210 prediction : Optional [ArrayLike ] = None ,
210211 devices : Optional [List [str ]] = None ,
212+ preprocess : Optional [callable ] = None ,
211213) -> np .ndarray :
212214 """Run prediction using torch-em on a given volume.
213215
@@ -258,7 +260,10 @@ def get_prediction_torch_em(
258260 print ("Run prediction with mask." )
259261 mask = mask .astype ("bool" )
260262
261- preprocess = None if isinstance (input_volume , np .ndarray ) else torch_em .transform .raw .standardize
263+ if preprocess is None :
264+ preprocess = None if isinstance (input_volume , np .ndarray ) else torch_em .transform .raw .standardize
265+ else :
266+ preprocess = preprocess
262267 prediction = predict_with_halo (
263268 input_volume , model , gpu_ids = devices ,
264269 block_shape = block_shape , halo = halo ,
0 commit comments