diff --git a/synapse_net/inference/cristae.py b/synapse_net/inference/cristae.py index 1f61cc99..a37f9d9a 100644 --- a/synapse_net/inference/cristae.py +++ b/synapse_net/inference/cristae.py @@ -42,6 +42,7 @@ def segment_cristae( return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[np.ndarray] = None, + **kwargs ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """Segment cristae in an input volume. @@ -61,6 +62,8 @@ def segment_cristae( The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True. """ + with_channels = kwargs.pop("with_channels", True) + channels_to_standardize = kwargs.pop("channels_to_standardize", [0]) if verbose: print("Segmenting cristae in volume of shape", input_volume.shape) # Create the scaler to handle prediction with a different scaling factor. @@ -72,7 +75,7 @@ def segment_cristae( mask = scaler.scale_input(mask, is_segmentation=True) pred = get_prediction( input_volume, model_path=model_path, model=model, mask=mask, - tiling=tiling, with_channels=True, verbose=verbose + tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose ) foreground, boundaries = pred[:2] seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size) diff --git a/synapse_net/inference/inference.py b/synapse_net/inference/inference.py index 97ed55bb..9236df69 100644 --- a/synapse_net/inference/inference.py +++ b/synapse_net/inference/inference.py @@ -10,6 +10,7 @@ from .mitochondria import segment_mitochondria from .ribbon_synapse import segment_ribbon_synapse_structures from .vesicles import segment_vesicles +from .cristae import segment_cristae from .util import get_device from ..file_utils import get_cache_dir @@ -25,6 +26,7 @@ def _get_model_registry(): "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1", "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186", "mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673", + "cristae": "f96c90484f4ea92ac0515a06e389cc117580f02c2aacdc44b5828820cf38c3c3", "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9", "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1", "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29", @@ -35,6 +37,7 @@ def _get_model_registry(): "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download", "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download", "mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download", + "cristae": "https://owncloud.gwdg.de/index.php/s/Df7OUOyQ1Kc2eEO/download", "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download", "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download", "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download", @@ -214,7 +217,7 @@ def run_segmentation( """ if model_type.startswith("vesicles"): segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) - elif model_type == "mitochondria": + elif model_type == "mitochondria" or model_type == "mitochondria2": segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) elif model_type == "active_zone": segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) @@ -222,6 +225,8 @@ def run_segmentation( segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) elif model_type == "ribbon": segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) + elif model_type == "cristae": + segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) else: raise ValueError(f"Unknown model type: {model_type}") return segmentation diff --git a/synapse_net/inference/util.py b/synapse_net/inference/util.py index d6a6abc1..37776dde 100644 --- a/synapse_net/inference/util.py +++ b/synapse_net/inference/util.py @@ -2,7 +2,7 @@ import time import warnings from glob import glob -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union # # Suppress annoying import warnings. # with warnings.catch_warnings(): @@ -85,6 +85,7 @@ def get_prediction( model: Optional[torch.nn.Module] = None, verbose: bool = True, with_channels: bool = False, + channels_to_standardize: Optional[List[int]] = None, mask: Optional[np.ndarray] = None, ) -> np.ndarray: """Run prediction on a given volume. @@ -99,6 +100,7 @@ def get_prediction( tiling: The tiling configuration for the prediction. verbose: Whether to print timing information. with_channels: Whether to predict with channels. + channels_to_standardize: List of channels to standardize. Defaults to None. mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask. @@ -120,8 +122,12 @@ def get_prediction( # We standardize the data for the whole volume beforehand. # If we have channels then the standardization is done independently per channel. if with_channels: + input_volume = input_volume.astype(np.float32, copy=False) # TODO Check that this is the correct axis. - input_volume = torch_em.transform.raw.standardize(input_volume, axis=(1, 2, 3)) + if channels_to_standardize is None: # assume all channels + channels_to_standardize = range(input_volume.shape[0]) + for ch in channels_to_standardize: + input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch]) else: input_volume = torch_em.transform.raw.standardize(input_volume) diff --git a/synapse_net/tools/segmentation_widget.py b/synapse_net/tools/segmentation_widget.py index 0fa5005d..cacab1a0 100644 --- a/synapse_net/tools/segmentation_widget.py +++ b/synapse_net/tools/segmentation_widget.py @@ -178,6 +178,9 @@ def on_predict(self): if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) kwargs = {"extra_segmentation": extra_seg} + elif model_type == "cristae": # Cristae model expects 2 3D volumes + image = np.stack([image, self._get_layer_selector_data(self.extra_seg_selector_name)], axis=0) + kwargs = {} else: kwargs = {} segmentation = run_segmentation(