Skip to content

Commit c814d00

Browse files
committed
Minor fixes
1 parent 0fb707f commit c814d00

File tree

2 files changed

+30
-23
lines changed

2 files changed

+30
-23
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ project_urls =
3333
packages = find:
3434
install_requires =
3535
dask-image
36+
matplotlib
3637
napari==0.4.18
3738
natsort
3839
numpy

src/napari_deeplabcut/_widgets.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
from types import MethodType
1010
from typing import Optional, Sequence, Union
1111

12-
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
12+
from matplotlib.backends.backend_qtagg import FigureCanvas
1313
from matplotlib.figure import Figure
14-
from PyQt5.QtWidgets import QSlider
1514

1615
import numpy as np
1716
from napari._qt.widgets.qt_welcome import QtWelcomeLabel
@@ -38,6 +37,7 @@
3837
QRadioButton,
3938
QScrollArea,
4039
QSizePolicy,
40+
QSlider,
4141
QStyle,
4242
QStyleOption,
4343
QVBoxLayout,
@@ -294,21 +294,23 @@ def on_close(self, event, widget):
294294
else:
295295
event.accept()
296296

297+
297298
class KeypointMatplotlibCanvas(QWidget):
298299
"""
299300
Class about matplotlib canvas in which I will draw the keypoints over a range of frames
300301
It will be at the bottom of the screen and will use the keypoints from the range of frames to plot them on a x-y time series.
301302
"""
302-
def __init__(self, napari_viewer):
303-
super().__init__()
303+
304+
def __init__(self, napari_viewer, parent=None):
305+
super().__init__(parent=parent)
304306

305307
self.viewer = napari_viewer
306308
self.figure = Figure()
307309
self.canvas = FigureCanvas(self.figure)
308310
self.ax = self.figure.add_subplot(111)
309-
self.vline = self.ax.axvline(0,0,1, color='k', linestyle='--')
310-
self.ax.set_xlabel('Frame')
311-
self.ax.set_ylabel('Y position')
311+
self.vline = self.ax.axvline(0, 0, 1, color="k", linestyle="--")
312+
self.ax.set_xlabel("Frame")
313+
self.ax.set_ylabel("Y position")
312314
# Add a slot to specify the range of frames to plot
313315
self.slider = QSlider(Qt.Horizontal)
314316
self.slider.setMinimum(50)
@@ -339,24 +341,26 @@ def __init__(self, napari_viewer):
339341
self.viewer.dims.events.current_step.connect(self.update_plot_range)
340342

341343
# Run update plot range once to initialize the plot
342-
self.update_plot_range(Event(type_name='',value=[self.viewer.dims.current_step[0]]))
343-
344+
self.update_plot_range(
345+
Event(type_name="", value=[self.viewer.dims.current_step[0]])
346+
)
347+
344348
def set_window(self, value):
345349
self._window = value
346350
self.slider_value.setText(str(value))
347-
self.update_plot_range(Event(type_name='',value=[self.viewer.dims.current_step[0]]))
348-
351+
self.update_plot_range(
352+
Event(type_name="", value=[self.viewer.dims.current_step[0]])
353+
)
349354

350355
def update_plot_range(self, event):
351-
352356
value = event.value[0]
353357
if self.df is None:
354358
points_layer = None
355359
for layer in self.viewer.layers:
356360
if isinstance(layer, Points):
357361
points_layer = layer
358362
break
359-
363+
360364
if points_layer is None:
361365
return
362366

@@ -369,25 +373,26 @@ def update_plot_range(self, event):
369373
)
370374

371375
# Find the bodyparts names
372-
bodyparts = self.df.columns.get_level_values('bodyparts').unique()
376+
bodyparts = self.df.columns.get_level_values("bodyparts").unique()
373377
# Get only the body parts that contain the word limb in them
374-
limb_bodyparts = [limb for limb in bodyparts if 'limb' in limb.lower()]
378+
limb_bodyparts = [limb for limb in bodyparts if "limb" in limb.lower()]
375379

376380
for limb in limb_bodyparts:
377-
y = self.df.xs((limb, 'y'), axis=1, level=['bodyparts', 'coords'])
381+
y = self.df.xs((limb, "y"), axis=1, level=["bodyparts", "coords"])
378382
x = np.arange(len(y))
379383
# color by limb colormap using point layer metadata
380-
color = points_layer.metadata['face_color_cycles']['label'][limb]
384+
color = points_layer.metadata["face_color_cycles"]["label"][limb]
381385
self.ax.plot(x, y, color=color, label=limb)
382386

383-
start = max(0, value-self._window//2)
384-
end = min(value + self._window//2, len(self.df))
385-
387+
start = max(0, value - self._window // 2)
388+
end = min(value + self._window // 2, len(self.df))
389+
386390
self.ax.set_xlim(start, end)
387391
self.vline.set_xdata(value)
388392

389393
self.canvas.draw_idle()
390394

395+
391396
class KeypointControls(QWidget):
392397
def __init__(self, napari_viewer):
393398
super().__init__()
@@ -503,7 +508,6 @@ def __init__(self, napari_viewer):
503508
matplotlib_widget = KeypointMatplotlibCanvas(self.viewer)
504509
matplotlib_widget.setVisible(False)
505510

506-
507511
@cached_property
508512
def settings(self):
509513
return QSettings()
@@ -537,11 +541,13 @@ def _show_trails(self, state):
537541
self._trails.visible = True
538542
elif self._trails is not None:
539543
self._trails.visible = False
540-
544+
541545
def _show_matplotlib_canvas(self, state):
542546
if state == Qt.Checked:
543547
self._canvas = KeypointMatplotlibCanvas(self.viewer)
544-
self.viewer.window.add_dock_widget(self._canvas, name="Trajectory plot", area="bottom")
548+
self.viewer.window.add_dock_widget(
549+
self._canvas, name="Trajectory plot", area="bottom"
550+
)
545551
self._canvas.show()
546552
else:
547553
self._canvas.close()

0 commit comments

Comments
 (0)