Skip to content

Commit 3cfeda4

Browse files
committed
changed to torch_em standardize
1 parent 4334d7d commit 3cfeda4

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

synapse_net/inference/cristae.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def segment_cristae(
4242
return_predictions: bool = False,
4343
scale: Optional[List[float]] = None,
4444
mask: Optional[np.ndarray] = None,
45+
**kwargs
4546
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
4647
"""Segment cristae in an input volume.
4748
@@ -61,6 +62,8 @@ def segment_cristae(
6162
The segmentation mask as a numpy array, or a tuple containing the segmentation mask
6263
and the predictions if return_predictions is True.
6364
"""
65+
with_channels = kwargs.pop("with_channels", True)
66+
channels_to_standardize = kwargs.pop("channels_to_standardize", [0])
6467
if verbose:
6568
print("Segmenting cristae in volume of shape", input_volume.shape)
6669
# Create the scaler to handle prediction with a different scaling factor.
@@ -72,7 +75,7 @@ def segment_cristae(
7275
mask = scaler.scale_input(mask, is_segmentation=True)
7376
pred = get_prediction(
7477
input_volume, model_path=model_path, model=model, mask=mask,
75-
tiling=tiling, with_channels=True, verbose=verbose
78+
tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose
7679
)
7780
foreground, boundaries = pred[:2]
7881
seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size)

synapse_net/inference/util.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def get_prediction(
8585
model: Optional[torch.nn.Module] = None,
8686
verbose: bool = True,
8787
with_channels: bool = False,
88-
channels_to_normalize: Optional[List[int]] = [0],
88+
channels_to_standardize: Optional[List[int]] = None,
8989
mask: Optional[np.ndarray] = None,
9090
) -> np.ndarray:
9191
"""Run prediction on a given volume.
@@ -100,7 +100,7 @@ def get_prediction(
100100
tiling: The tiling configuration for the prediction.
101101
verbose: Whether to print timing information.
102102
with_channels: Whether to predict with channels.
103-
channels_to_normalize: List of channels to normalize. Defaults to 0.
103+
channels_to_standardize: List of channels to standardize. Defaults to None.
104104
mask: Optional binary mask. If given, the prediction will only be run in
105105
the foreground region of the mask.
106106
@@ -123,9 +123,12 @@ def get_prediction(
123123
# If we have channels then the standardization is done independently per channel.
124124
if with_channels:
125125
input_volume = input_volume.astype(np.float32, copy=False)
126+
channels_to_standardize = None
126127
# TODO Check that this is the correct axis.
127-
for ch in channels_to_normalize:
128-
input_volume[ch] = torch_em.transform.raw.normalize(input_volume[ch])
128+
if channels_to_standardize is None: # assume all channels
129+
channels_to_standardize = range(input_volume.shape[0])
130+
for ch in channels_to_standardize:
131+
input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
129132
else:
130133
input_volume = torch_em.transform.raw.standardize(input_volume)
131134

0 commit comments

Comments
 (0)