Skip to content

Commit 0e04f2f

Browse files
authored
Merge pull request #90 from hausmanns/seb/matplotlib_keypoint_viewer
2 parents 13a8050 + ff625fa commit 0e04f2f

26 files changed

+275
-25
lines changed

.github/workflows/plugin_preview.yml

Lines changed: 0 additions & 21 deletions
This file was deleted.

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
include LICENSE
22
include README.md
33
include src/napari_deeplabcut/assets/*.svg
4+
include src/napari_deeplabcut/styles/*.mplstyle
45

56
recursive-exclude * __pycache__
67
recursive-exclude * *.py[co]

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ project_urls =
3333
packages = find:
3434
install_requires =
3535
dask-image
36+
matplotlib>=3.3
3637
napari==0.4.18
3738
natsort
3839
numpy

src/napari_deeplabcut/_widgets.py

Lines changed: 249 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
from datetime import datetime
55
from functools import partial, cached_property
66
from math import ceil, log10
7+
import matplotlib.style as mplstyle
8+
import napari
79
import pandas as pd
810
from pathlib import Path
911
from types import MethodType
1012
from typing import Optional, Sequence, Union
1113

14+
from matplotlib.backends.backend_qtagg import FigureCanvas, NavigationToolbar2QT
15+
1216
import numpy as np
1317
from napari._qt.widgets.qt_welcome import QtWelcomeLabel
1418
from napari.layers import Image, Points, Shapes, Tracks
@@ -34,6 +38,7 @@
3438
QRadioButton,
3539
QScrollArea,
3640
QSizePolicy,
41+
QSlider,
3742
QStyle,
3843
QStyleOption,
3944
QVBoxLayout,
@@ -291,6 +296,221 @@ def on_close(self, event, widget):
291296
event.accept()
292297

293298

299+
# Class taken from https://github.com/matplotlib/napari-matplotlib/blob/53aa5ec95c1f3901e21dedce8347d3f95efe1f79/src/napari_matplotlib/base.py#L309
300+
class NapariNavigationToolbar(NavigationToolbar2QT):
301+
"""Custom Toolbar style for Napari."""
302+
303+
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
304+
super().__init__(*args, **kwargs)
305+
self.setIconSize(QSize(28, 28))
306+
307+
def _update_buttons_checked(self) -> None:
308+
"""Update toggle tool icons when selected/unselected."""
309+
super()._update_buttons_checked()
310+
icon_dir = self.parentWidget()._get_path_to_icon()
311+
312+
# changes pan/zoom icons depending on state (checked or not)
313+
if "pan" in self._actions:
314+
if self._actions["pan"].isChecked():
315+
self._actions["pan"].setIcon(
316+
QIcon(os.path.join(icon_dir, "Pan_checked.png"))
317+
)
318+
else:
319+
self._actions["pan"].setIcon(
320+
QIcon(os.path.join(icon_dir, "Pan.png"))
321+
)
322+
if "zoom" in self._actions:
323+
if self._actions["zoom"].isChecked():
324+
self._actions["zoom"].setIcon(
325+
QIcon(os.path.join(icon_dir, "Zoom_checked.png"))
326+
)
327+
else:
328+
self._actions["zoom"].setIcon(
329+
QIcon(os.path.join(icon_dir, "Zoom.png"))
330+
)
331+
332+
333+
class KeypointMatplotlibCanvas(QWidget):
334+
"""
335+
Class about matplotlib canvas in which I will draw the keypoints over a range of frames
336+
It will be at the bottom of the screen and will use the keypoints from the range of frames to plot them on a x-y time series.
337+
"""
338+
339+
def __init__(self, napari_viewer, parent=None):
340+
super().__init__(parent=parent)
341+
342+
self.viewer = napari_viewer
343+
with mplstyle.context(self.mpl_style_sheet_path):
344+
self.canvas = FigureCanvas()
345+
self.canvas.figure.set_layout_engine("constrained")
346+
self.ax = self.canvas.figure.subplots()
347+
self.toolbar = NapariNavigationToolbar(self.canvas, parent=self)
348+
self._replace_toolbar_icons()
349+
self.canvas.mpl_connect("button_press_event", self.on_doubleclick)
350+
self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--")
351+
self.ax.set_xlabel("Frame")
352+
self.ax.set_ylabel("Y position")
353+
# Add a slot to specify the range of frames to plot
354+
self.slider = QSlider(Qt.Horizontal)
355+
self.slider.setMinimum(50)
356+
self.slider.setMaximum(10000)
357+
self.slider.setValue(50)
358+
self.slider.setTickPosition(QSlider.TicksBelow)
359+
self.slider.setTickInterval(50)
360+
self.slider_value = QLabel(str(self.slider.value()))
361+
self._window = self.slider.value()
362+
# Connect slider to window setter
363+
self.slider.valueChanged.connect(self.set_window)
364+
365+
layout = QVBoxLayout()
366+
layout.addWidget(self.canvas)
367+
layout.addWidget(self.toolbar)
368+
layout2 = QHBoxLayout()
369+
layout2.addWidget(self.slider)
370+
layout2.addWidget(self.slider_value)
371+
372+
layout.addLayout(layout2)
373+
self.setLayout(layout)
374+
375+
self.frames = []
376+
self.keypoints = []
377+
self.df = None
378+
# Make widget larger
379+
self.setMinimumHeight(300)
380+
# connect sliders to update plot
381+
self.viewer.dims.events.current_step.connect(self.update_plot_range)
382+
383+
# Run update plot range once to initialize the plot
384+
self._n = 0
385+
self.update_plot_range(
386+
Event(type_name="", value=[self.viewer.dims.current_step[0]])
387+
)
388+
389+
self.viewer.layers.events.inserted.connect(self._load_dataframe)
390+
self._lines = {}
391+
392+
def on_doubleclick(self, event):
393+
if event.dblclick:
394+
show = list(self._lines.values())[0][0].get_visible()
395+
for lines in self._lines.values():
396+
for l in lines:
397+
l.set_visible(not show)
398+
self._refresh_canvas(value=self._n)
399+
400+
def _napari_theme_has_light_bg(self) -> bool:
401+
"""
402+
Does this theme have a light background?
403+
404+
Returns
405+
-------
406+
bool
407+
True if theme's background colour has hsl lighter than 50%, False if darker.
408+
"""
409+
theme = napari.utils.theme.get_theme(self.viewer.theme, as_dict=False)
410+
_, _, bg_lightness = theme.background.as_hsl_tuple()
411+
return bg_lightness > 0.5
412+
413+
@property
414+
def mpl_style_sheet_path(self) -> Path:
415+
"""
416+
Path to the set Matplotlib style sheet.
417+
"""
418+
if self._napari_theme_has_light_bg():
419+
return Path(__file__).parent / "styles" / "light.mplstyle"
420+
else:
421+
return Path(__file__).parent / "styles" / "dark.mplstyle"
422+
423+
def _get_path_to_icon(self) -> Path:
424+
"""
425+
Get the icons directory (which is theme-dependent).
426+
427+
Icons modified from
428+
https://github.com/matplotlib/matplotlib/tree/main/lib/matplotlib/mpl-data/images
429+
"""
430+
icon_root = Path(__file__).parent / "assets"
431+
if self._napari_theme_has_light_bg():
432+
return icon_root / "black"
433+
else:
434+
return icon_root / "white"
435+
436+
def _replace_toolbar_icons(self) -> None:
437+
"""
438+
Modifies toolbar icons to match the napari theme, and add some tooltips.
439+
"""
440+
icon_dir = self._get_path_to_icon()
441+
for action in self.toolbar.actions():
442+
text = action.text()
443+
if text == "Pan":
444+
action.setToolTip(
445+
"Pan/Zoom: Left button pans; Right button zooms; "
446+
"Click once to activate; Click again to deactivate"
447+
)
448+
if text == "Zoom":
449+
action.setToolTip(
450+
"Zoom to rectangle; Click once to activate; "
451+
"Click again to deactivate"
452+
)
453+
if len(text) > 0: # i.e. not a separator item
454+
icon_path = os.path.join(icon_dir, text + ".png")
455+
action.setIcon(QIcon(icon_path))
456+
457+
def _load_dataframe(self):
458+
points_layer = None
459+
for layer in self.viewer.layers:
460+
if isinstance(layer, Points):
461+
points_layer = layer
462+
break
463+
464+
if points_layer is None:
465+
return
466+
467+
self.viewer.window.add_dock_widget(self, name="Trajectory plot", area="right")
468+
self.hide()
469+
470+
self.df = _form_df(
471+
points_layer.data,
472+
{
473+
"metadata": points_layer.metadata,
474+
"properties": points_layer.properties,
475+
},
476+
)
477+
for keypoint in self.df.columns.get_level_values("bodyparts").unique():
478+
y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"])
479+
x = np.arange(len(y))
480+
color = points_layer.metadata["face_color_cycles"]["label"][keypoint]
481+
lines = self.ax.plot(x, y, color=color, label=keypoint)
482+
self._lines[keypoint] = lines
483+
484+
self._refresh_canvas(value=self._n)
485+
486+
def _toggle_line_visibility(self, keypoint):
487+
for artist in self._lines[keypoint]:
488+
artist.set_visible(not artist.get_visible())
489+
self._refresh_canvas(value=self._n)
490+
491+
def _refresh_canvas(self, value):
492+
start = max(0, value - self._window // 2)
493+
end = min(value + self._window // 2, len(self.df))
494+
495+
self.ax.set_xlim(start, end)
496+
self.vline.set_xdata(value)
497+
self.canvas.draw()
498+
499+
def set_window(self, value):
500+
self._window = value
501+
self.slider_value.setText(str(value))
502+
self.update_plot_range(Event(type_name="", value=[self._n]))
503+
504+
def update_plot_range(self, event):
505+
value = event.value[0]
506+
self._n = value
507+
508+
if self.df is None:
509+
return
510+
511+
self._refresh_canvas(value)
512+
513+
294514
class KeypointControls(QWidget):
295515
def __init__(self, napari_viewer):
296516
super().__init__()
@@ -354,10 +574,19 @@ def __init__(self, napari_viewer):
354574
self._trail_cb.stateChanged.connect(self._show_trails)
355575
self._trails = None
356576

577+
matplotlib_label = QLabel("Show matplotlib canvas")
578+
self._matplotlib_canvas = KeypointMatplotlibCanvas(self.viewer)
579+
self._matplotlib_cb = QCheckBox()
580+
self._matplotlib_cb.setToolTip("toggle matplotlib canvas visibility")
581+
self._matplotlib_cb.stateChanged.connect(self._show_matplotlib_canvas)
582+
self._matplotlib_cb.setChecked(False)
583+
self._matplotlib_cb.setEnabled(False)
357584
self._view_scheme_cb = QCheckBox("Show color scheme", parent=self)
358585

359-
hlayout.addWidget(trail_label)
586+
hlayout.addWidget(self._matplotlib_cb)
587+
hlayout.addWidget(matplotlib_label)
360588
hlayout.addWidget(self._trail_cb)
589+
hlayout.addWidget(trail_label)
361590
hlayout.addWidget(self._view_scheme_cb)
362591

363592
self._layout.addLayout(hlayout)
@@ -368,6 +597,11 @@ def __init__(self, napari_viewer):
368597
self._color_scheme_display = self._form_color_scheme_display(self.viewer)
369598
self._view_scheme_cb.toggled.connect(self._show_color_scheme)
370599
self._view_scheme_cb.toggle()
600+
self._display.added.connect(
601+
lambda w: w.part_label.clicked.connect(
602+
self._matplotlib_canvas._toggle_line_visibility
603+
),
604+
)
371605

372606
# Substitute default menu action with custom one
373607
for action in self.viewer.window.file_menu.actions()[::-1]:
@@ -428,6 +662,12 @@ def _show_trails(self, state):
428662
elif self._trails is not None:
429663
self._trails.visible = False
430664

665+
def _show_matplotlib_canvas(self, state):
666+
if state == Qt.Checked:
667+
self._matplotlib_canvas.show()
668+
else:
669+
self._matplotlib_canvas.hide()
670+
431671
def _form_video_action_menu(self):
432672
group_box = QGroupBox("Video")
433673
layout = QVBoxLayout()
@@ -681,6 +921,7 @@ def on_insert(self, event):
681921
}
682922
)
683923
self._trail_cb.setEnabled(True)
924+
self._matplotlib_cb.setEnabled(True)
684925

685926
# Hide the color pickers, as colormaps are strictly defined by users
686927
controls = self.viewer.window.qt_viewer.dockLayerControls
@@ -710,6 +951,7 @@ def on_remove(self, event):
710951
menu.deleteLater()
711952
menu.destroy()
712953
self._trail_cb.setEnabled(False)
954+
self._matplotlib_cb.setEnabled(False)
713955
self.last_saved_label.hide()
714956
elif isinstance(layer, Image):
715957
self._images_meta = dict()
@@ -718,6 +960,7 @@ def on_remove(self, event):
718960
self.video_widget.setVisible(False)
719961
elif isinstance(layer, Tracks):
720962
self._trail_cb.setChecked(False)
963+
self._matplotlib_cb.setChecked(False)
721964
self._trails = None
722965

723966
@register_points_action("Change labeling mode")
@@ -1065,6 +1308,8 @@ def part_name(self, part_name: str):
10651308

10661309

10671310
class ColorSchemeDisplay(QScrollArea):
1311+
added = Signal(object)
1312+
10681313
def __init__(self, parent):
10691314
super().__init__(parent)
10701315

@@ -1108,9 +1353,9 @@ def _build(self):
11081353
def add_entry(self, name, color):
11091354
self.scheme_dict.update({name: color})
11101355

1111-
self._layout.addWidget(
1112-
LabelPair(color, name, self), alignment=Qt.AlignmentFlag.AlignLeft
1113-
)
1356+
widget = LabelPair(color, name, self)
1357+
self._layout.addWidget(widget, alignment=Qt.AlignmentFlag.AlignLeft)
1358+
self.added.emit(widget)
11141359

11151360
def reset(self):
11161361
self.scheme_dict = {}
6.76 KB
Loading
7.08 KB
Loading
6.65 KB
Loading
7.24 KB
Loading
7.14 KB
Loading
12.1 KB
Loading

0 commit comments

Comments
 (0)