Skip to content

Commit fe36a98

Browse files
committed
Added keypoint matplotlib viewer
1 parent 34da7a7 commit fe36a98

File tree

1 file changed

+122
-1
lines changed

1 file changed

+122
-1
lines changed

src/napari_deeplabcut/_widgets.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from types import MethodType
1010
from typing import Optional, Sequence, Union
1111

12+
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
13+
from matplotlib.figure import Figure
14+
from PyQt5.QtWidgets import QSlider
15+
1216
import numpy as np
1317
from napari._qt.widgets.qt_welcome import QtWelcomeLabel
1418
from napari.layers import Image, Points, Shapes, Tracks
@@ -290,6 +294,99 @@ def on_close(self, event, widget):
290294
else:
291295
event.accept()
292296

297+
class KeypointMatplotlibCanvas(QWidget):
298+
"""
299+
Class about matplotlib canvas in which I will draw the keypoints over a range of frames
300+
It will be at the bottom of the screen and will use the keypoints from the range of frames to plot them on a x-y time series.
301+
"""
302+
def __init__(self, napari_viewer):
303+
super().__init__()
304+
305+
self.viewer = napari_viewer
306+
self.figure = Figure()
307+
self.canvas = FigureCanvas(self.figure)
308+
self.ax = self.figure.add_subplot(111)
309+
self.vline = self.ax.axvline(0,0,1, color='k', linestyle='--')
310+
self.ax.set_xlabel('Frame')
311+
self.ax.set_ylabel('Y position')
312+
# Add a slot to specify the range of frames to plot
313+
self.slider = QSlider(Qt.Horizontal)
314+
self.slider.setMinimum(50)
315+
self.slider.setMaximum(10000)
316+
self.slider.setValue(50)
317+
self.slider.setTickPosition(QSlider.TicksBelow)
318+
self.slider.setTickInterval(50)
319+
self.slider_value = QLabel(str(self.slider.value()))
320+
self._window = self.slider.value()
321+
# Connect slider to window setter
322+
self.slider.valueChanged.connect(self.set_window)
323+
324+
layout = QVBoxLayout()
325+
layout.addWidget(self.canvas)
326+
layout2 = QHBoxLayout()
327+
layout2.addWidget(self.slider)
328+
layout2.addWidget(self.slider_value)
329+
330+
layout.addLayout(layout2)
331+
self.setLayout(layout)
332+
333+
self.frames = []
334+
self.keypoints = []
335+
self.df = None
336+
# Make widget larger
337+
self.setMinimumHeight(300)
338+
# connect sliders to update plot
339+
self.viewer.dims.events.current_step.connect(self.update_plot_range)
340+
341+
# Run update plot range once to initialize the plot
342+
self.update_plot_range(Event(type_name='',value=[self.viewer.dims.current_step[0]]))
343+
344+
def set_window(self, value):
345+
self._window = value
346+
self.slider_value.setText(str(value))
347+
self.update_plot_range(Event(type_name='',value=[self.viewer.dims.current_step[0]]))
348+
349+
350+
def update_plot_range(self, event):
351+
352+
value = event.value[0]
353+
if self.df is None:
354+
points_layer = None
355+
for layer in self.viewer.layers:
356+
if isinstance(layer, Points):
357+
points_layer = layer
358+
break
359+
360+
if points_layer is None:
361+
return
362+
363+
self.df = _form_df(
364+
points_layer.data,
365+
{
366+
"metadata": points_layer.metadata,
367+
"properties": points_layer.properties,
368+
},
369+
)
370+
371+
# Find the bodyparts names
372+
bodyparts = self.df.columns.get_level_values('bodyparts').unique()
373+
# Get only the body parts that contain the word limb in them
374+
limb_bodyparts = [limb for limb in bodyparts if 'limb' in limb.lower()]
375+
376+
for limb in limb_bodyparts:
377+
y = self.df.xs((limb, 'y'), axis=1, level=['bodyparts', 'coords'])
378+
x = np.arange(len(y))
379+
# color by limb colormap using point layer metadata
380+
color = points_layer.metadata['face_color_cycles']['label'][limb]
381+
self.ax.plot(x, y, color=color, label=limb)
382+
383+
start = max(0, value-self._window//2)
384+
end = min(value + self._window//2, len(self.df))
385+
386+
self.ax.set_xlim(start, end)
387+
self.vline.set_xdata(value)
388+
389+
self.canvas.draw_idle()
293390

294391
class KeypointControls(QWidget):
295392
def __init__(self, napari_viewer):
@@ -354,10 +451,19 @@ def __init__(self, napari_viewer):
354451
self._trail_cb.stateChanged.connect(self._show_trails)
355452
self._trails = None
356453

454+
matplotlib_label = QLabel("Show matplotlib canvas")
455+
self._matplotlib_cb = QCheckBox()
456+
self._matplotlib_cb.setToolTip("toggle matplotlib canvas visibility")
457+
self._matplotlib_cb.setChecked(False)
458+
self._matplotlib_cb.setEnabled(False)
459+
self._matplotlib_cb.stateChanged.connect(self._show_matplotlib_canvas)
460+
self._matplotlib_canvas = None
357461
self._view_scheme_cb = QCheckBox("Show color scheme", parent=self)
358462

359-
hlayout.addWidget(trail_label)
463+
hlayout.addWidget(self._matplotlib_cb)
464+
hlayout.addWidget(matplotlib_label)
360465
hlayout.addWidget(self._trail_cb)
466+
hlayout.addWidget(trail_label)
361467
hlayout.addWidget(self._view_scheme_cb)
362468

363469
self._layout.addLayout(hlayout)
@@ -394,6 +500,10 @@ def __init__(self, napari_viewer):
394500
QTimer.singleShot(10, self.start_tutorial)
395501
self.settings.setValue("first_launch", False)
396502

503+
matplotlib_widget = KeypointMatplotlibCanvas(self.viewer)
504+
matplotlib_widget.setVisible(False)
505+
506+
397507
@cached_property
398508
def settings(self):
399509
return QSettings()
@@ -427,6 +537,14 @@ def _show_trails(self, state):
427537
self._trails.visible = True
428538
elif self._trails is not None:
429539
self._trails.visible = False
540+
541+
def _show_matplotlib_canvas(self, state):
542+
if state == Qt.Checked:
543+
self._canvas = KeypointMatplotlibCanvas(self.viewer)
544+
self.viewer.window.add_dock_widget(self._canvas, name="Trajectory plot", area="bottom")
545+
self._canvas.show()
546+
else:
547+
self._canvas.close()
430548

431549
def _form_video_action_menu(self):
432550
group_box = QGroupBox("Video")
@@ -681,6 +799,7 @@ def on_insert(self, event):
681799
}
682800
)
683801
self._trail_cb.setEnabled(True)
802+
self._matplotlib_cb.setEnabled(True)
684803

685804
# Hide the color pickers, as colormaps are strictly defined by users
686805
controls = self.viewer.window.qt_viewer.dockLayerControls
@@ -710,6 +829,7 @@ def on_remove(self, event):
710829
menu.deleteLater()
711830
menu.destroy()
712831
self._trail_cb.setEnabled(False)
832+
self._matplotlib_cb.setEnabled(False)
713833
self.last_saved_label.hide()
714834
elif isinstance(layer, Image):
715835
self._images_meta = dict()
@@ -718,6 +838,7 @@ def on_remove(self, event):
718838
self.video_widget.setVisible(False)
719839
elif isinstance(layer, Tracks):
720840
self._trail_cb.setChecked(False)
841+
self._matplotlib_cb.setChecked(False)
721842
self._trails = None
722843

723844
@register_points_action("Change labeling mode")

0 commit comments

Comments
 (0)