Skip to content

Commit 95102e4

Browse files
authored
Merge pull request #97 from DeepLabCut/colormap_select
2 parents 0e04f2f + e3267ac commit 95102e4

File tree

8 files changed

+75
-43
lines changed

8 files changed

+75
-43
lines changed

MANIFEST.in

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
include LICENSE
22
include README.md
3-
include src/napari_deeplabcut/assets/*.svg
43
include src/napari_deeplabcut/styles/*.mplstyle
54

65
recursive-exclude * __pycache__

src/napari_deeplabcut/_reader.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,7 @@ def _populate_metadata(
124124
ids = header.individuals
125125
if likelihood is None:
126126
likelihood = np.ones(len(labels))
127-
label_colors = misc.build_color_cycle(len(header.bodyparts), colormap)
128-
id_colors = misc.build_color_cycle(len(header.individuals), colormap)
129-
face_color_cycle_maps = {
130-
"label": dict(zip(header.bodyparts, label_colors)),
131-
"id": dict(zip(header.individuals, id_colors)),
132-
}
127+
face_color_cycle_maps = misc.build_color_cycles(header, colormap)
133128
face_color_prop = "id" if ids[0] else "label"
134129
return {
135130
"name": "keypoints",
@@ -151,6 +146,7 @@ def _populate_metadata(
151146
"metadata": {
152147
"header": header,
153148
"face_color_cycles": face_color_cycle_maps,
149+
"colormap_name": colormap,
154150
"paths": paths or [],
155151
},
156152
}

src/napari_deeplabcut/_tests/test_keypoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_point_resize(viewer, points):
4747
layer = viewer.layers[0]
4848
controls = keypoints.QtPointsControls(layer)
4949
new_size = 10
50-
controls.changeSize(new_size)
50+
controls.changeCurrentSize(new_size)
5151
np.testing.assert_array_equal(points.size, new_size)
5252

5353

src/napari_deeplabcut/_widgets.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from datetime import datetime
55
from functools import partial, cached_property
66
from math import ceil, log10
7+
import matplotlib.pyplot as plt
78
import matplotlib.style as mplstyle
89
import napari
910
import pandas as pd
@@ -21,8 +22,8 @@
2122
from napari.layers.utils.layer_utils import _features_to_properties
2223
from napari.utils.events import Event
2324
from napari.utils.history import get_save_history, update_save_history
24-
from qtpy.QtCore import Qt, QTimer, Signal, QSize, QPoint, QSettings
25-
from qtpy.QtGui import QPainter, QIcon, QAction, QCursor
25+
from qtpy.QtCore import Qt, QTimer, Signal, QPoint, QSettings
26+
from qtpy.QtGui import QPainter, QAction, QCursor
2627
from qtpy.QtWidgets import (
2728
QButtonGroup,
2829
QCheckBox,
@@ -45,15 +46,14 @@
4546
QWidget,
4647
)
4748

48-
ICON_FOLDER = os.path.join(os.path.dirname(__file__), "assets")
49-
5049
from napari_deeplabcut import keypoints
5150
from napari_deeplabcut._reader import _load_config
5251
from napari_deeplabcut._writer import _write_config, _write_image, _form_df
5352
from napari_deeplabcut.misc import (
5453
encode_categories,
5554
to_os_dir_sep,
5655
guarantee_multiindex_rows,
56+
build_color_cycles
5757
)
5858

5959

@@ -622,6 +622,13 @@ def __init__(self, napari_viewer):
622622
launch_tutorial.triggered.connect(self.start_tutorial)
623623
self.viewer.window.view_menu.addAction(launch_tutorial)
624624

625+
# Hide some unused viewer buttons
626+
self.viewer.window._qt_viewer.viewerButtons.gridViewButton.hide()
627+
self.viewer.window._qt_viewer.viewerButtons.rollDimsButton.hide()
628+
self.viewer.window._qt_viewer.viewerButtons.transposeDimsButton.hide()
629+
self.viewer.window._qt_viewer.layerButtons.newPointsButton.setDisabled(True)
630+
self.viewer.window._qt_viewer.layerButtons.newLabelsButton.setDisabled(True)
631+
625632
if self.settings.value("first_launch", True) and not os.environ.get(
626633
"hide_tutorial", False
627634
):
@@ -650,13 +657,17 @@ def _show_trails(self, state):
650657
store = list(self._stores.values())[0]
651658
inds = encode_categories(store.layer.properties["label"])
652659
temp = np.c_[inds, store.layer.data]
660+
cmap = "viridis"
661+
for layer in self.viewer.layers:
662+
if isinstance(layer, Points) and layer.metadata:
663+
cmap = layer.metadata["colormap_name"]
653664
self._trails = self.viewer.add_tracks(
654665
temp,
655666
tail_length=50,
656667
head_length=50,
657668
tail_width=6,
658669
name="trails",
659-
colormap="viridis",
670+
colormap=cmap,
660671
)
661672
self._trails.visible = True
662673
elif self._trails is not None:
@@ -933,6 +944,11 @@ def on_insert(self, event):
933944
# Hide out of slice checkbox
934945
point_controls.outOfSliceCheckBox.hide()
935946
point_controls.layout().itemAt(15).widget().hide()
947+
# Add dropdown menu for colormap picking
948+
colormap_selector = DropdownMenu(plt.colormaps, self)
949+
colormap_selector.update_to(layer.metadata["colormap_name"])
950+
colormap_selector.currentTextChanged.connect(self._update_colormap)
951+
point_controls.layout().addRow("colormap", colormap_selector)
936952

937953
for layer_ in self.viewer.layers:
938954
if not isinstance(layer_, Image):
@@ -963,6 +979,20 @@ def on_remove(self, event):
963979
self._matplotlib_cb.setChecked(False)
964980
self._trails = None
965981

982+
def _update_colormap(self, colormap_name):
983+
for layer in self.viewer.layers:
984+
if isinstance(layer, Points) and layer.metadata:
985+
face_color_cycle_maps = build_color_cycles(
986+
layer.metadata["header"], colormap_name,
987+
)
988+
layer.metadata["face_color_cycles"] = face_color_cycle_maps
989+
face_color_prop = layer._face.color_properties.name
990+
layer.face_color = face_color_prop
991+
layer.face_color_cycle = face_color_cycle_maps[face_color_prop]
992+
layer.events.face_color()
993+
self._update_color_scheme()
994+
break
995+
966996
@register_points_action("Change labeling mode")
967997
def cycle_through_label_modes(self, *args):
968998
self.label_mode = next(keypoints.LabelMode)
@@ -975,8 +1005,15 @@ def label_mode(self):
9751005
def label_mode(self, mode: Union[str, keypoints.LabelMode]):
9761006
self._label_mode = keypoints.LabelMode(mode)
9771007
self.viewer.status = self.label_mode
1008+
mode_ = str(mode)
1009+
if mode_ == "loop":
1010+
for menu in self._menus:
1011+
menu._locked = True
1012+
else:
1013+
for menu in self._menus:
1014+
menu._locked = False
9781015
for btn in self._radio_group.buttons():
979-
if btn.text() == str(mode):
1016+
if btn.text() == mode_:
9801017
btn.setChecked(True)
9811018
break
9821019

@@ -1038,11 +1075,6 @@ def __init__(
10381075
layout2 = QVBoxLayout()
10391076
for menu in self.menus.values():
10401077
layout2.addWidget(menu)
1041-
self.lock_button = QPushButton("Lock selection")
1042-
self.lock_button.setIcon(QIcon(os.path.join(ICON_FOLDER, "unlock.svg")))
1043-
self.lock_button.setIconSize(QSize(24, 24))
1044-
self.lock_button.clicked.connect(self._lock_current_keypoint)
1045-
layout2.addWidget(self.lock_button)
10461078
group_box.setLayout(layout2)
10471079
layout1.addWidget(group_box)
10481080
self.setLayout(layout1)
@@ -1072,15 +1104,6 @@ def _update_items(self):
10721104
self.menus["id"].update_items(list(self.id2label))
10731105
self.menus["label"].update_items(self.id2label[id_])
10741106

1075-
def _lock_current_keypoint(self):
1076-
self._locked = not self._locked
1077-
if self._locked:
1078-
self.lock_button.setText("Unlock selection")
1079-
self.lock_button.setIcon(QIcon(os.path.join(ICON_FOLDER, "lock.svg")))
1080-
else:
1081-
self.lock_button.setText("Lock selection")
1082-
self.lock_button.setIcon(QIcon(os.path.join(ICON_FOLDER, "unlock.svg")))
1083-
10841107
def update_menus(self, event):
10851108
keypoint = self.store.current_keypoint
10861109
for attr, menu in self.menus.items():
@@ -1097,7 +1120,7 @@ def refresh_label_menu(self, text: str):
10971120

10981121
def smart_reset(self, event):
10991122
"""Set current keypoint to the first unlabeled one."""
1100-
if self._locked:
1123+
if self._locked: # The currently selected point is not updated
11011124
return
11021125
unannotated = ""
11031126
already_annotated = self.store.annotated_keypoints
@@ -1360,4 +1383,5 @@ def add_entry(self, name, color):
13601383
def reset(self):
13611384
self.scheme_dict = {}
13621385
for i in reversed(range(self._layout.count())):
1363-
self._layout.itemAt(i).widget().deleteLater()
1386+
w = self._layout.itemAt(i).widget()
1387+
self._layout.removeWidget(w)

src/napari_deeplabcut/assets/lock.svg

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/napari_deeplabcut/assets/unlock.svg

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/napari_deeplabcut/keypoints.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
from napari._qt.layer_controls.qt_points_controls import QtPointsControls
77
from napari.layers import Points
8+
from napari.layers.points._points_constants import SYMBOL_TRANSLATION_INVERTED
9+
from napari.layers.points._points_utils import coerce_symbols
810

911
from napari_deeplabcut.misc import CycleEnum
1012

@@ -19,7 +21,17 @@ def _change_size(self, value):
1921
self.layer.events.size()
2022

2123

22-
QtPointsControls.changeSize = _change_size
24+
def _change_symbol(self, text):
25+
symbol = coerce_symbols(np.array([SYMBOL_TRANSLATION_INVERTED[text]]))[0]
26+
self.layer._current_symbol = symbol
27+
if self.layer._update_properties:
28+
self.layer.symbol = symbol
29+
self.layer.events.symbol()
30+
self.layer.events.current_symbol()
31+
32+
33+
QtPointsControls.changeCurrentSize = _change_size
34+
QtPointsControls.changeCurrentSymbol = _change_symbol
2335

2436

2537
class LabelMode(CycleEnum):
@@ -29,11 +41,8 @@ class LabelMode(CycleEnum):
2941
clicking to add an already annotated point has no effect.
3042
QUICK: similar to SEQUENTIAL, but trying to add an already
3143
annotated point actually moves it to the cursor location.
32-
LOOP: the first point is placed frame by frame, then it wraps
33-
to the next label at the end and restart from frame 1, etc.
34-
Unless the keypoint selection is locked, the dropdown menu is
35-
automatically set to the first unlabeled keypoint of
36-
the current frame.
44+
LOOP: the currently selected point is placed frame after frame,
45+
before wrapping at the end to frame 1, etc.
3746
"""
3847

3948
SEQUENTIAL = auto()
@@ -51,11 +60,8 @@ def default(cls):
5160
"clicking to add an already annotated point has no effect.",
5261
"QUICK": "Similar to SEQUENTIAL, but trying to add an already\n"
5362
"annotated point actually moves it to the cursor location.",
54-
"LOOP": "The first point is placed frame by frame, then it wraps\n"
55-
"to the next label at the end and restart from frame 1, etc.\n"
56-
"Unless the keypoint selection is locked, the dropdown menu is\n"
57-
"automatically set to the first unlabeled keypoint of\n"
58-
"the current frame.",
63+
"LOOP": "The currently selected point is placed frame after frame,\n"
64+
"before wrapping at the end to frame 1, etc.",
5965
}
6066

6167

src/napari_deeplabcut/misc.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ def build_color_cycle(n_colors: int, colormap: Optional[str] = "viridis") -> np.
101101
return cmap.map(np.linspace(0, 1, n_colors))
102102

103103

104+
def build_color_cycles(header: DLCHeader, colormap: Optional[str] = "viridis"):
105+
label_colors = build_color_cycle(len(header.bodyparts), colormap)
106+
id_colors = build_color_cycle(len(header.individuals), colormap)
107+
return {
108+
"label": dict(zip(header.bodyparts, label_colors)),
109+
"id": dict(zip(header.individuals, id_colors)),
110+
}
111+
112+
104113
class DLCHeader:
105114
def __init__(self, columns: pd.MultiIndex):
106115
self.columns = columns

0 commit comments

Comments
 (0)