From fe36a98b4de70b1194d9528fabfa9988424ff211 Mon Sep 17 00:00:00 2001 From: hausmanns Date: Mon, 21 Aug 2023 14:16:39 +0200 Subject: [PATCH 1/3] Added keypoint matplotlib viewer --- src/napari_deeplabcut/_widgets.py | 123 +++++++++++++++++++++++++++++- 1 file changed, 122 insertions(+), 1 deletion(-) diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index 8111067..5311c49 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -9,6 +9,10 @@ from types import MethodType from typing import Optional, Sequence, Union +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.figure import Figure +from PyQt5.QtWidgets import QSlider + import numpy as np from napari._qt.widgets.qt_welcome import QtWelcomeLabel from napari.layers import Image, Points, Shapes, Tracks @@ -290,6 +294,99 @@ def on_close(self, event, widget): else: event.accept() +class KeypointMatplotlibCanvas(QWidget): + """ + Class about matplotlib canvas in which I will draw the keypoints over a range of frames + It will be at the bottom of the screen and will use the keypoints from the range of frames to plot them on a x-y time series. + """ + def __init__(self, napari_viewer): + super().__init__() + + self.viewer = napari_viewer + self.figure = Figure() + self.canvas = FigureCanvas(self.figure) + self.ax = self.figure.add_subplot(111) + self.vline = self.ax.axvline(0,0,1, color='k', linestyle='--') + self.ax.set_xlabel('Frame') + self.ax.set_ylabel('Y position') + # Add a slot to specify the range of frames to plot + self.slider = QSlider(Qt.Horizontal) + self.slider.setMinimum(50) + self.slider.setMaximum(10000) + self.slider.setValue(50) + self.slider.setTickPosition(QSlider.TicksBelow) + self.slider.setTickInterval(50) + self.slider_value = QLabel(str(self.slider.value())) + self._window = self.slider.value() + # Connect slider to window setter + self.slider.valueChanged.connect(self.set_window) + + layout = QVBoxLayout() + layout.addWidget(self.canvas) + layout2 = QHBoxLayout() + layout2.addWidget(self.slider) + layout2.addWidget(self.slider_value) + + layout.addLayout(layout2) + self.setLayout(layout) + + self.frames = [] + self.keypoints = [] + self.df = None + # Make widget larger + self.setMinimumHeight(300) + # connect sliders to update plot + self.viewer.dims.events.current_step.connect(self.update_plot_range) + + # Run update plot range once to initialize the plot + self.update_plot_range(Event(type_name='',value=[self.viewer.dims.current_step[0]])) + + def set_window(self, value): + self._window = value + self.slider_value.setText(str(value)) + self.update_plot_range(Event(type_name='',value=[self.viewer.dims.current_step[0]])) + + + def update_plot_range(self, event): + + value = event.value[0] + if self.df is None: + points_layer = None + for layer in self.viewer.layers: + if isinstance(layer, Points): + points_layer = layer + break + + if points_layer is None: + return + + self.df = _form_df( + points_layer.data, + { + "metadata": points_layer.metadata, + "properties": points_layer.properties, + }, + ) + + # Find the bodyparts names + bodyparts = self.df.columns.get_level_values('bodyparts').unique() + # Get only the body parts that contain the word limb in them + limb_bodyparts = [limb for limb in bodyparts if 'limb' in limb.lower()] + + for limb in limb_bodyparts: + y = self.df.xs((limb, 'y'), axis=1, level=['bodyparts', 'coords']) + x = np.arange(len(y)) + # color by limb colormap using point layer metadata + color = points_layer.metadata['face_color_cycles']['label'][limb] + self.ax.plot(x, y, color=color, label=limb) + + start = max(0, value-self._window//2) + end = min(value + self._window//2, len(self.df)) + + self.ax.set_xlim(start, end) + self.vline.set_xdata(value) + + self.canvas.draw_idle() class KeypointControls(QWidget): def __init__(self, napari_viewer): @@ -354,10 +451,19 @@ def __init__(self, napari_viewer): self._trail_cb.stateChanged.connect(self._show_trails) self._trails = None + matplotlib_label = QLabel("Show matplotlib canvas") + self._matplotlib_cb = QCheckBox() + self._matplotlib_cb.setToolTip("toggle matplotlib canvas visibility") + self._matplotlib_cb.setChecked(False) + self._matplotlib_cb.setEnabled(False) + self._matplotlib_cb.stateChanged.connect(self._show_matplotlib_canvas) + self._matplotlib_canvas = None self._view_scheme_cb = QCheckBox("Show color scheme", parent=self) - hlayout.addWidget(trail_label) + hlayout.addWidget(self._matplotlib_cb) + hlayout.addWidget(matplotlib_label) hlayout.addWidget(self._trail_cb) + hlayout.addWidget(trail_label) hlayout.addWidget(self._view_scheme_cb) self._layout.addLayout(hlayout) @@ -394,6 +500,10 @@ def __init__(self, napari_viewer): QTimer.singleShot(10, self.start_tutorial) self.settings.setValue("first_launch", False) + matplotlib_widget = KeypointMatplotlibCanvas(self.viewer) + matplotlib_widget.setVisible(False) + + @cached_property def settings(self): return QSettings() @@ -427,6 +537,14 @@ def _show_trails(self, state): self._trails.visible = True elif self._trails is not None: self._trails.visible = False + + def _show_matplotlib_canvas(self, state): + if state == Qt.Checked: + self._canvas = KeypointMatplotlibCanvas(self.viewer) + self.viewer.window.add_dock_widget(self._canvas, name="Trajectory plot", area="bottom") + self._canvas.show() + else: + self._canvas.close() def _form_video_action_menu(self): group_box = QGroupBox("Video") @@ -681,6 +799,7 @@ def on_insert(self, event): } ) self._trail_cb.setEnabled(True) + self._matplotlib_cb.setEnabled(True) # Hide the color pickers, as colormaps are strictly defined by users controls = self.viewer.window.qt_viewer.dockLayerControls @@ -710,6 +829,7 @@ def on_remove(self, event): menu.deleteLater() menu.destroy() self._trail_cb.setEnabled(False) + self._matplotlib_cb.setEnabled(False) self.last_saved_label.hide() elif isinstance(layer, Image): self._images_meta = dict() @@ -718,6 +838,7 @@ def on_remove(self, event): self.video_widget.setVisible(False) elif isinstance(layer, Tracks): self._trail_cb.setChecked(False) + self._matplotlib_cb.setChecked(False) self._trails = None @register_points_action("Change labeling mode") From f5e3bd3893b37b58ab17af99520563bdeaeeb650 Mon Sep 17 00:00:00 2001 From: hausmanns Date: Mon, 21 Aug 2023 14:24:00 +0200 Subject: [PATCH 2/3] Added skeleton display, still hardcoded --- src/napari_deeplabcut/_reader.py | 24 ++++++++++++++++++++++++ src/napari_deeplabcut/_widgets.py | 21 +++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/src/napari_deeplabcut/_reader.py b/src/napari_deeplabcut/_reader.py index 5a9754b..4464c3e 100644 --- a/src/napari_deeplabcut/_reader.py +++ b/src/napari_deeplabcut/_reader.py @@ -227,7 +227,31 @@ def read_hdf(filename: str) -> List[LayerData]: metadata["metadata"]["root"] = os.path.split(filename)[0] # Store file name in case the layer's name is edited by the user metadata["metadata"]["name"] = metadata["name"] + + limb_pairs = [['nose','LeftForelimb'], + ['nose', 'RightForelimb'], + ['LeftForelimb', 'RightForelimb'], + ['LeftHindlimb', 'RightHindlimb'], + ['TailBase', 'LeftHindlimb'], + ['TailBase', 'RightHindlimb'], + ['TailBase', 'Tail1'], + ['Tail1', 'Tail2'], + ['Tail2', 'Tail3'] + ] + n = temp.shape[0] + vectors = np.zeros((n*len(limb_pairs), 2, 3)) + for i, (kpt1, kpt2) in enumerate(limb_pairs): + origin = temp.xs(kpt1, level='bodyparts', axis=1).to_numpy()[:,:2] + end = temp.xs(kpt2, level='bodyparts', axis=1).to_numpy()[:,:2] + vec = end-origin + vectors[i*n:(i+1)*n, 0, [2,1]] = origin + vectors[i*n:(i+1)*n, 1, [2,1]] = vec + vectors[i*n:(i+1)*n, :, 0] = np.arange(temp.shape[0])[:, None] + layers.append((data, metadata, "points")) + layers.append((vectors, {'edge_width':1, 'edge_color':'yellow'}, "vectors")) + + return layers diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index 5311c49..69ae4f7 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -458,10 +458,22 @@ def __init__(self, napari_viewer): self._matplotlib_cb.setEnabled(False) self._matplotlib_cb.stateChanged.connect(self._show_matplotlib_canvas) self._matplotlib_canvas = None + + # Add checkbox to show skeleton + skeleton_label = QLabel("Show skeleton") + self._skeleton_cb = QCheckBox() + self._skeleton_cb.setToolTip("toggle skeleton visibility") + self._skeleton_cb.setChecked(False) + self._skeleton_cb.setEnabled(False) + self._skeleton_cb.stateChanged.connect(self._show_skeleton) + self._skeleton = None + self._view_scheme_cb = QCheckBox("Show color scheme", parent=self) hlayout.addWidget(self._matplotlib_cb) hlayout.addWidget(matplotlib_label) + hlayout.addWidget(self._skeleton_cb) + hlayout.addWidget(skeleton_label) hlayout.addWidget(self._trail_cb) hlayout.addWidget(trail_label) hlayout.addWidget(self._view_scheme_cb) @@ -546,6 +558,12 @@ def _show_matplotlib_canvas(self, state): else: self._canvas.close() + def _show_skeleton(self, state): + if state == Qt.Checked: + # Check if "skeleton" and "skeleton_color" are in the config.yaml metadata + return True + + def _form_video_action_menu(self): group_box = QGroupBox("Video") layout = QVBoxLayout() @@ -800,6 +818,7 @@ def on_insert(self, event): ) self._trail_cb.setEnabled(True) self._matplotlib_cb.setEnabled(True) + self._skeleton_cb.setEnabled(True) # Hide the color pickers, as colormaps are strictly defined by users controls = self.viewer.window.qt_viewer.dockLayerControls @@ -830,6 +849,7 @@ def on_remove(self, event): menu.destroy() self._trail_cb.setEnabled(False) self._matplotlib_cb.setEnabled(False) + self._skeleton_cb.setEnabled(False) self.last_saved_label.hide() elif isinstance(layer, Image): self._images_meta = dict() @@ -839,6 +859,7 @@ def on_remove(self, event): elif isinstance(layer, Tracks): self._trail_cb.setChecked(False) self._matplotlib_cb.setChecked(False) + self._skeleton_cb.setChecked(False) self._trails = None @register_points_action("Change labeling mode") From f1757febfc388bc8348279c6facc7ede2b189c35 Mon Sep 17 00:00:00 2001 From: hausmanns Date: Fri, 25 Aug 2023 18:36:07 +0200 Subject: [PATCH 3/3] Removed functions from matplotlib branch --- src/napari_deeplabcut/_reader.py | 3 + src/napari_deeplabcut/_widgets.py | 120 ------------------------------ 2 files changed, 3 insertions(+), 120 deletions(-) diff --git a/src/napari_deeplabcut/_reader.py b/src/napari_deeplabcut/_reader.py index 4464c3e..ca06849 100644 --- a/src/napari_deeplabcut/_reader.py +++ b/src/napari_deeplabcut/_reader.py @@ -249,6 +249,9 @@ def read_hdf(filename: str) -> List[LayerData]: vectors[i*n:(i+1)*n, :, 0] = np.arange(temp.shape[0])[:, None] layers.append((data, metadata, "points")) + # Make a dictionary of the colors of the bodyparts based on the colormap + # and the bodyparts + layers.append((vectors, {'edge_width':1, 'edge_color':'yellow'}, "vectors")) diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index 69ae4f7..608378c 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -9,10 +9,6 @@ from types import MethodType from typing import Optional, Sequence, Union -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.figure import Figure -from PyQt5.QtWidgets import QSlider - import numpy as np from napari._qt.widgets.qt_welcome import QtWelcomeLabel from napari.layers import Image, Points, Shapes, Tracks @@ -294,100 +290,6 @@ def on_close(self, event, widget): else: event.accept() -class KeypointMatplotlibCanvas(QWidget): - """ - Class about matplotlib canvas in which I will draw the keypoints over a range of frames - It will be at the bottom of the screen and will use the keypoints from the range of frames to plot them on a x-y time series. - """ - def __init__(self, napari_viewer): - super().__init__() - - self.viewer = napari_viewer - self.figure = Figure() - self.canvas = FigureCanvas(self.figure) - self.ax = self.figure.add_subplot(111) - self.vline = self.ax.axvline(0,0,1, color='k', linestyle='--') - self.ax.set_xlabel('Frame') - self.ax.set_ylabel('Y position') - # Add a slot to specify the range of frames to plot - self.slider = QSlider(Qt.Horizontal) - self.slider.setMinimum(50) - self.slider.setMaximum(10000) - self.slider.setValue(50) - self.slider.setTickPosition(QSlider.TicksBelow) - self.slider.setTickInterval(50) - self.slider_value = QLabel(str(self.slider.value())) - self._window = self.slider.value() - # Connect slider to window setter - self.slider.valueChanged.connect(self.set_window) - - layout = QVBoxLayout() - layout.addWidget(self.canvas) - layout2 = QHBoxLayout() - layout2.addWidget(self.slider) - layout2.addWidget(self.slider_value) - - layout.addLayout(layout2) - self.setLayout(layout) - - self.frames = [] - self.keypoints = [] - self.df = None - # Make widget larger - self.setMinimumHeight(300) - # connect sliders to update plot - self.viewer.dims.events.current_step.connect(self.update_plot_range) - - # Run update plot range once to initialize the plot - self.update_plot_range(Event(type_name='',value=[self.viewer.dims.current_step[0]])) - - def set_window(self, value): - self._window = value - self.slider_value.setText(str(value)) - self.update_plot_range(Event(type_name='',value=[self.viewer.dims.current_step[0]])) - - - def update_plot_range(self, event): - - value = event.value[0] - if self.df is None: - points_layer = None - for layer in self.viewer.layers: - if isinstance(layer, Points): - points_layer = layer - break - - if points_layer is None: - return - - self.df = _form_df( - points_layer.data, - { - "metadata": points_layer.metadata, - "properties": points_layer.properties, - }, - ) - - # Find the bodyparts names - bodyparts = self.df.columns.get_level_values('bodyparts').unique() - # Get only the body parts that contain the word limb in them - limb_bodyparts = [limb for limb in bodyparts if 'limb' in limb.lower()] - - for limb in limb_bodyparts: - y = self.df.xs((limb, 'y'), axis=1, level=['bodyparts', 'coords']) - x = np.arange(len(y)) - # color by limb colormap using point layer metadata - color = points_layer.metadata['face_color_cycles']['label'][limb] - self.ax.plot(x, y, color=color, label=limb) - - start = max(0, value-self._window//2) - end = min(value + self._window//2, len(self.df)) - - self.ax.set_xlim(start, end) - self.vline.set_xdata(value) - - self.canvas.draw_idle() - class KeypointControls(QWidget): def __init__(self, napari_viewer): super().__init__() @@ -451,14 +353,6 @@ def __init__(self, napari_viewer): self._trail_cb.stateChanged.connect(self._show_trails) self._trails = None - matplotlib_label = QLabel("Show matplotlib canvas") - self._matplotlib_cb = QCheckBox() - self._matplotlib_cb.setToolTip("toggle matplotlib canvas visibility") - self._matplotlib_cb.setChecked(False) - self._matplotlib_cb.setEnabled(False) - self._matplotlib_cb.stateChanged.connect(self._show_matplotlib_canvas) - self._matplotlib_canvas = None - # Add checkbox to show skeleton skeleton_label = QLabel("Show skeleton") self._skeleton_cb = QCheckBox() @@ -470,8 +364,6 @@ def __init__(self, napari_viewer): self._view_scheme_cb = QCheckBox("Show color scheme", parent=self) - hlayout.addWidget(self._matplotlib_cb) - hlayout.addWidget(matplotlib_label) hlayout.addWidget(self._skeleton_cb) hlayout.addWidget(skeleton_label) hlayout.addWidget(self._trail_cb) @@ -512,8 +404,6 @@ def __init__(self, napari_viewer): QTimer.singleShot(10, self.start_tutorial) self.settings.setValue("first_launch", False) - matplotlib_widget = KeypointMatplotlibCanvas(self.viewer) - matplotlib_widget.setVisible(False) @cached_property @@ -550,13 +440,6 @@ def _show_trails(self, state): elif self._trails is not None: self._trails.visible = False - def _show_matplotlib_canvas(self, state): - if state == Qt.Checked: - self._canvas = KeypointMatplotlibCanvas(self.viewer) - self.viewer.window.add_dock_widget(self._canvas, name="Trajectory plot", area="bottom") - self._canvas.show() - else: - self._canvas.close() def _show_skeleton(self, state): if state == Qt.Checked: @@ -817,7 +700,6 @@ def on_insert(self, event): } ) self._trail_cb.setEnabled(True) - self._matplotlib_cb.setEnabled(True) self._skeleton_cb.setEnabled(True) # Hide the color pickers, as colormaps are strictly defined by users @@ -848,7 +730,6 @@ def on_remove(self, event): menu.deleteLater() menu.destroy() self._trail_cb.setEnabled(False) - self._matplotlib_cb.setEnabled(False) self._skeleton_cb.setEnabled(False) self.last_saved_label.hide() elif isinstance(layer, Image): @@ -858,7 +739,6 @@ def on_remove(self, event): self.video_widget.setVisible(False) elif isinstance(layer, Tracks): self._trail_cb.setChecked(False) - self._matplotlib_cb.setChecked(False) self._skeleton_cb.setChecked(False) self._trails = None