4
4
from datetime import datetime
5
5
from functools import partial , cached_property
6
6
from math import ceil , log10
7
+ import matplotlib .pyplot as plt
7
8
import matplotlib .style as mplstyle
8
9
import napari
9
10
import pandas as pd
21
22
from napari .layers .utils .layer_utils import _features_to_properties
22
23
from napari .utils .events import Event
23
24
from napari .utils .history import get_save_history , update_save_history
24
- from qtpy .QtCore import Qt , QTimer , Signal , QSize , QPoint , QSettings
25
- from qtpy .QtGui import QPainter , QIcon , QAction , QCursor
25
+ from qtpy .QtCore import Qt , QTimer , Signal , QPoint , QSettings
26
+ from qtpy .QtGui import QPainter , QAction , QCursor
26
27
from qtpy .QtWidgets import (
27
28
QButtonGroup ,
28
29
QCheckBox ,
45
46
QWidget ,
46
47
)
47
48
48
- ICON_FOLDER = os .path .join (os .path .dirname (__file__ ), "assets" )
49
-
50
49
from napari_deeplabcut import keypoints
51
50
from napari_deeplabcut ._reader import _load_config
52
51
from napari_deeplabcut ._writer import _write_config , _write_image , _form_df
53
52
from napari_deeplabcut .misc import (
54
53
encode_categories ,
55
54
to_os_dir_sep ,
56
55
guarantee_multiindex_rows ,
56
+ build_color_cycles
57
57
)
58
58
59
59
@@ -622,6 +622,13 @@ def __init__(self, napari_viewer):
622
622
launch_tutorial .triggered .connect (self .start_tutorial )
623
623
self .viewer .window .view_menu .addAction (launch_tutorial )
624
624
625
+ # Hide some unused viewer buttons
626
+ self .viewer .window ._qt_viewer .viewerButtons .gridViewButton .hide ()
627
+ self .viewer .window ._qt_viewer .viewerButtons .rollDimsButton .hide ()
628
+ self .viewer .window ._qt_viewer .viewerButtons .transposeDimsButton .hide ()
629
+ self .viewer .window ._qt_viewer .layerButtons .newPointsButton .setDisabled (True )
630
+ self .viewer .window ._qt_viewer .layerButtons .newLabelsButton .setDisabled (True )
631
+
625
632
if self .settings .value ("first_launch" , True ) and not os .environ .get (
626
633
"hide_tutorial" , False
627
634
):
@@ -650,13 +657,17 @@ def _show_trails(self, state):
650
657
store = list (self ._stores .values ())[0 ]
651
658
inds = encode_categories (store .layer .properties ["label" ])
652
659
temp = np .c_ [inds , store .layer .data ]
660
+ cmap = "viridis"
661
+ for layer in self .viewer .layers :
662
+ if isinstance (layer , Points ) and layer .metadata :
663
+ cmap = layer .metadata ["colormap_name" ]
653
664
self ._trails = self .viewer .add_tracks (
654
665
temp ,
655
666
tail_length = 50 ,
656
667
head_length = 50 ,
657
668
tail_width = 6 ,
658
669
name = "trails" ,
659
- colormap = "viridis" ,
670
+ colormap = cmap ,
660
671
)
661
672
self ._trails .visible = True
662
673
elif self ._trails is not None :
@@ -933,6 +944,11 @@ def on_insert(self, event):
933
944
# Hide out of slice checkbox
934
945
point_controls .outOfSliceCheckBox .hide ()
935
946
point_controls .layout ().itemAt (15 ).widget ().hide ()
947
+ # Add dropdown menu for colormap picking
948
+ colormap_selector = DropdownMenu (plt .colormaps , self )
949
+ colormap_selector .update_to (layer .metadata ["colormap_name" ])
950
+ colormap_selector .currentTextChanged .connect (self ._update_colormap )
951
+ point_controls .layout ().addRow ("colormap" , colormap_selector )
936
952
937
953
for layer_ in self .viewer .layers :
938
954
if not isinstance (layer_ , Image ):
@@ -963,6 +979,20 @@ def on_remove(self, event):
963
979
self ._matplotlib_cb .setChecked (False )
964
980
self ._trails = None
965
981
982
+ def _update_colormap (self , colormap_name ):
983
+ for layer in self .viewer .layers :
984
+ if isinstance (layer , Points ) and layer .metadata :
985
+ face_color_cycle_maps = build_color_cycles (
986
+ layer .metadata ["header" ], colormap_name ,
987
+ )
988
+ layer .metadata ["face_color_cycles" ] = face_color_cycle_maps
989
+ face_color_prop = layer ._face .color_properties .name
990
+ layer .face_color = face_color_prop
991
+ layer .face_color_cycle = face_color_cycle_maps [face_color_prop ]
992
+ layer .events .face_color ()
993
+ self ._update_color_scheme ()
994
+ break
995
+
966
996
@register_points_action ("Change labeling mode" )
967
997
def cycle_through_label_modes (self , * args ):
968
998
self .label_mode = next (keypoints .LabelMode )
@@ -975,8 +1005,15 @@ def label_mode(self):
975
1005
def label_mode (self , mode : Union [str , keypoints .LabelMode ]):
976
1006
self ._label_mode = keypoints .LabelMode (mode )
977
1007
self .viewer .status = self .label_mode
1008
+ mode_ = str (mode )
1009
+ if mode_ == "loop" :
1010
+ for menu in self ._menus :
1011
+ menu ._locked = True
1012
+ else :
1013
+ for menu in self ._menus :
1014
+ menu ._locked = False
978
1015
for btn in self ._radio_group .buttons ():
979
- if btn .text () == str ( mode ) :
1016
+ if btn .text () == mode_ :
980
1017
btn .setChecked (True )
981
1018
break
982
1019
@@ -1038,11 +1075,6 @@ def __init__(
1038
1075
layout2 = QVBoxLayout ()
1039
1076
for menu in self .menus .values ():
1040
1077
layout2 .addWidget (menu )
1041
- self .lock_button = QPushButton ("Lock selection" )
1042
- self .lock_button .setIcon (QIcon (os .path .join (ICON_FOLDER , "unlock.svg" )))
1043
- self .lock_button .setIconSize (QSize (24 , 24 ))
1044
- self .lock_button .clicked .connect (self ._lock_current_keypoint )
1045
- layout2 .addWidget (self .lock_button )
1046
1078
group_box .setLayout (layout2 )
1047
1079
layout1 .addWidget (group_box )
1048
1080
self .setLayout (layout1 )
@@ -1072,15 +1104,6 @@ def _update_items(self):
1072
1104
self .menus ["id" ].update_items (list (self .id2label ))
1073
1105
self .menus ["label" ].update_items (self .id2label [id_ ])
1074
1106
1075
- def _lock_current_keypoint (self ):
1076
- self ._locked = not self ._locked
1077
- if self ._locked :
1078
- self .lock_button .setText ("Unlock selection" )
1079
- self .lock_button .setIcon (QIcon (os .path .join (ICON_FOLDER , "lock.svg" )))
1080
- else :
1081
- self .lock_button .setText ("Lock selection" )
1082
- self .lock_button .setIcon (QIcon (os .path .join (ICON_FOLDER , "unlock.svg" )))
1083
-
1084
1107
def update_menus (self , event ):
1085
1108
keypoint = self .store .current_keypoint
1086
1109
for attr , menu in self .menus .items ():
@@ -1097,7 +1120,7 @@ def refresh_label_menu(self, text: str):
1097
1120
1098
1121
def smart_reset (self , event ):
1099
1122
"""Set current keypoint to the first unlabeled one."""
1100
- if self ._locked :
1123
+ if self ._locked : # The currently selected point is not updated
1101
1124
return
1102
1125
unannotated = ""
1103
1126
already_annotated = self .store .annotated_keypoints
@@ -1360,4 +1383,5 @@ def add_entry(self, name, color):
1360
1383
def reset (self ):
1361
1384
self .scheme_dict = {}
1362
1385
for i in reversed (range (self ._layout .count ())):
1363
- self ._layout .itemAt (i ).widget ().deleteLater ()
1386
+ w = self ._layout .itemAt (i ).widget ()
1387
+ self ._layout .removeWidget (w )
0 commit comments