Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion synapse_net/inference/cristae.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def segment_cristae(
return_predictions: bool = False,
scale: Optional[List[float]] = None,
mask: Optional[np.ndarray] = None,
**kwargs
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""Segment cristae in an input volume.

Expand All @@ -61,6 +62,8 @@ def segment_cristae(
The segmentation mask as a numpy array, or a tuple containing the segmentation mask
and the predictions if return_predictions is True.
"""
with_channels = kwargs.pop("with_channels", True)
channels_to_standardize = kwargs.pop("channels_to_standardize", [0])
if verbose:
print("Segmenting cristae in volume of shape", input_volume.shape)
# Create the scaler to handle prediction with a different scaling factor.
Expand All @@ -72,7 +75,7 @@ def segment_cristae(
mask = scaler.scale_input(mask, is_segmentation=True)
pred = get_prediction(
input_volume, model_path=model_path, model=model, mask=mask,
tiling=tiling, with_channels=True, verbose=verbose
tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose
)
foreground, boundaries = pred[:2]
seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size)
Expand Down
11 changes: 7 additions & 4 deletions synapse_net/inference/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_prediction(
model: Optional[torch.nn.Module] = None,
verbose: bool = True,
with_channels: bool = False,
channels_to_normalize: Optional[List[int]] = [0],
channels_to_standardize: Optional[List[int]] = None,
mask: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Run prediction on a given volume.
Expand All @@ -100,7 +100,7 @@ def get_prediction(
tiling: The tiling configuration for the prediction.
verbose: Whether to print timing information.
with_channels: Whether to predict with channels.
channels_to_normalize: List of channels to normalize. Defaults to 0.
channels_to_standardize: List of channels to standardize. Defaults to None.
mask: Optional binary mask. If given, the prediction will only be run in
the foreground region of the mask.

Expand All @@ -123,9 +123,12 @@ def get_prediction(
# If we have channels then the standardization is done independently per channel.
if with_channels:
input_volume = input_volume.astype(np.float32, copy=False)
channels_to_standardize = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now you are over-riding your input so that it's always None. If you remove this line then everything should be fine.

# TODO Check that this is the correct axis.
for ch in channels_to_normalize:
input_volume[ch] = torch_em.transform.raw.normalize(input_volume[ch])
if channels_to_standardize is None: # assume all channels
channels_to_standardize = range(input_volume.shape[0])
for ch in channels_to_standardize:
input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
else:
input_volume = torch_em.transform.raw.standardize(input_volume)

Expand Down