Skip to content

Commit 4334d7d

Browse files
committed
modified normalization function so it handles different channel sizes
1 parent 466c733 commit 4334d7d

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

synapse_net/inference/util.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
import warnings
44
from glob import glob
5-
from typing import Dict, Optional, Tuple, Union
5+
from typing import Dict, List, Optional, Tuple, Union
66

77
# # Suppress annoying import warnings.
88
# with warnings.catch_warnings():
@@ -85,6 +85,7 @@ def get_prediction(
8585
model: Optional[torch.nn.Module] = None,
8686
verbose: bool = True,
8787
with_channels: bool = False,
88+
channels_to_normalize: Optional[List[int]] = [0],
8889
mask: Optional[np.ndarray] = None,
8990
) -> np.ndarray:
9091
"""Run prediction on a given volume.
@@ -99,6 +100,7 @@ def get_prediction(
99100
tiling: The tiling configuration for the prediction.
100101
verbose: Whether to print timing information.
101102
with_channels: Whether to predict with channels.
103+
channels_to_normalize: List of channels to normalize. Defaults to 0.
102104
mask: Optional binary mask. If given, the prediction will only be run in
103105
the foreground region of the mask.
104106
@@ -120,8 +122,10 @@ def get_prediction(
120122
# We standardize the data for the whole volume beforehand.
121123
# If we have channels then the standardization is done independently per channel.
122124
if with_channels:
125+
input_volume = input_volume.astype(np.float32, copy=False)
123126
# TODO Check that this is the correct axis.
124-
input_volume = np.stack([torch_em.transform.raw.normalize(input_volume[0]), input_volume[1]], axis=0)
127+
for ch in channels_to_normalize:
128+
input_volume[ch] = torch_em.transform.raw.normalize(input_volume[ch])
125129
else:
126130
input_volume = torch_em.transform.raw.standardize(input_volume)
127131

0 commit comments

Comments
 (0)