Skip to content

Commit 6c26d0e

Browse files
committed
Add colormap selector
1 parent e7ac73a commit 6c26d0e

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

src/napari_deeplabcut/_reader.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,7 @@ def _populate_metadata(
124124
ids = header.individuals
125125
if likelihood is None:
126126
likelihood = np.ones(len(labels))
127-
label_colors = misc.build_color_cycle(len(header.bodyparts), colormap)
128-
id_colors = misc.build_color_cycle(len(header.individuals), colormap)
129-
face_color_cycle_maps = {
130-
"label": dict(zip(header.bodyparts, label_colors)),
131-
"id": dict(zip(header.individuals, id_colors)),
132-
}
127+
face_color_cycle_maps = misc.build_color_cycles(header, colormap)
133128
face_color_prop = "id" if ids[0] else "label"
134129
return {
135130
"name": "keypoints",
@@ -151,6 +146,7 @@ def _populate_metadata(
151146
"metadata": {
152147
"header": header,
153148
"face_color_cycles": face_color_cycle_maps,
149+
"colormap_name": colormap,
154150
"paths": paths or [],
155151
},
156152
}

src/napari_deeplabcut/_widgets.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from datetime import datetime
55
from functools import partial, cached_property
66
from math import ceil, log10
7+
import matplotlib.pyplot as plt
78
import pandas as pd
89
from pathlib import Path
910
from types import MethodType
@@ -49,6 +50,7 @@
4950
encode_categories,
5051
to_os_dir_sep,
5152
guarantee_multiindex_rows,
53+
build_color_cycles
5254
)
5355

5456

@@ -701,6 +703,11 @@ def on_insert(self, event):
701703
# Hide out of slice checkbox
702704
point_controls.outOfSliceCheckBox.hide()
703705
point_controls.layout().itemAt(15).widget().hide()
706+
# Add dropdown menu for colormap picking
707+
colormap_selector = DropdownMenu(plt.colormaps, self)
708+
colormap_selector.update_to(layer.metadata["colormap_name"])
709+
colormap_selector.currentTextChanged.connect(self._update_colormap)
710+
point_controls.layout().addRow("colormap", colormap_selector)
704711

705712
for layer_ in self.viewer.layers:
706713
if not isinstance(layer_, Image):
@@ -732,6 +739,20 @@ def on_remove(self, event):
732739
self._trail_cb.setChecked(False)
733740
self._trails = None
734741

742+
def _update_colormap(self, colormap_name):
743+
for layer in self.viewer.layers:
744+
if isinstance(layer, Points) and layer.metadata:
745+
face_color_cycle_maps = build_color_cycles(
746+
layer.metadata["header"], colormap_name,
747+
)
748+
layer.metadata["face_color_cycles"] = face_color_cycle_maps
749+
face_color_prop = layer._face.color_properties.name
750+
layer.face_color = face_color_prop
751+
layer.face_color_cycle = face_color_cycle_maps[face_color_prop]
752+
layer.events.face_color()
753+
self._update_color_scheme()
754+
break
755+
735756
@register_points_action("Change labeling mode")
736757
def cycle_through_label_modes(self, *args):
737758
self.label_mode = next(keypoints.LabelMode)
@@ -1127,4 +1148,5 @@ def add_entry(self, name, color):
11271148
def reset(self):
11281149
self.scheme_dict = {}
11291150
for i in reversed(range(self._layout.count())):
1130-
self._layout.itemAt(i).widget().deleteLater()
1151+
w = self._layout.itemAt(i).widget()
1152+
self._layout.removeWidget(w)

src/napari_deeplabcut/misc.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ def build_color_cycle(n_colors: int, colormap: Optional[str] = "viridis") -> np.
101101
return cmap.map(np.linspace(0, 1, n_colors))
102102

103103

104+
def build_color_cycles(header: DLCHeader, colormap: Optional[str] = "viridis"):
105+
label_colors = build_color_cycle(len(header.bodyparts), colormap)
106+
id_colors = build_color_cycle(len(header.individuals), colormap)
107+
return {
108+
"label": dict(zip(header.bodyparts, label_colors)),
109+
"id": dict(zip(header.individuals, id_colors)),
110+
}
111+
112+
104113
class DLCHeader:
105114
def __init__(self, columns: pd.MultiIndex):
106115
self.columns = columns

0 commit comments

Comments
 (0)