| 
 | 1 | +import copy  | 
 | 2 | + | 
1 | 3 | import napari  | 
 | 4 | +import numpy as np  | 
 | 5 | + | 
2 | 6 | from napari.utils.notifications import show_info  | 
3 | 7 | from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox  | 
4 | 8 | 
 
  | 
5 | 9 | from .base_widget import BaseWidget  | 
6 | 10 | from .util import (run_segmentation, get_model, get_model_registry, _available_devices, get_device,  | 
7 | 11 |                    get_current_tiling, compute_scale_from_voxel_size, load_custom_model)  | 
8 |  | -from synaptic_reconstruction.inference.util import get_default_tiling  | 
9 |  | -import copy  | 
 | 12 | +from ..inference.util import get_default_tiling  | 
10 | 13 | 
 
  | 
11 | 14 | 
 
  | 
12 | 15 | class SegmentationWidget(BaseWidget):  | 
@@ -79,37 +82,41 @@ def on_predict(self):  | 
79 | 82 |             show_info("INFO: Please choose an image.")  | 
80 | 83 |             return  | 
81 | 84 | 
 
  | 
82 |  | -        # load current tiling  | 
 | 85 | +        # Get the current tiling.  | 
83 | 86 |         self.tiling = get_current_tiling(self.tiling, self.default_tiling, model_type)  | 
84 | 87 | 
 
  | 
 | 88 | +        # Get the voxel size.  | 
85 | 89 |         metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)  | 
86 |  | -        voxel_size = metadata.get("voxel_size", None)  | 
87 |  | -        scale = None  | 
 | 90 | +        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)  | 
88 | 91 | 
 
  | 
89 |  | -        if self.voxel_size_param.value() != 0.0:  # changed from default  | 
90 |  | -            voxel_size = {}  | 
91 |  | -            # override voxel size with user input  | 
92 |  | -            if len(image.shape) == 3:  | 
93 |  | -                voxel_size["x"] = self.voxel_size_param.value()  | 
94 |  | -                voxel_size["y"] = self.voxel_size_param.value()  | 
95 |  | -                voxel_size["z"] = self.voxel_size_param.value()  | 
96 |  | -            else:  | 
97 |  | -                voxel_size["x"] = self.voxel_size_param.value()  | 
98 |  | -                voxel_size["y"] = self.voxel_size_param.value()  | 
 | 92 | +        # Determine the scaling based on the voxel size.  | 
 | 93 | +        scale = None  | 
99 | 94 |         if voxel_size:  | 
100 | 95 |             if model_type == "custom":  | 
101 | 96 |                 show_info("INFO: The image is not rescaled for a custom model.")  | 
102 | 97 |             else:  | 
103 | 98 |                 # calculate scale so voxel_size is the same as in training  | 
104 | 99 |                 scale = compute_scale_from_voxel_size(voxel_size, model_type)  | 
105 |  | -                show_info(f"INFO: Rescaled the image by {scale} to optimize for the selected model.")  | 
106 |  | - | 
 | 100 | +                scale_info = list(map(lambda x: np.round(x, 2), scale))  | 
 | 101 | +                show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")  | 
 | 102 | + | 
 | 103 | +        # Some models require an additional segmentation for inference or postprocessing.  | 
 | 104 | +        # For these models we read out the 'Extra Segmentation' widget.  | 
 | 105 | +        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.  | 
 | 106 | +            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)  | 
 | 107 | +            kwargs = {"extra_segmentation": extra_seg}  | 
 | 108 | +        else:  | 
 | 109 | +            kwargs = {}  | 
107 | 110 |         segmentation = run_segmentation(  | 
108 |  | -            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale  | 
 | 111 | +            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs  | 
109 | 112 |         )  | 
110 | 113 | 
 
  | 
111 |  | -        # Add the segmentation layer  | 
112 |  | -        self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)  | 
 | 114 | +        # Add the segmentation layer(s).  | 
 | 115 | +        if isinstance(segmentation, dict):  | 
 | 116 | +            for name, seg in segmentation.items():  | 
 | 117 | +                self.viewer.add_labels(seg, name=name, metadata=metadata)  | 
 | 118 | +        else:  | 
 | 119 | +            self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)  | 
113 | 120 |         show_info(f"INFO: Segmentation of {model_type} added to layers.")  | 
114 | 121 | 
 
  | 
115 | 122 |     def _create_settings_widget(self):  | 
@@ -156,5 +163,10 @@ def _create_settings_widget(self):  | 
156 | 163 |         )  | 
157 | 164 |         setting_values.layout().addLayout(layout)  | 
158 | 165 | 
 
  | 
 | 166 | +        # Add selection UI for additional segmentation, which some models require for inference or postproc.  | 
 | 167 | +        self.extra_seg_selector_name = "Extra Segmentation"  | 
 | 168 | +        self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")  | 
 | 169 | +        setting_values.layout().addWidget(self.extra_selector_widget)  | 
 | 170 | + | 
159 | 171 |         settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")  | 
160 | 172 |         return settings  | 
0 commit comments