Skip to content

Commit c05d4a8

Browse files
authored
Merge pull request #51 from DeepLabCut/keypoint_diff
Support the addition of new body parts after having started labeling
2 parents eabe0a5 + 4e915f2 commit c05d4a8

File tree

3 files changed

+96
-39
lines changed

3 files changed

+96
-39
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ Suggested workflows, depending on the image folder contents:
101101

102102
Saving works as described in *1*.
103103

104+
***Note that if a new body part has been added to the `config.yaml` file after having started to label, loading the config in the GUI is necessary to update the dropdown menus and other metadata.***
105+
104106
3. **Refining labels** – the image folder contains a `machinelabels-iter<#>.h5` file.
105107

106108
The process is analog to *2*.

src/napari_deeplabcut/_widgets.py

Lines changed: 84 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ def __init__(self, napari_viewer):
203203

204204
self._radio_group = self._form_mode_radio_buttons()
205205

206+
self._display = ColorSchemeDisplay(parent=self)
206207
self._color_scheme_display = self._form_color_scheme_display(self.viewer)
207-
208208
self._view_scheme_cb.toggled.connect(self._show_color_scheme)
209209
self._view_scheme_cb.toggle()
210210

@@ -354,29 +354,23 @@ def _func():
354354
return group
355355

356356
def _form_color_scheme_display(self, viewer):
357-
display = ColorSchemeDisplay(parent=self)
358-
self._update_color_scheme(display)
359-
360-
self.viewer.layers.events.inserted.connect(
361-
partial(self._update_color_scheme, display)
362-
)
363-
357+
self.viewer.layers.events.inserted.connect(self._update_color_scheme)
364358
return viewer.window.add_dock_widget(
365-
display, name="Color scheme reference", area="left"
359+
self._display, name="Color scheme reference", area="left"
366360
)
367361

368-
def _update_color_scheme(self, display):
362+
def _update_color_scheme(self):
363+
def to_hex(nparray):
364+
a = np.array(nparray * 255, dtype=int)
365+
rgb2hex = lambda r, g, b, _: f"#{r:02x}{g:02x}{b:02x}"
366+
res = rgb2hex(*a)
367+
return res
368+
369+
self._display.reset()
369370
for layer in self.viewer.layers:
370371
if isinstance(layer, Points) and layer.metadata:
371-
372-
def to_hex(nparray):
373-
a = np.array(nparray * 255, dtype=int)
374-
rgb2hex = lambda r, g, b, _: f"#{r:02x}{g:02x}{b:02x}"
375-
res = rgb2hex(*a)
376-
return res
377-
378372
[
379-
display.add_entry(name, to_hex(color))
373+
self._display.add_entry(name, to_hex(color))
380374
for name, color in layer.metadata["face_color_cycles"][
381375
"label"
382376
].items()
@@ -439,6 +433,38 @@ def on_insert(self, event):
439433
10, partial(self._move_image_layer_to_bottom, event.index)
440434
)
441435
elif isinstance(layer, Points):
436+
# If the current Points layer comes from a config file, some have already
437+
# been added and the body part names are different from the existing ones,
438+
# then we update store's metadata and menus.
439+
if layer.metadata.get("project", "") and self._stores:
440+
keypoints_menu = self._menus[0].menus["label"]
441+
current_keypoint_set = set(
442+
keypoints_menu.itemText(i) for i in range(keypoints_menu.count())
443+
)
444+
new_keypoint_set = set(layer.metadata["header"].bodyparts)
445+
diff = new_keypoint_set.difference(current_keypoint_set)
446+
if diff:
447+
answer = QMessageBox.question(self, "", "Do you want to display the new keypoints only?")
448+
if answer == QMessageBox.Yes:
449+
self.viewer.layers[-2].shown = False
450+
451+
self.viewer.status = f"New keypoint{'s' if len(diff) > 1 else ''} {', '.join(diff)} found."
452+
for _layer, store in self._stores.items():
453+
_layer.metadata["header"] = layer.metadata["header"]
454+
_layer.metadata["face_color_cycles"] = layer.metadata["face_color_cycles"]
455+
_layer.face_color_cycle = layer.face_color_cycle
456+
store.layer = _layer
457+
458+
for menu in self._menus:
459+
menu._map_individuals_to_bodyparts()
460+
menu._update_items()
461+
462+
self._update_color_scheme()
463+
464+
# Remove the unnecessary layer newly added
465+
QTimer.singleShot(10, self.viewer.layers.pop)
466+
return
467+
442468
store = keypoints.KeypointStore(self.viewer, layer)
443469
self._stores[layer] = store
444470
# TODO Set default dir of the save file dialog
@@ -470,14 +496,15 @@ def on_insert(self, event):
470496

471497
def on_remove(self, event):
472498
layer = event.value
473-
if isinstance(layer, Points):
499+
n_points_layer = sum(isinstance(l, Points) for l in self.viewer.layers)
500+
if isinstance(layer, Points) and n_points_layer == 0:
474501
if self._color_scheme_display is not None:
475-
self.viewer.window.remove_dock_widget(self._color_scheme_display)
502+
self._display.reset()
476503
self._stores.pop(layer, None)
477504
while self._menus:
478505
menu = self._menus.pop()
479506
self._layout.removeWidget(menu)
480-
menu.setParent(None)
507+
menu.deleteLater()
481508
menu.destroy()
482509
self._trail_cb.setEnabled(False)
483510
self.last_saved_label.hide()
@@ -529,7 +556,7 @@ def toggle_edge_color(layer):
529556
class DropdownMenu(QComboBox):
530557
def __init__(self, labels: Sequence[str], parent: Optional[QWidget] = None):
531558
super().__init__(parent)
532-
self.addItems(labels)
559+
self.update_items(labels)
533560

534561
def update_to(self, text: str):
535562
index = self.findText(text)
@@ -539,6 +566,10 @@ def update_to(self, text: str):
539566
def reset(self):
540567
self.setCurrentIndex(0)
541568

569+
def update_items(self, items):
570+
self.clear()
571+
self.addItems(items)
572+
542573

543574
class KeypointsDropdownMenu(QWidget):
544575
def __init__(
@@ -551,22 +582,11 @@ def __init__(
551582
self.store.layer.events.current_properties.connect(self.update_menus)
552583
self._locked = False
553584

554-
# Map individuals to their respective bodyparts
555585
self.id2label = defaultdict(list)
556-
for keypoint in store._keypoints:
557-
label = keypoint.label
558-
id_ = keypoint.id
559-
if label not in self.id2label[id_]:
560-
self.id2label[id_].append(label)
561-
562586
self.menus = dict()
563-
if store.ids[0]:
564-
menu = create_dropdown_menu(store, list(self.id2label), "id")
565-
menu.currentTextChanged.connect(self.refresh_label_menu)
566-
self.menus["id"] = menu
567-
self.menus["label"] = create_dropdown_menu(
568-
store, self.id2label[store.ids[0]], "label"
569-
)
587+
self._map_individuals_to_bodyparts()
588+
self._populate_menus()
589+
570590
layout1 = QVBoxLayout()
571591
layout1.addStretch(1)
572592
group_box = QGroupBox("Keypoint selection")
@@ -582,6 +602,29 @@ def __init__(
582602
layout1.addWidget(group_box)
583603
self.setLayout(layout1)
584604

605+
def _map_individuals_to_bodyparts(self):
606+
for keypoint in self.store._keypoints:
607+
label = keypoint.label
608+
id_ = keypoint.id
609+
if label not in self.id2label[id_]:
610+
self.id2label[id_].append(label)
611+
612+
def _populate_menus(self):
613+
id_ = self.store.ids[0]
614+
if id_:
615+
menu = create_dropdown_menu(self.store, list(self.id2label), "id")
616+
menu.currentTextChanged.connect(self.refresh_label_menu)
617+
self.menus["id"] = menu
618+
self.menus["label"] = create_dropdown_menu(
619+
self.store, self.id2label[id_], "label",
620+
)
621+
622+
def _update_items(self):
623+
id_ = self.store.ids[0]
624+
if id_:
625+
self.menus["id"].update_items(list(self.id2label))
626+
self.menus["label"].update_items(self.id2label[id_])
627+
585628
def _lock_current_keypoint(self):
586629
self._locked = not self._locked
587630
if self._locked:
@@ -833,5 +876,8 @@ def add_entry(self, name, color):
833876
self._layout.addWidget(
834877
LabelPair(color, name, self), alignment=Qt.AlignmentFlag.AlignLeft
835878
)
836-
self._container.setLayout(self._layout)
837-
self._container.update()
879+
880+
def reset(self):
881+
self.scheme_dict = {}
882+
for i in reversed(range(self._layout.count())):
883+
self._layout.itemAt(i).widget().deleteLater()

src/napari_deeplabcut/keypoints.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,21 @@ def default(cls):
6565
class KeypointStore:
6666
def __init__(self, viewer, layer: Points):
6767
self.viewer = viewer
68+
self._keypoints = []
6869
self.layer = layer
70+
self.viewer.dims.set_current_step(0, 0)
71+
72+
@property
73+
def layer(self):
74+
return self._layer
75+
76+
@layer.setter
77+
def layer(self, layer):
78+
self._layer = layer
6979
all_pairs = self.layer.metadata["header"].form_individual_bodypart_pairs()
7080
self._keypoints = [
7181
Keypoint(label, id_) for id_, label in all_pairs
7282
] # Ordered references to all possible keypoints
73-
self.viewer.dims.set_current_step(0, 0)
7483

7584
@property
7685
def current_step(self):

0 commit comments

Comments
 (0)