Skip to content

Commit 453a805

Browse files
authored
Update model combination tooltips and adapt other GUIs (#964)
Update model combination tooltips and adapt image series annotator model and training GUI
1 parent de4d894 commit 453a805

File tree

6 files changed

+177
-154
lines changed

6 files changed

+177
-154
lines changed

micro_sam/napari.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name: micro-sam
2-
display_name: SegmentAnything for Microscopy
2+
display_name: Segment Anything for Microscopy
33
# see https://napari.org/stable/plugins/manifest.html for valid categories
44
categories: ["Segmentation", "Annotation"]
55
contributions:

micro_sam/sam_annotator/_tooltips.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
"embeddings_save_path": "Select path to save or load the computed image embeddings.",
99
"halo": "Enter overlap values for computing tiled embeddings. Enter only x-value for quadratic size.\n Only active when tiling is used.", # noqa
1010
"image": "Select the napari image layer.",
11-
"model": "Select the segment anything model.",
11+
"model_family": "Select the segment anything model family.",
12+
"model_size": "Select the image encoder size of the segment anything model.",
1213
"automatic_segmentation_mode": "Select the automatic segmentation mode.",
1314
"run_button": "Compute embeddings or load embeddings if embedding_save_path is specified.",
1415
"tiling": "Enter tile size for computing tiled embeddings. Enter only x-value for quadratic size or both for non-quadratic.", # noqa

micro_sam/sam_annotator/_widgets.py

Lines changed: 126 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,127 @@ def _get_file_path(self, name, textbox, tooltip=None):
232232
# Handle the case where the selected path is not a file
233233
print("Invalid file selected. Please try again.")
234234

235+
def _get_model_size_options(self):
236+
# We store the actual model names mapped to UI labels.
237+
self.model_size_mapping = {}
238+
if self.model_family == "Natural Images (SAM)":
239+
self.model_size_options = list(self._model_size_map .values())
240+
self.model_size_mapping = {self._model_size_map[k]: f"vit_{k}" for k in self._model_size_map.keys()}
241+
else:
242+
model_suffix = self.supported_dropdown_maps[self.model_family]
243+
self.model_size_options = []
244+
245+
for option in self.model_options:
246+
if option.endswith(model_suffix):
247+
# Extract model size character on-the-fly.
248+
key = next((k for k in self._model_size_map .keys() if f"vit_{k}" in option), None)
249+
if key:
250+
size_label = self._model_size_map[key]
251+
self.model_size_options.append(size_label)
252+
self.model_size_mapping[size_label] = option # Store the actual model name.
253+
254+
# We ensure an assorted order of model sizes ('tiny' to 'huge')
255+
self.model_size_options.sort(key=lambda x: ["tiny", "base", "large", "huge"].index(x))
256+
257+
def _update_model_type(self):
258+
# Get currently selected model size (before clearing dropdown)
259+
current_selection = self.model_size_dropdown.currentText()
260+
self._get_model_size_options() # Update model size options dynamically
261+
262+
# NOTE: We need to prevent recursive updates for this step temporarily.
263+
self.model_size_dropdown.blockSignals(True)
264+
265+
# Let's clear and recreate the dropdown.
266+
self.model_size_dropdown.clear()
267+
self.model_size_dropdown.addItems(self.model_size_options)
268+
269+
# We restore the previous selection, if still valid.
270+
if current_selection in self.model_size_options:
271+
self.model_size = current_selection
272+
else:
273+
if self.model_size_options: # Default to the first available model size
274+
self.model_size = self.model_size_options[0]
275+
276+
# Let's map the selection to the correct model type (eg. "tiny" -> "vit_t")
277+
size_key = next(
278+
(k for k, v in self._model_size_map.items() if v == self.model_size), "b"
279+
)
280+
self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family]
281+
282+
self.model_size_dropdown.setCurrentText(self.model_size) # Apply the selected text to the dropdown
283+
284+
# We force a refresh for UI here.
285+
self.model_size_dropdown.update()
286+
287+
# NOTE: And finally, we should re-enable signals again.
288+
self.model_size_dropdown.blockSignals(False)
289+
290+
def _create_model_section(self, default_model: str = util._DEFAULT_MODEL, create_layout: bool = True):
291+
292+
# Create a list of support dropdown values and correspond them to suffixes.
293+
self.supported_dropdown_maps = {
294+
"Natural Images (SAM)": "",
295+
"Light Microscopy": "_lm",
296+
"Electron Microscopy": "_em_organelles",
297+
"Medical Imaging": "_medical_imaging",
298+
"Histopathology": "_histopathology",
299+
}
300+
301+
# NOTE: The available options for all are either 'tiny', 'base', 'large' or 'huge'.
302+
self._model_size_map = {"t": "tiny", "b": "base", "l": "large", "h": "huge"}
303+
304+
self._default_model_choice = default_model
305+
# Let's set the literally default model choice depending on 'micro-sam'.
306+
self.model_family = {v: k for k, v in self.supported_dropdown_maps.items()}[self._default_model_choice[5:]]
307+
308+
kwargs = {}
309+
if create_layout:
310+
layout = QtWidgets.QVBoxLayout()
311+
kwargs["layout"] = layout
312+
313+
# NOTE: We stick to the base variant for each model family.
314+
# i.e. 'Natural Images (SAM)', 'Light Microscopy', 'Electron Microscopy', 'Medical_Imaging', 'Histopathology'.
315+
self.model_family_dropdown, layout = self._add_choice_param(
316+
"model_family", self.model_family, list(self.supported_dropdown_maps.keys()),
317+
title="Model:", tooltip=get_tooltip("embedding", "model_family"), **kwargs,
318+
)
319+
self.model_family_dropdown.currentTextChanged.connect(self._update_model_type)
320+
return layout
321+
322+
def _create_model_size_section(self):
323+
324+
# Create UI for the model size.
325+
# This would combine with the chosen 'self.model_family' and depend on 'self._default_model_choice'.
326+
self.model_size = self._model_size_map[self._default_model_choice[4]]
327+
328+
# Get all model options.
329+
self.model_options = list(util.models().urls.keys())
330+
# Filter out the decoders from the model list.
331+
self.model_options = [model for model in self.model_options if not model.endswith("decoder")]
332+
333+
# Now, we get the available sizes per model family.
334+
self._get_model_size_options()
335+
336+
self.model_size_dropdown, layout = self._add_choice_param(
337+
"model_size", self.model_size, self.model_size_options,
338+
title="model size:", tooltip=get_tooltip("embedding", "model_size"),
339+
)
340+
self.model_size_dropdown.currentTextChanged.connect(self._update_model_type)
341+
return layout
342+
343+
def _validate_model_type_and_custom_weights(self):
344+
# Let's get all model combination stuff into the desired `model_type` structure.
345+
self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family]
346+
347+
# For 'custom_weights', we remove the displayed text on top of the drop-down menu.
348+
if self.custom_weights:
349+
# NOTE: We prevent recursive updates for this step temporarily.
350+
self.model_family_dropdown.blockSignals(True)
351+
self.model_family_dropdown.setCurrentIndex(-1) # This removes the displayed text.
352+
self.model_family_dropdown.update()
353+
# NOTE: And re-enable signals again.
354+
self.model_family_dropdown.blockSignals(False)
355+
235356

236357
# Custom signals for managing progress updates.
237358
class PBarSignals(QObject):
@@ -1016,7 +1137,7 @@ def __init__(self, parent=None):
10161137
# Section 1: Image and Model.
10171138
section1_layout = QtWidgets.QHBoxLayout()
10181139
section1_layout.addLayout(self._create_image_section())
1019-
section1_layout.addLayout(self._create_model_section())
1140+
section1_layout.addLayout(self._create_model_section()) # Creates the model family widget section.
10201141
self.layout().addLayout(section1_layout)
10211142

10221143
# Section 2: Settings (collapsible).
@@ -1103,116 +1224,15 @@ def _update_model(self, state):
11031224
if "segment_nd" in state.widgets:
11041225
vutil._sync_ndsegment_widget(state.widgets["segment_nd"], _model_type, self.custom_weights)
11051226

1106-
def _update_model_type(self):
1107-
# Get currently selected model size (before clearing dropdown)
1108-
current_selection = self.model_size_dropdown.currentText()
1109-
self._get_model_size_options() # Update model size options dynamically
1110-
1111-
# NOTE: We need to prevent recursive updates for this step temporarily.
1112-
self.model_size_dropdown.blockSignals(True)
1113-
1114-
# Let's clear and recreate the dropdown.
1115-
self.model_size_dropdown.clear()
1116-
self.model_size_dropdown.addItems(self.model_size_options)
1117-
1118-
# We restore the previous selection, if still valid.
1119-
if current_selection in self.model_size_options:
1120-
self.model_size = current_selection
1121-
else:
1122-
if self.model_size_options: # Default to the first available model size
1123-
self.model_size = self.model_size_options[0]
1124-
1125-
# Let's map the selection to the correct model type (eg. "tiny" -> "vit_t")
1126-
size_key = next(
1127-
(k for k, v in self._model_size_map.items() if v == self.model_size), "b"
1128-
)
1129-
self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family]
1130-
1131-
self.model_size_dropdown.setCurrentText(self.model_size) # Apply the selected text to the dropdown
1132-
1133-
# We force a refresh for UI here.
1134-
self.model_size_dropdown.update()
1135-
1136-
# NOTE: And finally, we should re-enable signals again.
1137-
self.model_size_dropdown.blockSignals(False)
1138-
1139-
def _get_model_size_options(self):
1140-
# We store the actual model names mapped to UI labels.
1141-
self.model_size_mapping = {}
1142-
if self.model_family == "Natural Images (SAM)":
1143-
self.model_size_options = list(self._model_size_map .values())
1144-
self.model_size_mapping = {self._model_size_map[k]: f"vit_{k}" for k in self._model_size_map.keys()}
1145-
else:
1146-
model_suffix = self.supported_dropdown_maps[self.model_family]
1147-
self.model_size_options = []
1148-
1149-
for option in self.model_options:
1150-
if option.endswith(model_suffix):
1151-
# Extract model size character on-the-fly.
1152-
key = next((k for k in self._model_size_map .keys() if f"vit_{k}" in option), None)
1153-
if key:
1154-
size_label = self._model_size_map[key]
1155-
self.model_size_options.append(size_label)
1156-
self.model_size_mapping[size_label] = option # Store the actual model name.
1157-
1158-
# We ensure an assorted order of model sizes ('tiny' to 'huge')
1159-
self.model_size_options.sort(key=lambda x: ["tiny", "base", "large", "huge"].index(x))
1160-
1161-
def _create_model_section(self):
1162-
# Create a list of support dropdown values and correspond them to suffixes.
1163-
self.supported_dropdown_maps = {
1164-
"Natural Images (SAM)": "",
1165-
"Light Microscopy": "_lm",
1166-
"Electron Microscopy": "_em_organelles",
1167-
"Medical Imaging": "_medical_imaging",
1168-
"Histopathology": "_histopathology",
1169-
}
1170-
1171-
# NOTE: The available options for all are either 'tiny', 'base', 'large' or 'huge'.
1172-
self._model_size_map = {"t": "tiny", "b": "base", "l": "large", "h": "huge"}
1173-
1174-
self._default_model_choice = util._DEFAULT_MODEL
1175-
# Let's set the literally default model choice depending on 'micro-sam'.
1176-
self.model_family = {v: k for k, v in self.supported_dropdown_maps.items()}[self._default_model_choice[5:]]
1177-
1178-
layout = QtWidgets.QVBoxLayout()
1179-
1180-
# NOTE: We stick to the base variant for each model family.
1181-
# i.e. 'Natural Images (SAM)', 'Light Microscopy', 'Electron Microscopy', 'Medical_Imaging', 'Histopathology'.
1182-
self.model_family_dropdown, layout = self._add_choice_param(
1183-
"model_family", self.model_family, list(self.supported_dropdown_maps.keys()),
1184-
title="Model:", layout=layout, tooltip=get_tooltip("embedding", "model")
1185-
)
1186-
self.model_family_dropdown.currentTextChanged.connect(self._update_model_type)
1187-
return layout
1188-
11891227
def _create_settings_widget(self):
11901228
setting_values = QtWidgets.QWidget()
11911229
setting_values.setToolTip(get_tooltip("embedding", "settings"))
11921230
setting_values.setLayout(QtWidgets.QVBoxLayout())
11931231

1194-
# Create UI for the model size.
1195-
# This would combine with the chosen 'self.model_family' and depend on 'self._default_model_choice'.
1196-
self.model_size = self._model_size_map[self._default_model_choice[4]]
1197-
1198-
# Get all model options.
1199-
self.model_options = list(util.models().urls.keys())
1200-
# Filter out the decoders from the model list.
1201-
self.model_options = [model for model in self.model_options if not model.endswith("decoder")]
1202-
1203-
# Now, we get the available sizes per model family.
1204-
self._get_model_size_options()
1205-
1206-
self.model_size_dropdown, layout = self._add_choice_param(
1207-
"model_size", self.model_size, self.model_size_options,
1208-
title="model size:", tooltip=get_tooltip("embedding", "model"),
1209-
)
1210-
self.model_size_dropdown.currentTextChanged.connect(self._update_model_type)
1232+
# Add the model size widget section.
1233+
layout = self._create_model_size_section()
12111234
setting_values.layout().addLayout(layout)
12121235

1213-
# Now that all parameters in place, let's get them all into one `model_type`.
1214-
self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family]
1215-
12161236
# Create UI for the device.
12171237
self.device = "auto"
12181238
device_options = ["auto"] + util._available_devices()
@@ -1356,19 +1376,12 @@ def _validate_existing_embeddings(self, state):
13561376
return _generate_message(val_results["message_type"], val_results["message"])
13571377

13581378
def __call__(self, skip_validate=False):
1379+
self._validate_model_type_and_custom_weights()
1380+
13591381
# Validate user inputs.
13601382
if not skip_validate and self._validate_inputs():
13611383
return
13621384

1363-
# For 'custom_weights', we remove the displayed text on top of the drop-down menu.
1364-
if self.custom_weights:
1365-
# NOTE: We prevent recursive updates for this step temporarily.
1366-
self.model_family_dropdown.blockSignals(True)
1367-
self.model_family_dropdown.setCurrentIndex(-1) # This removes the displayed text.
1368-
self.model_family_dropdown.update()
1369-
# NOTE: And re-enable signals again.
1370-
self.model_family_dropdown.blockSignals(False)
1371-
13721385
# Get the image.
13731386
image = self.image_selection.get_value()
13741387

micro_sam/sam_annotator/image_series_annotator.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,6 @@ def __init__(self, viewer: napari.Viewer, parent=None):
370370
self.run_button.clicked.connect(self.__call__)
371371
self.layout().addWidget(self.run_button)
372372

373-
# model_type: str = util._DEFAULT_MODEL,
374373
def _create_options(self):
375374
self.folder = None
376375
_, layout = self._add_path_param(
@@ -388,19 +387,18 @@ def _create_options(self):
388387
)
389388
self.layout().addLayout(layout)
390389

391-
self.model_type = util._DEFAULT_MODEL
392-
model_options = list(util.models().urls.keys())
393-
model_options = [model for model in model_options if not model.endswith("decoder")]
394-
_, layout = self._add_choice_param(
395-
"model_type", self.model_type, model_options, title="Model:",
396-
tooltip=get_tooltip("embedding", "model")
397-
)
390+
# Add the model family widget section.
391+
layout = self._create_model_section(create_layout=False)
398392
self.layout().addLayout(layout)
399393

400394
def _create_settings(self):
401395
setting_values = QtWidgets.QWidget()
402396
setting_values.setLayout(QtWidgets.QVBoxLayout())
403397

398+
# Add the model size widget section.
399+
layout = self._create_model_size_section()
400+
setting_values.layout().addLayout(layout)
401+
404402
self.pattern = "*"
405403
_, layout = self._add_string_param(
406404
"pattern", self.pattern, tooltip=get_tooltip("image_series_annotator", "pattern")
@@ -463,12 +461,16 @@ def _validate_inputs(self):
463461
return False
464462

465463
def __call__(self, skip_validate=False):
464+
self._validate_model_type_and_custom_weights()
465+
466466
if not skip_validate and self._validate_inputs():
467467
return
468468
tile_shape, halo = widgets._process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
469469

470470
image_folder_annotator(
471-
self.folder, self.output_folder, self.pattern,
471+
input_folder=self.folder,
472+
output_folder=self.output_folder,
473+
pattern=self.pattern,
472474
model_type=self.model_type,
473475
embedding_path=self.embeddings_save_path,
474476
tile_shape=tile_shape, halo=halo, checkpoint_path=self.custom_weights,

0 commit comments

Comments
 (0)