Skip to content

Commit 516232b

Browse files
committed
Polish keypoint canvas widget
1 parent c814d00 commit 516232b

File tree

4 files changed

+115
-52
lines changed

4 files changed

+115
-52
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
include LICENSE
22
include README.md
33
include src/napari_deeplabcut/assets/*.svg
4+
include src/napari_deeplabcut/styles/*.mplstyle
45

56
recursive-exclude * __pycache__
67
recursive-exclude * *.py[co]

src/napari_deeplabcut/_widgets.py

Lines changed: 90 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
from datetime import datetime
55
from functools import partial, cached_property
66
from math import ceil, log10
7+
import matplotlib.style as mplstyle
8+
import napari
79
import pandas as pd
810
from pathlib import Path
911
from types import MethodType
1012
from typing import Optional, Sequence, Union
1113

1214
from matplotlib.backends.backend_qtagg import FigureCanvas
13-
from matplotlib.figure import Figure
1415

1516
import numpy as np
1617
from napari._qt.widgets.qt_welcome import QtWelcomeLabel
@@ -305,9 +306,10 @@ def __init__(self, napari_viewer, parent=None):
305306
super().__init__(parent=parent)
306307

307308
self.viewer = napari_viewer
308-
self.figure = Figure()
309-
self.canvas = FigureCanvas(self.figure)
310-
self.ax = self.figure.add_subplot(111)
309+
with mplstyle.context(self.mpl_style_sheet_path):
310+
self.canvas = FigureCanvas()
311+
self.canvas.figure.set_layout_engine("constrained")
312+
self.ax = self.canvas.figure.subplots()
311313
self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--")
312314
self.ax.set_xlabel("Frame")
313315
self.ax.set_ylabel("Y position")
@@ -341,56 +343,92 @@ def __init__(self, napari_viewer, parent=None):
341343
self.viewer.dims.events.current_step.connect(self.update_plot_range)
342344

343345
# Run update plot range once to initialize the plot
346+
self._n = 0
344347
self.update_plot_range(
345348
Event(type_name="", value=[self.viewer.dims.current_step[0]])
346349
)
347350

348-
def set_window(self, value):
349-
self._window = value
350-
self.slider_value.setText(str(value))
351-
self.update_plot_range(
352-
Event(type_name="", value=[self.viewer.dims.current_step[0]])
353-
)
351+
self.viewer.layers.events.inserted.connect(self._load_dataframe)
352+
self._lines = {}
354353

355-
def update_plot_range(self, event):
356-
value = event.value[0]
357-
if self.df is None:
358-
points_layer = None
359-
for layer in self.viewer.layers:
360-
if isinstance(layer, Points):
361-
points_layer = layer
362-
break
354+
def _napari_theme_has_light_bg(self) -> bool:
355+
"""
356+
Does this theme have a light background?
363357
364-
if points_layer is None:
365-
return
358+
Returns
359+
-------
360+
bool
361+
True if theme's background colour has hsl lighter than 50%, False if darker.
362+
"""
363+
theme = napari.utils.theme.get_theme(self.viewer.theme, as_dict=False)
364+
_, _, bg_lightness = theme.background.as_hsl_tuple()
365+
return bg_lightness > 0.5
366366

367-
self.df = _form_df(
368-
points_layer.data,
369-
{
370-
"metadata": points_layer.metadata,
371-
"properties": points_layer.properties,
372-
},
373-
)
367+
@property
368+
def mpl_style_sheet_path(self) -> Path:
369+
"""
370+
Path to the set Matplotlib style sheet.
371+
"""
372+
if self._napari_theme_has_light_bg():
373+
return Path(__file__).parent / "styles" / "light.mplstyle"
374+
else:
375+
return Path(__file__).parent / "styles" / "dark.mplstyle"
376+
377+
def _load_dataframe(self):
378+
points_layer = None
379+
for layer in self.viewer.layers:
380+
if isinstance(layer, Points):
381+
points_layer = layer
382+
break
374383

375-
# Find the bodyparts names
376-
bodyparts = self.df.columns.get_level_values("bodyparts").unique()
377-
# Get only the body parts that contain the word limb in them
378-
limb_bodyparts = [limb for limb in bodyparts if "limb" in limb.lower()]
384+
if points_layer is None:
385+
return
386+
387+
self.viewer.window.add_dock_widget(self, name="Trajectory plot", area="right")
388+
self.hide()
389+
390+
self.df = _form_df(
391+
points_layer.data,
392+
{
393+
"metadata": points_layer.metadata,
394+
"properties": points_layer.properties,
395+
},
396+
)
397+
for keypoint in self.df.columns.get_level_values("bodyparts").unique():
398+
y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"])
399+
x = np.arange(len(y))
400+
color = points_layer.metadata["face_color_cycles"]["label"][keypoint]
401+
(line,) = self.ax.plot(x, y, color=color, label=keypoint)
402+
self._lines[keypoint] = line
403+
404+
self._refresh_canvas(value=self._n)
379405

380-
for limb in limb_bodyparts:
381-
y = self.df.xs((limb, "y"), axis=1, level=["bodyparts", "coords"])
382-
x = np.arange(len(y))
383-
# color by limb colormap using point layer metadata
384-
color = points_layer.metadata["face_color_cycles"]["label"][limb]
385-
self.ax.plot(x, y, color=color, label=limb)
406+
def _toggle_line_visibility(self, keypoint):
407+
artist = self._lines[keypoint]
408+
artist.set_visible(not artist.get_visible())
409+
self._refresh_canvas(value=self._n)
386410

411+
def _refresh_canvas(self, value):
387412
start = max(0, value - self._window // 2)
388413
end = min(value + self._window // 2, len(self.df))
389414

390415
self.ax.set_xlim(start, end)
391416
self.vline.set_xdata(value)
417+
self.canvas.draw()
392418

393-
self.canvas.draw_idle()
419+
def set_window(self, value):
420+
self._window = value
421+
self.slider_value.setText(str(value))
422+
self.update_plot_range(Event(type_name="", value=[self._n]))
423+
424+
def update_plot_range(self, event):
425+
value = event.value[0]
426+
self._n = value
427+
428+
if self.df is None:
429+
return
430+
431+
self._refresh_canvas(value)
394432

395433

396434
class KeypointControls(QWidget):
@@ -457,12 +495,12 @@ def __init__(self, napari_viewer):
457495
self._trails = None
458496

459497
matplotlib_label = QLabel("Show matplotlib canvas")
498+
self._matplotlib_canvas = KeypointMatplotlibCanvas(self.viewer)
460499
self._matplotlib_cb = QCheckBox()
461500
self._matplotlib_cb.setToolTip("toggle matplotlib canvas visibility")
501+
self._matplotlib_cb.stateChanged.connect(self._show_matplotlib_canvas)
462502
self._matplotlib_cb.setChecked(False)
463503
self._matplotlib_cb.setEnabled(False)
464-
self._matplotlib_cb.stateChanged.connect(self._show_matplotlib_canvas)
465-
self._matplotlib_canvas = None
466504
self._view_scheme_cb = QCheckBox("Show color scheme", parent=self)
467505

468506
hlayout.addWidget(self._matplotlib_cb)
@@ -479,6 +517,11 @@ def __init__(self, napari_viewer):
479517
self._color_scheme_display = self._form_color_scheme_display(self.viewer)
480518
self._view_scheme_cb.toggled.connect(self._show_color_scheme)
481519
self._view_scheme_cb.toggle()
520+
self._display.added.connect(
521+
lambda w: w.part_label.clicked.connect(
522+
self._matplotlib_canvas._toggle_line_visibility
523+
),
524+
)
482525

483526
# Substitute default menu action with custom one
484527
for action in self.viewer.window.file_menu.actions()[::-1]:
@@ -505,9 +548,6 @@ def __init__(self, napari_viewer):
505548
QTimer.singleShot(10, self.start_tutorial)
506549
self.settings.setValue("first_launch", False)
507550

508-
matplotlib_widget = KeypointMatplotlibCanvas(self.viewer)
509-
matplotlib_widget.setVisible(False)
510-
511551
@cached_property
512552
def settings(self):
513553
return QSettings()
@@ -544,13 +584,9 @@ def _show_trails(self, state):
544584

545585
def _show_matplotlib_canvas(self, state):
546586
if state == Qt.Checked:
547-
self._canvas = KeypointMatplotlibCanvas(self.viewer)
548-
self.viewer.window.add_dock_widget(
549-
self._canvas, name="Trajectory plot", area="bottom"
550-
)
551-
self._canvas.show()
587+
self._matplotlib_canvas.show()
552588
else:
553-
self._canvas.close()
589+
self._matplotlib_canvas.hide()
554590

555591
def _form_video_action_menu(self):
556592
group_box = QGroupBox("Video")
@@ -1192,6 +1228,8 @@ def part_name(self, part_name: str):
11921228

11931229

11941230
class ColorSchemeDisplay(QScrollArea):
1231+
added = Signal(object)
1232+
11951233
def __init__(self, parent):
11961234
super().__init__(parent)
11971235

@@ -1235,9 +1273,9 @@ def _build(self):
12351273
def add_entry(self, name, color):
12361274
self.scheme_dict.update({name: color})
12371275

1238-
self._layout.addWidget(
1239-
LabelPair(color, name, self), alignment=Qt.AlignmentFlag.AlignLeft
1240-
)
1276+
widget = LabelPair(color, name, self)
1277+
self._layout.addWidget(widget, alignment=Qt.AlignmentFlag.AlignLeft)
1278+
self.added.emit(widget)
12411279

12421280
def reset(self):
12431281
self.scheme_dict = {}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Dark-theme napari colour scheme for matplotlib plots
2+
3+
# text (very light grey - almost white): #f0f1f2
4+
# foreground (mid grey): #414851
5+
# background (dark blue-gray): #262930
6+
7+
figure.facecolor : none
8+
axes.labelcolor : f0f1f2
9+
axes.facecolor : none
10+
axes.edgecolor : 414851
11+
xtick.color : f0f1f2
12+
ytick.color : f0f1f2
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Light-theme napari colour scheme for matplotlib plots
2+
3+
# text (very dark grey - almost black): #3b3a39
4+
# foreground (mid grey): #d6d0ce
5+
# background (brownish beige): #efebe9
6+
7+
figure.facecolor : none
8+
axes.labelcolor : 3b3a39
9+
axes.facecolor : none
10+
axes.edgecolor : d6d0ce
11+
xtick.color : 3b3a39
12+
ytick.color : 3b3a39

0 commit comments

Comments
 (0)