Skip to content

Commit 09c4d40

Browse files
Update normalization to support non numpy inputs
1 parent 48a889c commit 09c4d40

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

synapse_net/inference/util.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,21 @@ def rescale_output(self, output, is_segmentation):
100100
return output
101101

102102

103+
def _preprocess(input_volume, with_channels, channels_to_standardize):
104+
# We standardize the data for the whole volume beforehand.
105+
# If we have channels then the standardization is done independently per channel.
106+
if with_channels:
107+
input_volume = input_volume.astype(np.float32, copy=False)
108+
# TODO Check that this is the correct axis.
109+
if channels_to_standardize is None: # assume all channels
110+
channels_to_standardize = range(input_volume.shape[0])
111+
for ch in channels_to_standardize:
112+
input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
113+
else:
114+
input_volume = torch_em.transform.raw.standardize(input_volume)
115+
return input_volume
116+
117+
103118
def get_prediction(
104119
input_volume: ArrayLike, # [z, y, x]
105120
tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
@@ -144,17 +159,11 @@ def get_prediction(
144159
if tiling is None:
145160
tiling = get_default_tiling()
146161

147-
# We standardize the data for the whole volume beforehand.
148-
# If we have channels then the standardization is done independently per channel.
149-
if with_channels:
150-
input_volume = input_volume.astype(np.float32, copy=False)
151-
# TODO Check that this is the correct axis.
152-
if channels_to_standardize is None: # assume all channels
153-
channels_to_standardize = range(input_volume.shape[0])
154-
for ch in channels_to_standardize:
155-
input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
156-
else:
157-
input_volume = torch_em.transform.raw.standardize(input_volume)
162+
# Normalize the whole input volume if it is a numpy array.
163+
# Otherwise we have a zarr array or similar as input, and can't normalize it en-block.
164+
# Normalization will be applied later per block in this case.
165+
if isinstance(input_volume, np.ndarray):
166+
input_volume = _preprocess(input_volume, with_channels, channels_to_standardize)
158167

159168
# Run prediction with the bioimage.io library.
160169
if is_bioimageio:
@@ -242,10 +251,11 @@ def get_prediction_torch_em(
242251
print("Run prediction with mask.")
243252
mask = mask.astype("bool")
244253

254+
preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
245255
prediction = predict_with_halo(
246256
input_volume, model, gpu_ids=[device],
247257
block_shape=block_shape, halo=halo,
248-
preprocess=None, with_channels=with_channels, mask=mask,
258+
preprocess=preprocess, with_channels=with_channels, mask=mask,
249259
output=prediction,
250260
)
251261
if verbose:

0 commit comments

Comments
 (0)