@@ -85,7 +85,7 @@ def get_prediction(
8585 model : Optional [torch .nn .Module ] = None ,
8686 verbose : bool = True ,
8787 with_channels : bool = False ,
88- channels_to_normalize : Optional [List [int ]] = [ 0 ] ,
88+ channels_to_standardize : Optional [List [int ]] = None ,
8989 mask : Optional [np .ndarray ] = None ,
9090) -> np .ndarray :
9191 """Run prediction on a given volume.
@@ -100,7 +100,7 @@ def get_prediction(
100100 tiling: The tiling configuration for the prediction.
101101 verbose: Whether to print timing information.
102102 with_channels: Whether to predict with channels.
103- channels_to_normalize : List of channels to normalize . Defaults to 0 .
103+ channels_to_standardize : List of channels to standardize . Defaults to None .
104104 mask: Optional binary mask. If given, the prediction will only be run in
105105 the foreground region of the mask.
106106
@@ -123,9 +123,12 @@ def get_prediction(
123123 # If we have channels then the standardization is done independently per channel.
124124 if with_channels :
125125 input_volume = input_volume .astype (np .float32 , copy = False )
126+ channels_to_standardize = None
126127 # TODO Check that this is the correct axis.
127- for ch in channels_to_normalize :
128- input_volume [ch ] = torch_em .transform .raw .normalize (input_volume [ch ])
128+ if channels_to_standardize is None : # assume all channels
129+ channels_to_standardize = range (input_volume .shape [0 ])
130+ for ch in channels_to_standardize :
131+ input_volume [ch ] = torch_em .transform .raw .standardize (input_volume [ch ])
129132 else :
130133 input_volume = torch_em .transform .raw .standardize (input_volume )
131134
0 commit comments