diff --git a/synapse_net/inference/cristae.py b/synapse_net/inference/cristae.py index a37f9d9a..f8b28aae 100644 --- a/synapse_net/inference/cristae.py +++ b/synapse_net/inference/cristae.py @@ -62,13 +62,17 @@ def segment_cristae( The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True. """ + mitochondria = kwargs.pop("extra_segmentation") with_channels = kwargs.pop("with_channels", True) channels_to_standardize = kwargs.pop("channels_to_standardize", [0]) if verbose: print("Segmenting cristae in volume of shape", input_volume.shape) # Create the scaler to handle prediction with a different scaling factor. scaler = _Scaler(scale, verbose) - input_volume = scaler.scale_input(input_volume) + # rescale each channel + volume = scaler.scale_input(input_volume) + mito_seg = scaler.scale_input(mitochondria, is_segmentation=True) + input_volume = np.stack([volume, mito_seg], axis=0) # Run prediction and segmentation. if mask is not None: diff --git a/synapse_net/inference/inference.py b/synapse_net/inference/inference.py index 32e3f700..9cbdd143 100644 --- a/synapse_net/inference/inference.py +++ b/synapse_net/inference/inference.py @@ -112,6 +112,7 @@ def get_model_training_resolution(model_type: str) -> Dict[str, float]: "active_zone": {"x": 1.44, "y": 1.44, "z": 1.44}, "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, + "cristae": {"x": 1.44, "y": 1.44, "z": 1.44}, "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, "vesicles_2d": {"x": 1.35, "y": 1.35}, "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, diff --git a/synapse_net/inference/util.py b/synapse_net/inference/util.py index 454df943..2afcd6ea 100644 --- a/synapse_net/inference/util.py +++ b/synapse_net/inference/util.py @@ -59,7 +59,7 @@ def scale_input(self, input_volume, is_segmentation=False): if self._original_shape is None: self._original_shape = input_volume.shape - elif self._oringal_shape != input_volume.shape: + elif self._original_shape != input_volume.shape: raise RuntimeError( "Scaler was called with different input shapes. " "This is not supported, please create a new instance of the class for it." diff --git a/synapse_net/tools/segmentation_widget.py b/synapse_net/tools/segmentation_widget.py index 87b47571..50fd16e5 100644 --- a/synapse_net/tools/segmentation_widget.py +++ b/synapse_net/tools/segmentation_widget.py @@ -189,8 +189,11 @@ def on_predict(self): extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) kwargs = {"extra_segmentation": extra_seg} elif model_type == "cristae": # Cristae model expects 2 3D volumes - image = np.stack([image, self._get_layer_selector_data(self.extra_seg_selector_name)], axis=0) - kwargs = {} + kwargs = { + "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name), + "with_channels": True, + "channels_to_standardize": [0] + } else: kwargs = {} segmentation = run_segmentation(