@@ -100,6 +100,21 @@ def rescale_output(self, output, is_segmentation):
100100 return output
101101
102102
103+ def _preprocess (input_volume , with_channels , channels_to_standardize ):
104+ # We standardize the data for the whole volume beforehand.
105+ # If we have channels then the standardization is done independently per channel.
106+ if with_channels :
107+ input_volume = input_volume .astype (np .float32 , copy = False )
108+ # TODO Check that this is the correct axis.
109+ if channels_to_standardize is None : # assume all channels
110+ channels_to_standardize = range (input_volume .shape [0 ])
111+ for ch in channels_to_standardize :
112+ input_volume [ch ] = torch_em .transform .raw .standardize (input_volume [ch ])
113+ else :
114+ input_volume = torch_em .transform .raw .standardize (input_volume )
115+ return input_volume
116+
117+
103118def get_prediction (
104119 input_volume : ArrayLike , # [z, y, x]
105120 tiling : Optional [Dict [str , Dict [str , int ]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
@@ -144,17 +159,11 @@ def get_prediction(
144159 if tiling is None :
145160 tiling = get_default_tiling ()
146161
147- # We standardize the data for the whole volume beforehand.
148- # If we have channels then the standardization is done independently per channel.
149- if with_channels :
150- input_volume = input_volume .astype (np .float32 , copy = False )
151- # TODO Check that this is the correct axis.
152- if channels_to_standardize is None : # assume all channels
153- channels_to_standardize = range (input_volume .shape [0 ])
154- for ch in channels_to_standardize :
155- input_volume [ch ] = torch_em .transform .raw .standardize (input_volume [ch ])
156- else :
157- input_volume = torch_em .transform .raw .standardize (input_volume )
162+ # Normalize the whole input volume if it is a numpy array.
163+ # Otherwise we have a zarr array or similar as input, and can't normalize it en-block.
164+ # Normalization will be applied later per block in this case.
165+ if isinstance (input_volume , np .ndarray ):
166+ input_volume = _preprocess (input_volume , with_channels , channels_to_standardize )
158167
159168 # Run prediction with the bioimage.io library.
160169 if is_bioimageio :
@@ -242,10 +251,11 @@ def get_prediction_torch_em(
242251 print ("Run prediction with mask." )
243252 mask = mask .astype ("bool" )
244253
254+ preprocess = None if isinstance (input_volume , np .ndarray ) else torch_em .transform .raw .standardize
245255 prediction = predict_with_halo (
246256 input_volume , model , gpu_ids = [device ],
247257 block_shape = block_shape , halo = halo ,
248- preprocess = None , with_channels = with_channels , mask = mask ,
258+ preprocess = preprocess , with_channels = with_channels , mask = mask ,
249259 output = prediction ,
250260 )
251261 if verbose :
0 commit comments