diff --git a/synapse_net/inference/scalable_segmentation.py b/synapse_net/inference/scalable_segmentation.py new file mode 100644 index 00000000..156fef55 --- /dev/null +++ b/synapse_net/inference/scalable_segmentation.py @@ -0,0 +1,137 @@ +import os +import tempfile +from typing import Dict, List, Optional + +import elf.parallel as parallel +import numpy as np +import torch + +from elf.io import open_file +from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper +from elf.wrapper.base import MultiTransformationWrapper +from elf.wrapper.resized_volume import ResizedVolume +from numpy.typing import ArrayLike +from synapse_net.inference.util import get_prediction + + +class SelectChannel(SimpleTransformationWrapper): + """Wrapper to select a chanel from an array-like dataset object. + + Args: + volume: The array-like input dataset. + channel: The channel that will be selected. + """ + def __init__(self, volume: np.typing.ArrayLike, channel: int): + self.channel = channel + super().__init__(volume, lambda x: x[self.channel], with_channels=True) + + @property + def shape(self): + return self._volume.shape[1:] + + @property + def chunks(self): + return self._volume.chunks[1:] + + @property + def ndim(self): + return self._volume.ndim - 1 + + +def _run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape): + # Create wrappers for selecting the foreground and the boundary channel. + foreground = SelectChannel(pred, 0) + boundaries = SelectChannel(pred, 1) + + # Create wrappers for subtracting and thresholding boundary subtracted from the foreground. + # And then compute the seeds based on this. + seed_input = ThresholdWrapper( + MultiTransformationWrapper(np.subtract, foreground, boundaries), seed_threshold + ) + parallel.label(seed_input, seeds, verbose=verbose, block_shape=chunks) + + # Run watershed to extend back from the seeds to the boundaries. + mask = ThresholdWrapper(foreground, 0.5) + + # Resize if necessary. + if original_shape is not None: + boundaries = ResizedVolume(boundaries, original_shape, order=1) + seeds = ResizedVolume(seeds, original_shape, order=0) + mask = ResizedVolume(mask, original_shape, order=0) + + parallel.seeded_watershed( + boundaries, seeds=seeds, out=output, verbose=verbose, mask=mask, block_shape=chunks, halo=3 * (16,) + ) + + # Run the size filter. + if min_size > 0: + parallel.size_filter(output, output, min_size=min_size, verbose=verbose, block_shape=chunks) + + +def scalable_segmentation( + input_: ArrayLike, + output: ArrayLike, + model: torch.nn.Module, + tiling: Optional[Dict[str, Dict[str, int]]] = None, + scale: Optional[List[float]] = None, + seed_threshold: float = 0.5, + min_size: int = 500, + prediction: Optional[ArrayLike] = None, + verbose: bool = True, + mask: Optional[ArrayLike] = None, +) -> None: + """Run segmentation based on a prediction with foreground and boundary channel. + + This function first subtracts the boundary prediction from the foreground prediction, + then applies a threshold, connected components, and a watershed to fit the components + back to the foreground. All processing steps are implemented in a scalable fashion, + so that the function runs for large input volumes. + + Args: + input_: The input data. + output: The array for storing the output segmentation. + Can be a numpy array, a zarr array, or similar. + model: The model for prediction. + tiling: The tiling configuration for the prediction. + scale: The scale factor to use for rescaling the input volume before prediction. + seed_threshold: The threshold applied before computing connected components. + min_size: The minimum size of a vesicle to be considered. + prediction: The array for storing the prediction. + If given, this can be a numpy array, a zarr array, or similar + If not given will be stored in a temporary n5 array. + verbose: Whether to print timing information. + """ + if mask is not None: + raise NotImplementedError + assert model.out_channels == 2 + + # Create a temporary directory for storing the predictions. + chunks = (128,) * 3 + with tempfile.TemporaryDirectory() as tmp_dir: + + if scale is None or np.allclose(scale, 1.0, atol=1e-3): + original_shape = None + else: + original_shape = input_.shape + new_shape = tuple(int(sh * sc) for sh, sc in zip(input_.shape, scale)) + input_ = ResizedVolume(input_, shape=new_shape, order=1) + + if prediction is None: + # Create the dataset for storing the prediction. + tmp_pred = os.path.join(tmp_dir, "prediction.n5") + f = open_file(tmp_pred, mode="a") + pred_shape = (2,) + input_.shape + pred_chunks = (1,) + chunks + prediction = f.create_dataset("pred", shape=pred_shape, dtype="float32", chunks=pred_chunks) + else: + assert prediction.shape[0] == 2 + assert prediction.shape[1:] == input_.shape + + # Create temporary storage for the seeds. + tmp_seeds = os.path.join(tmp_dir, "seeds.n5") + f = open_file(tmp_seeds, mode="a") + seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks) + + # Run prediction and segmentation. + get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose) + _run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape) diff --git a/synapse_net/inference/util.py b/synapse_net/inference/util.py index 09c771da..6bec9bf4 100644 --- a/synapse_net/inference/util.py +++ b/synapse_net/inference/util.py @@ -18,6 +18,7 @@ # import xarray from elf.io import open_file +from numpy.typing import ArrayLike from scipy.ndimage import binary_closing from skimage.measure import regionprops from skimage.morphology import remove_small_holes @@ -99,16 +100,32 @@ def rescale_output(self, output, is_segmentation): return output +def _preprocess(input_volume, with_channels, channels_to_standardize): + # We standardize the data for the whole volume beforehand. + # If we have channels then the standardization is done independently per channel. + if with_channels: + input_volume = input_volume.astype(np.float32, copy=False) + # TODO Check that this is the correct axis. + if channels_to_standardize is None: # assume all channels + channels_to_standardize = range(input_volume.shape[0]) + for ch in channels_to_standardize: + input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch]) + else: + input_volume = torch_em.transform.raw.standardize(input_volume) + return input_volume + + def get_prediction( - input_volume: np.ndarray, # [z, y, x] + input_volume: ArrayLike, # [z, y, x] tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} model_path: Optional[str] = None, model: Optional[torch.nn.Module] = None, verbose: bool = True, with_channels: bool = False, channels_to_standardize: Optional[List[int]] = None, - mask: Optional[np.ndarray] = None, -) -> np.ndarray: + mask: Optional[ArrayLike] = None, + prediction: Optional[ArrayLike] = None, +) -> ArrayLike: """Run prediction on a given volume. This function will automatically choose the correct prediction implementation, @@ -124,6 +141,8 @@ def get_prediction( channels_to_standardize: List of channels to standardize. Defaults to None. mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask. + prediction: An array like object for writing the prediction. + If not given, the prediction will be computed in moemory. Returns: The predicted volume. @@ -140,17 +159,11 @@ def get_prediction( if tiling is None: tiling = get_default_tiling() - # We standardize the data for the whole volume beforehand. - # If we have channels then the standardization is done independently per channel. - if with_channels: - input_volume = input_volume.astype(np.float32, copy=False) - # TODO Check that this is the correct axis. - if channels_to_standardize is None: # assume all channels - channels_to_standardize = range(input_volume.shape[0]) - for ch in channels_to_standardize: - input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch]) - else: - input_volume = torch_em.transform.raw.standardize(input_volume) + # Normalize the whole input volume if it is a numpy array. + # Otherwise we have a zarr array or similar as input, and can't normalize it en-block. + # Normalization will be applied later per block in this case. + if isinstance(input_volume, np.ndarray): + input_volume = _preprocess(input_volume, with_channels, channels_to_standardize) # Run prediction with the bioimage.io library. if is_bioimageio: @@ -174,21 +187,23 @@ def get_prediction( for dim in tiling["tile"]: updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim] # print(f"updated_tiling {updated_tiling}") - pred = get_prediction_torch_em( - input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask + prediction = get_prediction_torch_em( + input_volume, updated_tiling, model_path, model, verbose, with_channels, + mask=mask, prediction=prediction, ) - return pred + return prediction def get_prediction_torch_em( - input_volume: np.ndarray, # [z, y, x] + input_volume: ArrayLike, # [z, y, x] tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} model_path: Optional[str] = None, model: Optional[torch.nn.Module] = None, verbose: bool = True, with_channels: bool = False, - mask: Optional[np.ndarray] = None, + mask: Optional[ArrayLike] = None, + prediction: Optional[ArrayLike] = None, ) -> np.ndarray: """Run prediction using torch-em on a given volume. @@ -201,6 +216,8 @@ def get_prediction_torch_em( with_channels: Whether to predict with channels. mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask. + prediction: An array like object for writing the prediction. + If not given, the prediction will be computed in moemory. Returns: The predicted volume. @@ -234,14 +251,16 @@ def get_prediction_torch_em( print("Run prediction with mask.") mask = mask.astype("bool") - pred = predict_with_halo( + preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize + prediction = predict_with_halo( input_volume, model, gpu_ids=[device], block_shape=block_shape, halo=halo, - preprocess=None, with_channels=with_channels, mask=mask, + preprocess=preprocess, with_channels=with_channels, mask=mask, + output=prediction, ) if verbose: print("Prediction time in", time.time() - t0, "s") - return pred + return prediction def _get_file_paths(input_path, ext=".mrc"): @@ -325,6 +344,7 @@ def inference_helper( output_key: Optional[str] = None, model_resolution: Optional[Tuple[float, float, float]] = None, scale: Optional[Tuple[float, float, float]] = None, + allocate_output: bool = False, ) -> None: """Helper function to run segmentation for mrc files. @@ -347,6 +367,7 @@ def inference_helper( model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. If given, the scaling factor will automatically be determined based on the voxel_size of the input data. scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'. + allocate_output: Whether to allocate the output for the segmentation function. """ if (scale is not None) and (model_resolution is not None): raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.") @@ -412,7 +433,11 @@ def inference_helper( this_scale = _derive_scale(img_path, model_resolution) # Run the segmentation. - segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale) + if allocate_output: + segmentation = np.zeros(input_volume.shape, dtype="uint32") + segmentation_function(input_volume, output=segmentation, mask=mask, scale=this_scale) + else: + segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale) # Write the result to tif or h5. os.makedirs(os.path.split(output_path)[0], exist_ok=True) diff --git a/synapse_net/tools/cli.py b/synapse_net/tools/cli.py index 6b4e44f1..d2565b72 100644 --- a/synapse_net/tools/cli.py +++ b/synapse_net/tools/cli.py @@ -6,6 +6,7 @@ import torch_em from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation +from ..inference.scalable_segmentation import scalable_segmentation from ..inference.util import inference_helper, parse_tiling @@ -152,6 +153,10 @@ def segmentation_cli(): "--verbose", "-v", action="store_true", help="Whether to print verbose information about the segmentation progress." ) + parser.add_argument( + "--scalable", action="store_true", help="Use the scalable segmentation implementation. " + "Currently this only works for vesicles, mitochondria, or active zones." + ) args = parser.parse_args() if args.checkpoint is None: @@ -181,11 +186,26 @@ def segmentation_cli(): model_resolution = None scale = (2 if is_2d else 3) * (args.scale,) - segmentation_function = partial( - run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, - ) + if args.scalable: + if not args.model.startswith(("vesicle", "mito", "active")): + raise ValueError( + "The scalable segmentation implementation is currently only supported for " + f"vesicles, mitochondria, or active zones, not for {args.model}." + ) + segmentation_function = partial( + scalable_segmentation, model=model, tiling=tiling, verbose=args.verbose + ) + allocate_output = True + + else: + segmentation_function = partial( + run_segmentation, model=model, model_type=args.model, verbose=args.verbose, tiling=tiling, + ) + allocate_output = False + inference_helper( args.input_path, args.output_path, segmentation_function, mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, + allocate_output=allocate_output )