9
9
from types import MethodType
10
10
from typing import Optional , Sequence , Union
11
11
12
- from matplotlib .backends .backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
12
+ from matplotlib .backends .backend_qtagg import FigureCanvas
13
13
from matplotlib .figure import Figure
14
- from PyQt5 .QtWidgets import QSlider
15
14
16
15
import numpy as np
17
16
from napari ._qt .widgets .qt_welcome import QtWelcomeLabel
38
37
QRadioButton ,
39
38
QScrollArea ,
40
39
QSizePolicy ,
40
+ QSlider ,
41
41
QStyle ,
42
42
QStyleOption ,
43
43
QVBoxLayout ,
@@ -294,21 +294,23 @@ def on_close(self, event, widget):
294
294
else :
295
295
event .accept ()
296
296
297
+
297
298
class KeypointMatplotlibCanvas (QWidget ):
298
299
"""
299
300
Class about matplotlib canvas in which I will draw the keypoints over a range of frames
300
301
It will be at the bottom of the screen and will use the keypoints from the range of frames to plot them on a x-y time series.
301
302
"""
302
- def __init__ (self , napari_viewer ):
303
- super ().__init__ ()
303
+
304
+ def __init__ (self , napari_viewer , parent = None ):
305
+ super ().__init__ (parent = parent )
304
306
305
307
self .viewer = napari_viewer
306
308
self .figure = Figure ()
307
309
self .canvas = FigureCanvas (self .figure )
308
310
self .ax = self .figure .add_subplot (111 )
309
- self .vline = self .ax .axvline (0 ,0 , 1 , color = 'k' , linestyle = '--' )
310
- self .ax .set_xlabel (' Frame' )
311
- self .ax .set_ylabel (' Y position' )
311
+ self .vline = self .ax .axvline (0 , 0 , 1 , color = "k" , linestyle = "--" )
312
+ self .ax .set_xlabel (" Frame" )
313
+ self .ax .set_ylabel (" Y position" )
312
314
# Add a slot to specify the range of frames to plot
313
315
self .slider = QSlider (Qt .Horizontal )
314
316
self .slider .setMinimum (50 )
@@ -339,24 +341,26 @@ def __init__(self, napari_viewer):
339
341
self .viewer .dims .events .current_step .connect (self .update_plot_range )
340
342
341
343
# Run update plot range once to initialize the plot
342
- self .update_plot_range (Event (type_name = '' ,value = [self .viewer .dims .current_step [0 ]]))
343
-
344
+ self .update_plot_range (
345
+ Event (type_name = "" , value = [self .viewer .dims .current_step [0 ]])
346
+ )
347
+
344
348
def set_window (self , value ):
345
349
self ._window = value
346
350
self .slider_value .setText (str (value ))
347
- self .update_plot_range (Event (type_name = '' ,value = [self .viewer .dims .current_step [0 ]]))
348
-
351
+ self .update_plot_range (
352
+ Event (type_name = "" , value = [self .viewer .dims .current_step [0 ]])
353
+ )
349
354
350
355
def update_plot_range (self , event ):
351
-
352
356
value = event .value [0 ]
353
357
if self .df is None :
354
358
points_layer = None
355
359
for layer in self .viewer .layers :
356
360
if isinstance (layer , Points ):
357
361
points_layer = layer
358
362
break
359
-
363
+
360
364
if points_layer is None :
361
365
return
362
366
@@ -369,25 +373,26 @@ def update_plot_range(self, event):
369
373
)
370
374
371
375
# Find the bodyparts names
372
- bodyparts = self .df .columns .get_level_values (' bodyparts' ).unique ()
376
+ bodyparts = self .df .columns .get_level_values (" bodyparts" ).unique ()
373
377
# Get only the body parts that contain the word limb in them
374
- limb_bodyparts = [limb for limb in bodyparts if ' limb' in limb .lower ()]
378
+ limb_bodyparts = [limb for limb in bodyparts if " limb" in limb .lower ()]
375
379
376
380
for limb in limb_bodyparts :
377
- y = self .df .xs ((limb , 'y' ), axis = 1 , level = [' bodyparts' , ' coords' ])
381
+ y = self .df .xs ((limb , "y" ), axis = 1 , level = [" bodyparts" , " coords" ])
378
382
x = np .arange (len (y ))
379
383
# color by limb colormap using point layer metadata
380
- color = points_layer .metadata [' face_color_cycles' ][ ' label' ][limb ]
384
+ color = points_layer .metadata [" face_color_cycles" ][ " label" ][limb ]
381
385
self .ax .plot (x , y , color = color , label = limb )
382
386
383
- start = max (0 , value - self ._window // 2 )
384
- end = min (value + self ._window // 2 , len (self .df ))
385
-
387
+ start = max (0 , value - self ._window // 2 )
388
+ end = min (value + self ._window // 2 , len (self .df ))
389
+
386
390
self .ax .set_xlim (start , end )
387
391
self .vline .set_xdata (value )
388
392
389
393
self .canvas .draw_idle ()
390
394
395
+
391
396
class KeypointControls (QWidget ):
392
397
def __init__ (self , napari_viewer ):
393
398
super ().__init__ ()
@@ -503,7 +508,6 @@ def __init__(self, napari_viewer):
503
508
matplotlib_widget = KeypointMatplotlibCanvas (self .viewer )
504
509
matplotlib_widget .setVisible (False )
505
510
506
-
507
511
@cached_property
508
512
def settings (self ):
509
513
return QSettings ()
@@ -537,11 +541,13 @@ def _show_trails(self, state):
537
541
self ._trails .visible = True
538
542
elif self ._trails is not None :
539
543
self ._trails .visible = False
540
-
544
+
541
545
def _show_matplotlib_canvas (self , state ):
542
546
if state == Qt .Checked :
543
547
self ._canvas = KeypointMatplotlibCanvas (self .viewer )
544
- self .viewer .window .add_dock_widget (self ._canvas , name = "Trajectory plot" , area = "bottom" )
548
+ self .viewer .window .add_dock_widget (
549
+ self ._canvas , name = "Trajectory plot" , area = "bottom"
550
+ )
545
551
self ._canvas .show ()
546
552
else :
547
553
self ._canvas .close ()
0 commit comments