Skip to content

Commit cfe0be7

Browse files
committed
Rewrite logic of keypoint menus
1 parent 5a5709d commit cfe0be7

File tree

2 files changed

+76
-37
lines changed

2 files changed

+76
-37
lines changed

src/napari_deeplabcut/_widgets.py

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ def __init__(self, napari_viewer):
203203

204204
self._radio_group = self._form_mode_radio_buttons()
205205

206+
self._display = ColorSchemeDisplay(parent=self)
206207
self._color_scheme_display = self._form_color_scheme_display(self.viewer)
207-
208208
self._view_scheme_cb.toggled.connect(self._show_color_scheme)
209209
self._view_scheme_cb.toggle()
210210

@@ -349,29 +349,23 @@ def _func():
349349
return group
350350

351351
def _form_color_scheme_display(self, viewer):
352-
display = ColorSchemeDisplay(parent=self)
353-
self._update_color_scheme(display)
354-
355-
self.viewer.layers.events.inserted.connect(
356-
partial(self._update_color_scheme, display)
357-
)
358-
352+
self.viewer.layers.events.inserted.connect(self._update_color_scheme)
359353
return viewer.window.add_dock_widget(
360-
display, name="Color scheme reference", area="left"
354+
self._display, name="Color scheme reference", area="left"
361355
)
362356

363-
def _update_color_scheme(self, display):
357+
def _update_color_scheme(self):
358+
def to_hex(nparray):
359+
a = np.array(nparray * 255, dtype=int)
360+
rgb2hex = lambda r, g, b, _: f"#{r:02x}{g:02x}{b:02x}"
361+
res = rgb2hex(*a)
362+
return res
363+
364+
self._display.reset()
364365
for layer in self.viewer.layers:
365366
if isinstance(layer, Points) and layer.metadata:
366-
367-
def to_hex(nparray):
368-
a = np.array(nparray * 255, dtype=int)
369-
rgb2hex = lambda r, g, b, _: f"#{r:02x}{g:02x}{b:02x}"
370-
res = rgb2hex(*a)
371-
return res
372-
373367
[
374-
display.add_entry(name, to_hex(color))
368+
self._display.add_entry(name, to_hex(color))
375369
for name, color in layer.metadata["face_color_cycles"][
376370
"label"
377371
].items()
@@ -434,6 +428,22 @@ def on_insert(self, event):
434428
10, partial(self._move_image_layer_to_bottom, event.index)
435429
)
436430
elif isinstance(layer, Points):
431+
# If the Points layer comes from a config file and some Points layers
432+
# were already added, then we only update existing store's metadata.
433+
if layer.metadata.get("project", "") and self._stores:
434+
for _layer, store in self._stores.items():
435+
_layer.metadata["header"] = layer.metadata["header"]
436+
_layer.metadata["face_color_cycles"] = layer.metadata["face_color_cycles"]
437+
_layer.face_color_cycle = layer.face_color_cycle
438+
store.layer = _layer
439+
440+
for menu in self._menus:
441+
menu._map_individuals_to_bodyparts()
442+
menu._update_items()
443+
444+
self._update_color_scheme()
445+
return
446+
437447
store = keypoints.KeypointStore(self.viewer, layer)
438448
self._stores[layer] = store
439449
# TODO Set default dir of the save file dialog
@@ -469,7 +479,8 @@ def on_insert(self, event):
469479

470480
def on_remove(self, event):
471481
layer = event.value
472-
if isinstance(layer, Points):
482+
n_points_layer = sum(isinstance(l, Points) for l in self.viewer.layers)
483+
if isinstance(layer, Points) and n_points_layer == 0:
473484
if self._color_scheme_display is not None:
474485
self.viewer.window.remove_dock_widget(self._color_scheme_display)
475486
self._stores.pop(layer, None)
@@ -528,7 +539,7 @@ def toggle_edge_color(layer):
528539
class DropdownMenu(QComboBox):
529540
def __init__(self, labels: Sequence[str], parent: Optional[QWidget] = None):
530541
super().__init__(parent)
531-
self.addItems(labels)
542+
self.update_items(labels)
532543

533544
def update_to(self, text: str):
534545
index = self.findText(text)
@@ -538,6 +549,10 @@ def update_to(self, text: str):
538549
def reset(self):
539550
self.setCurrentIndex(0)
540551

552+
def update_items(self, items):
553+
self.clear()
554+
self.addItems(items)
555+
541556

542557
class KeypointsDropdownMenu(QWidget):
543558
def __init__(
@@ -549,22 +564,11 @@ def __init__(
549564
self.store = store
550565
self.store.layer.events.current_properties.connect(self.update_menus)
551566

552-
# Map individuals to their respective bodyparts
553567
self.id2label = defaultdict(list)
554-
for keypoint in store._keypoints:
555-
label = keypoint.label
556-
id_ = keypoint.id
557-
if label not in self.id2label[id_]:
558-
self.id2label[id_].append(label)
559-
560568
self.menus = dict()
561-
if store.ids[0]:
562-
menu = create_dropdown_menu(store, list(self.id2label), "id")
563-
menu.currentTextChanged.connect(self.refresh_label_menu)
564-
self.menus["id"] = menu
565-
self.menus["label"] = create_dropdown_menu(
566-
store, self.id2label[store.ids[0]], "label"
567-
)
569+
self._map_individuals_to_bodyparts()
570+
self._populate_menus()
571+
568572
layout1 = QVBoxLayout()
569573
layout1.addStretch(1)
570574
group_box = QGroupBox("Keypoint selection")
@@ -575,6 +579,29 @@ def __init__(
575579
layout1.addWidget(group_box)
576580
self.setLayout(layout1)
577581

582+
def _map_individuals_to_bodyparts(self):
583+
for keypoint in self.store._keypoints:
584+
label = keypoint.label
585+
id_ = keypoint.id
586+
if label not in self.id2label[id_]:
587+
self.id2label[id_].append(label)
588+
589+
def _populate_menus(self):
590+
id_ = self.store.ids[0]
591+
if id_:
592+
menu = create_dropdown_menu(self.store, list(self.id2label), "id")
593+
menu.currentTextChanged.connect(self.refresh_label_menu)
594+
self.menus["id"] = menu
595+
self.menus["label"] = create_dropdown_menu(
596+
self.store, self.id2label[id_], "label",
597+
)
598+
599+
def _update_items(self):
600+
id_ = self.store.ids[0]
601+
if id_:
602+
self.menus["id"].update_items(list(self.id2label))
603+
self.menus["label"].update_items(self.id2label[id_])
604+
578605
def update_menus(self, event):
579606
keypoint = self.store.current_keypoint
580607
for attr, menu in self.menus.items():
@@ -805,5 +832,8 @@ def add_entry(self, name, color):
805832
self._layout.addWidget(
806833
LabelPair(color, name, self), alignment=Qt.AlignmentFlag.AlignLeft
807834
)
808-
self._container.setLayout(self._layout)
809-
self._container.update()
835+
836+
def reset(self):
837+
self.scheme_dict = {}
838+
for i in reversed(range(self._layout.count())):
839+
self._layout.itemAt(i).widget().deleteLater()

src/napari_deeplabcut/keypoints.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,21 @@ def default(cls):
5959
class KeypointStore:
6060
def __init__(self, viewer, layer: Points):
6161
self.viewer = viewer
62+
self._keypoints = []
6263
self.layer = layer
64+
self.viewer.dims.set_current_step(0, 0)
65+
66+
@property
67+
def layer(self):
68+
return self._layer
69+
70+
@layer.setter
71+
def layer(self, layer):
72+
self._layer = layer
6373
all_pairs = self.layer.metadata["header"].form_individual_bodypart_pairs()
6474
self._keypoints = [
6575
Keypoint(label, id_) for id_, label in all_pairs
6676
] # Ordered references to all possible keypoints
67-
self.viewer.dims.set_current_step(0, 0)
6877

6978
@property
7079
def current_step(self):

0 commit comments

Comments
 (0)