9
9
from types import MethodType
10
10
from typing import Optional , Sequence , Union
11
11
12
+ from matplotlib .backends .backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
13
+ from matplotlib .figure import Figure
14
+ from PyQt5 .QtWidgets import QSlider
15
+
12
16
import numpy as np
13
17
from napari ._qt .widgets .qt_welcome import QtWelcomeLabel
14
18
from napari .layers import Image , Points , Shapes , Tracks
@@ -290,6 +294,99 @@ def on_close(self, event, widget):
290
294
else :
291
295
event .accept ()
292
296
297
+ class KeypointMatplotlibCanvas (QWidget ):
298
+ """
299
+ Class about matplotlib canvas in which I will draw the keypoints over a range of frames
300
+ It will be at the bottom of the screen and will use the keypoints from the range of frames to plot them on a x-y time series.
301
+ """
302
+ def __init__ (self , napari_viewer ):
303
+ super ().__init__ ()
304
+
305
+ self .viewer = napari_viewer
306
+ self .figure = Figure ()
307
+ self .canvas = FigureCanvas (self .figure )
308
+ self .ax = self .figure .add_subplot (111 )
309
+ self .vline = self .ax .axvline (0 ,0 ,1 , color = 'k' , linestyle = '--' )
310
+ self .ax .set_xlabel ('Frame' )
311
+ self .ax .set_ylabel ('Y position' )
312
+ # Add a slot to specify the range of frames to plot
313
+ self .slider = QSlider (Qt .Horizontal )
314
+ self .slider .setMinimum (50 )
315
+ self .slider .setMaximum (10000 )
316
+ self .slider .setValue (50 )
317
+ self .slider .setTickPosition (QSlider .TicksBelow )
318
+ self .slider .setTickInterval (50 )
319
+ self .slider_value = QLabel (str (self .slider .value ()))
320
+ self ._window = self .slider .value ()
321
+ # Connect slider to window setter
322
+ self .slider .valueChanged .connect (self .set_window )
323
+
324
+ layout = QVBoxLayout ()
325
+ layout .addWidget (self .canvas )
326
+ layout2 = QHBoxLayout ()
327
+ layout2 .addWidget (self .slider )
328
+ layout2 .addWidget (self .slider_value )
329
+
330
+ layout .addLayout (layout2 )
331
+ self .setLayout (layout )
332
+
333
+ self .frames = []
334
+ self .keypoints = []
335
+ self .df = None
336
+ # Make widget larger
337
+ self .setMinimumHeight (300 )
338
+ # connect sliders to update plot
339
+ self .viewer .dims .events .current_step .connect (self .update_plot_range )
340
+
341
+ # Run update plot range once to initialize the plot
342
+ self .update_plot_range (Event (type_name = '' ,value = [self .viewer .dims .current_step [0 ]]))
343
+
344
+ def set_window (self , value ):
345
+ self ._window = value
346
+ self .slider_value .setText (str (value ))
347
+ self .update_plot_range (Event (type_name = '' ,value = [self .viewer .dims .current_step [0 ]]))
348
+
349
+
350
+ def update_plot_range (self , event ):
351
+
352
+ value = event .value [0 ]
353
+ if self .df is None :
354
+ points_layer = None
355
+ for layer in self .viewer .layers :
356
+ if isinstance (layer , Points ):
357
+ points_layer = layer
358
+ break
359
+
360
+ if points_layer is None :
361
+ return
362
+
363
+ self .df = _form_df (
364
+ points_layer .data ,
365
+ {
366
+ "metadata" : points_layer .metadata ,
367
+ "properties" : points_layer .properties ,
368
+ },
369
+ )
370
+
371
+ # Find the bodyparts names
372
+ bodyparts = self .df .columns .get_level_values ('bodyparts' ).unique ()
373
+ # Get only the body parts that contain the word limb in them
374
+ limb_bodyparts = [limb for limb in bodyparts if 'limb' in limb .lower ()]
375
+
376
+ for limb in limb_bodyparts :
377
+ y = self .df .xs ((limb , 'y' ), axis = 1 , level = ['bodyparts' , 'coords' ])
378
+ x = np .arange (len (y ))
379
+ # color by limb colormap using point layer metadata
380
+ color = points_layer .metadata ['face_color_cycles' ]['label' ][limb ]
381
+ self .ax .plot (x , y , color = color , label = limb )
382
+
383
+ start = max (0 , value - self ._window // 2 )
384
+ end = min (value + self ._window // 2 , len (self .df ))
385
+
386
+ self .ax .set_xlim (start , end )
387
+ self .vline .set_xdata (value )
388
+
389
+ self .canvas .draw_idle ()
293
390
294
391
class KeypointControls (QWidget ):
295
392
def __init__ (self , napari_viewer ):
@@ -354,10 +451,19 @@ def __init__(self, napari_viewer):
354
451
self ._trail_cb .stateChanged .connect (self ._show_trails )
355
452
self ._trails = None
356
453
454
+ matplotlib_label = QLabel ("Show matplotlib canvas" )
455
+ self ._matplotlib_cb = QCheckBox ()
456
+ self ._matplotlib_cb .setToolTip ("toggle matplotlib canvas visibility" )
457
+ self ._matplotlib_cb .setChecked (False )
458
+ self ._matplotlib_cb .setEnabled (False )
459
+ self ._matplotlib_cb .stateChanged .connect (self ._show_matplotlib_canvas )
460
+ self ._matplotlib_canvas = None
357
461
self ._view_scheme_cb = QCheckBox ("Show color scheme" , parent = self )
358
462
359
- hlayout .addWidget (trail_label )
463
+ hlayout .addWidget (self ._matplotlib_cb )
464
+ hlayout .addWidget (matplotlib_label )
360
465
hlayout .addWidget (self ._trail_cb )
466
+ hlayout .addWidget (trail_label )
361
467
hlayout .addWidget (self ._view_scheme_cb )
362
468
363
469
self ._layout .addLayout (hlayout )
@@ -394,6 +500,10 @@ def __init__(self, napari_viewer):
394
500
QTimer .singleShot (10 , self .start_tutorial )
395
501
self .settings .setValue ("first_launch" , False )
396
502
503
+ matplotlib_widget = KeypointMatplotlibCanvas (self .viewer )
504
+ matplotlib_widget .setVisible (False )
505
+
506
+
397
507
@cached_property
398
508
def settings (self ):
399
509
return QSettings ()
@@ -427,6 +537,14 @@ def _show_trails(self, state):
427
537
self ._trails .visible = True
428
538
elif self ._trails is not None :
429
539
self ._trails .visible = False
540
+
541
+ def _show_matplotlib_canvas (self , state ):
542
+ if state == Qt .Checked :
543
+ self ._canvas = KeypointMatplotlibCanvas (self .viewer )
544
+ self .viewer .window .add_dock_widget (self ._canvas , name = "Trajectory plot" , area = "bottom" )
545
+ self ._canvas .show ()
546
+ else :
547
+ self ._canvas .close ()
430
548
431
549
def _form_video_action_menu (self ):
432
550
group_box = QGroupBox ("Video" )
@@ -681,6 +799,7 @@ def on_insert(self, event):
681
799
}
682
800
)
683
801
self ._trail_cb .setEnabled (True )
802
+ self ._matplotlib_cb .setEnabled (True )
684
803
685
804
# Hide the color pickers, as colormaps are strictly defined by users
686
805
controls = self .viewer .window .qt_viewer .dockLayerControls
@@ -710,6 +829,7 @@ def on_remove(self, event):
710
829
menu .deleteLater ()
711
830
menu .destroy ()
712
831
self ._trail_cb .setEnabled (False )
832
+ self ._matplotlib_cb .setEnabled (False )
713
833
self .last_saved_label .hide ()
714
834
elif isinstance (layer , Image ):
715
835
self ._images_meta = dict ()
@@ -718,6 +838,7 @@ def on_remove(self, event):
718
838
self .video_widget .setVisible (False )
719
839
elif isinstance (layer , Tracks ):
720
840
self ._trail_cb .setChecked (False )
841
+ self ._matplotlib_cb .setChecked (False )
721
842
self ._trails = None
722
843
723
844
@register_points_action ("Change labeling mode" )
0 commit comments