4
4
from datetime import datetime
5
5
from functools import partial , cached_property
6
6
from math import ceil , log10
7
+ import matplotlib .style as mplstyle
8
+ import napari
7
9
import pandas as pd
8
10
from pathlib import Path
9
11
from types import MethodType
10
12
from typing import Optional , Sequence , Union
11
13
12
14
from matplotlib .backends .backend_qtagg import FigureCanvas
13
- from matplotlib .figure import Figure
14
15
15
16
import numpy as np
16
17
from napari ._qt .widgets .qt_welcome import QtWelcomeLabel
@@ -305,9 +306,10 @@ def __init__(self, napari_viewer, parent=None):
305
306
super ().__init__ (parent = parent )
306
307
307
308
self .viewer = napari_viewer
308
- self .figure = Figure ()
309
- self .canvas = FigureCanvas (self .figure )
310
- self .ax = self .figure .add_subplot (111 )
309
+ with mplstyle .context (self .mpl_style_sheet_path ):
310
+ self .canvas = FigureCanvas ()
311
+ self .canvas .figure .set_layout_engine ("constrained" )
312
+ self .ax = self .canvas .figure .subplots ()
311
313
self .vline = self .ax .axvline (0 , 0 , 1 , color = "k" , linestyle = "--" )
312
314
self .ax .set_xlabel ("Frame" )
313
315
self .ax .set_ylabel ("Y position" )
@@ -341,56 +343,92 @@ def __init__(self, napari_viewer, parent=None):
341
343
self .viewer .dims .events .current_step .connect (self .update_plot_range )
342
344
343
345
# Run update plot range once to initialize the plot
346
+ self ._n = 0
344
347
self .update_plot_range (
345
348
Event (type_name = "" , value = [self .viewer .dims .current_step [0 ]])
346
349
)
347
350
348
- def set_window (self , value ):
349
- self ._window = value
350
- self .slider_value .setText (str (value ))
351
- self .update_plot_range (
352
- Event (type_name = "" , value = [self .viewer .dims .current_step [0 ]])
353
- )
351
+ self .viewer .layers .events .inserted .connect (self ._load_dataframe )
352
+ self ._lines = {}
354
353
355
- def update_plot_range (self , event ):
356
- value = event .value [0 ]
357
- if self .df is None :
358
- points_layer = None
359
- for layer in self .viewer .layers :
360
- if isinstance (layer , Points ):
361
- points_layer = layer
362
- break
354
+ def _napari_theme_has_light_bg (self ) -> bool :
355
+ """
356
+ Does this theme have a light background?
363
357
364
- if points_layer is None :
365
- return
358
+ Returns
359
+ -------
360
+ bool
361
+ True if theme's background colour has hsl lighter than 50%, False if darker.
362
+ """
363
+ theme = napari .utils .theme .get_theme (self .viewer .theme , as_dict = False )
364
+ _ , _ , bg_lightness = theme .background .as_hsl_tuple ()
365
+ return bg_lightness > 0.5
366
366
367
- self .df = _form_df (
368
- points_layer .data ,
369
- {
370
- "metadata" : points_layer .metadata ,
371
- "properties" : points_layer .properties ,
372
- },
373
- )
367
+ @property
368
+ def mpl_style_sheet_path (self ) -> Path :
369
+ """
370
+ Path to the set Matplotlib style sheet.
371
+ """
372
+ if self ._napari_theme_has_light_bg ():
373
+ return Path (__file__ ).parent / "styles" / "light.mplstyle"
374
+ else :
375
+ return Path (__file__ ).parent / "styles" / "dark.mplstyle"
376
+
377
+ def _load_dataframe (self ):
378
+ points_layer = None
379
+ for layer in self .viewer .layers :
380
+ if isinstance (layer , Points ):
381
+ points_layer = layer
382
+ break
374
383
375
- # Find the bodyparts names
376
- bodyparts = self .df .columns .get_level_values ("bodyparts" ).unique ()
377
- # Get only the body parts that contain the word limb in them
378
- limb_bodyparts = [limb for limb in bodyparts if "limb" in limb .lower ()]
384
+ if points_layer is None :
385
+ return
386
+
387
+ self .viewer .window .add_dock_widget (self , name = "Trajectory plot" , area = "right" )
388
+ self .hide ()
389
+
390
+ self .df = _form_df (
391
+ points_layer .data ,
392
+ {
393
+ "metadata" : points_layer .metadata ,
394
+ "properties" : points_layer .properties ,
395
+ },
396
+ )
397
+ for keypoint in self .df .columns .get_level_values ("bodyparts" ).unique ():
398
+ y = self .df .xs ((keypoint , "y" ), axis = 1 , level = ["bodyparts" , "coords" ])
399
+ x = np .arange (len (y ))
400
+ color = points_layer .metadata ["face_color_cycles" ]["label" ][keypoint ]
401
+ (line ,) = self .ax .plot (x , y , color = color , label = keypoint )
402
+ self ._lines [keypoint ] = line
403
+
404
+ self ._refresh_canvas (value = self ._n )
379
405
380
- for limb in limb_bodyparts :
381
- y = self .df .xs ((limb , "y" ), axis = 1 , level = ["bodyparts" , "coords" ])
382
- x = np .arange (len (y ))
383
- # color by limb colormap using point layer metadata
384
- color = points_layer .metadata ["face_color_cycles" ]["label" ][limb ]
385
- self .ax .plot (x , y , color = color , label = limb )
406
+ def _toggle_line_visibility (self , keypoint ):
407
+ artist = self ._lines [keypoint ]
408
+ artist .set_visible (not artist .get_visible ())
409
+ self ._refresh_canvas (value = self ._n )
386
410
411
+ def _refresh_canvas (self , value ):
387
412
start = max (0 , value - self ._window // 2 )
388
413
end = min (value + self ._window // 2 , len (self .df ))
389
414
390
415
self .ax .set_xlim (start , end )
391
416
self .vline .set_xdata (value )
417
+ self .canvas .draw ()
392
418
393
- self .canvas .draw_idle ()
419
+ def set_window (self , value ):
420
+ self ._window = value
421
+ self .slider_value .setText (str (value ))
422
+ self .update_plot_range (Event (type_name = "" , value = [self ._n ]))
423
+
424
+ def update_plot_range (self , event ):
425
+ value = event .value [0 ]
426
+ self ._n = value
427
+
428
+ if self .df is None :
429
+ return
430
+
431
+ self ._refresh_canvas (value )
394
432
395
433
396
434
class KeypointControls (QWidget ):
@@ -457,12 +495,12 @@ def __init__(self, napari_viewer):
457
495
self ._trails = None
458
496
459
497
matplotlib_label = QLabel ("Show matplotlib canvas" )
498
+ self ._matplotlib_canvas = KeypointMatplotlibCanvas (self .viewer )
460
499
self ._matplotlib_cb = QCheckBox ()
461
500
self ._matplotlib_cb .setToolTip ("toggle matplotlib canvas visibility" )
501
+ self ._matplotlib_cb .stateChanged .connect (self ._show_matplotlib_canvas )
462
502
self ._matplotlib_cb .setChecked (False )
463
503
self ._matplotlib_cb .setEnabled (False )
464
- self ._matplotlib_cb .stateChanged .connect (self ._show_matplotlib_canvas )
465
- self ._matplotlib_canvas = None
466
504
self ._view_scheme_cb = QCheckBox ("Show color scheme" , parent = self )
467
505
468
506
hlayout .addWidget (self ._matplotlib_cb )
@@ -479,6 +517,11 @@ def __init__(self, napari_viewer):
479
517
self ._color_scheme_display = self ._form_color_scheme_display (self .viewer )
480
518
self ._view_scheme_cb .toggled .connect (self ._show_color_scheme )
481
519
self ._view_scheme_cb .toggle ()
520
+ self ._display .added .connect (
521
+ lambda w : w .part_label .clicked .connect (
522
+ self ._matplotlib_canvas ._toggle_line_visibility
523
+ ),
524
+ )
482
525
483
526
# Substitute default menu action with custom one
484
527
for action in self .viewer .window .file_menu .actions ()[::- 1 ]:
@@ -505,9 +548,6 @@ def __init__(self, napari_viewer):
505
548
QTimer .singleShot (10 , self .start_tutorial )
506
549
self .settings .setValue ("first_launch" , False )
507
550
508
- matplotlib_widget = KeypointMatplotlibCanvas (self .viewer )
509
- matplotlib_widget .setVisible (False )
510
-
511
551
@cached_property
512
552
def settings (self ):
513
553
return QSettings ()
@@ -544,13 +584,9 @@ def _show_trails(self, state):
544
584
545
585
def _show_matplotlib_canvas (self , state ):
546
586
if state == Qt .Checked :
547
- self ._canvas = KeypointMatplotlibCanvas (self .viewer )
548
- self .viewer .window .add_dock_widget (
549
- self ._canvas , name = "Trajectory plot" , area = "bottom"
550
- )
551
- self ._canvas .show ()
587
+ self ._matplotlib_canvas .show ()
552
588
else :
553
- self ._canvas . close ()
589
+ self ._matplotlib_canvas . hide ()
554
590
555
591
def _form_video_action_menu (self ):
556
592
group_box = QGroupBox ("Video" )
@@ -1192,6 +1228,8 @@ def part_name(self, part_name: str):
1192
1228
1193
1229
1194
1230
class ColorSchemeDisplay (QScrollArea ):
1231
+ added = Signal (object )
1232
+
1195
1233
def __init__ (self , parent ):
1196
1234
super ().__init__ (parent )
1197
1235
@@ -1235,9 +1273,9 @@ def _build(self):
1235
1273
def add_entry (self , name , color ):
1236
1274
self .scheme_dict .update ({name : color })
1237
1275
1238
- self . _layout . addWidget (
1239
- LabelPair ( color , name , self ), alignment = Qt .AlignmentFlag .AlignLeft
1240
- )
1276
+ widget = LabelPair ( color , name , self )
1277
+ self . _layout . addWidget ( widget , alignment = Qt .AlignmentFlag .AlignLeft )
1278
+ self . added . emit ( widget )
1241
1279
1242
1280
def reset (self ):
1243
1281
self .scheme_dict = {}
0 commit comments