22import time
33import warnings
44from glob import glob
5- from typing import Dict , Optional , Tuple , Union
5+ from typing import Dict , List , Optional , Tuple , Union
66
77# # Suppress annoying import warnings.
88# with warnings.catch_warnings():
@@ -85,6 +85,7 @@ def get_prediction(
8585 model : Optional [torch .nn .Module ] = None ,
8686 verbose : bool = True ,
8787 with_channels : bool = False ,
88+ channels_to_normalize : Optional [List [int ]] = [0 ],
8889 mask : Optional [np .ndarray ] = None ,
8990) -> np .ndarray :
9091 """Run prediction on a given volume.
@@ -99,6 +100,7 @@ def get_prediction(
99100 tiling: The tiling configuration for the prediction.
100101 verbose: Whether to print timing information.
101102 with_channels: Whether to predict with channels.
103+ channels_to_normalize: List of channels to normalize. Defaults to 0.
102104 mask: Optional binary mask. If given, the prediction will only be run in
103105 the foreground region of the mask.
104106
@@ -120,8 +122,10 @@ def get_prediction(
120122 # We standardize the data for the whole volume beforehand.
121123 # If we have channels then the standardization is done independently per channel.
122124 if with_channels :
125+ input_volume = input_volume .astype (np .float32 , copy = False )
123126 # TODO Check that this is the correct axis.
124- input_volume = np .stack ([torch_em .transform .raw .normalize (input_volume [0 ]), input_volume [1 ]], axis = 0 )
127+ for ch in channels_to_normalize :
128+ input_volume [ch ] = torch_em .transform .raw .normalize (input_volume [ch ])
125129 else :
126130 input_volume = torch_em .transform .raw .standardize (input_volume )
127131
0 commit comments