1717from AnyQt .QtCore import Qt , QSize , QRectF , QObject
1818
1919from orangewidget .utils .combobox import ComboBox , ComboBoxSearch
20- from Orange .data import Domain , Table , Variable , DiscreteVariable
20+ from Orange .data import Domain , Table , Variable , DiscreteVariable , \
21+ ContinuousVariable
2122from Orange .data .sql .table import SqlTable
2223import Orange .distance
2324
2627from Orange .widgets .utils .itemmodels import DomainModel
2728from Orange .widgets .utils .stickygraphicsview import StickyGraphicsView
2829from Orange .widgets .utils .graphicsview import GraphicsWidgetView
29- from Orange .widgets .utils .colorpalettes import DiscretePalette , Palette
30+ from Orange .widgets .utils .colorpalettes import Palette
3031
3132from Orange .widgets .utils .annotated_data import (create_annotated_table ,
3233 ANNOTATED_DATA_SIGNAL_NAME )
@@ -412,13 +413,36 @@ def _(idx, cb=cb):
412413 form .addRow ("Text" , self .annotation_text_cb )
413414 form .addRow ("Color" , self .row_side_color_cb )
414415 box .layout ().addWidget (annotbox )
415- posbox = gui .vBox (box , "Column Labels Position" , addSpace = False )
416- posbox .setFlat (True )
416+ annotbox = QGroupBox ("Column annotations" , flat = True )
417+ form = QFormLayout (
418+ annotbox ,
419+ formAlignment = Qt .AlignLeft ,
420+ labelAlignment = Qt .AlignLeft ,
421+ fieldGrowthPolicy = QFormLayout .AllNonFixedFieldsGrow
422+ )
423+ self .col_side_color_model = DomainModel (
424+ placeholder = "(None)" ,
425+ valid_types = (DiscreteVariable , ContinuousVariable ),
426+ parent = self
427+ )
428+ self .col_side_color_cb = cb = ComboBoxSearch (
429+ sizeAdjustPolicy = QComboBox .AdjustToMinimumContentsLength ,
430+ minimumContentsLength = 12
431+ )
432+ self .col_side_color_cb .setModel (self .col_side_color_model )
433+ self .column_annotation_color_var = None
434+ self .col_side_color_cb .activated .connect (self .__set_column_annotation_color_key_index )
435+ # posbox = gui.vBox(box, "Column Labels Position", addSpace=False)
436+ # posbox.setFlat(True)
417437 cb = gui .comboBox (
418- posbox , self , "column_label_pos" ,
438+ None , self , "column_label_pos" ,
419439 callback = self .update_column_annotations )
420440 cb .setModel (create_list_model (ColumnLabelsPosData , parent = self ))
421441 cb .setCurrentIndex (self .column_label_pos )
442+ form .addRow ("Color" , self .col_side_color_cb )
443+ form .addRow ("Label position" , cb )
444+ box .layout ().addWidget (annotbox )
445+
422446 gui .checkBox (self .controlArea , self , "keep_aspect" ,
423447 "Keep aspect ratio" , box = "Resize" ,
424448 callback = self .__aspect_mode_changed )
@@ -596,7 +620,7 @@ def set_dataset(self, data=None):
596620 self .row_split_model .set_domain (data .domain )
597621 self .col_annot_data = data .transpose (data [:0 ].transform (Domain (data .domain .attributes )))
598622 self .col_split_model .set_domain (self .col_annot_data .domain )
599-
623+ self . col_side_color_model . set_domain ( self . col_annot_data . domain )
600624 if data .domain .has_discrete_class :
601625 self .split_by_var = data .domain .class_var
602626 else :
@@ -633,7 +657,7 @@ def set_split_variable(self, var):
633657 self .update_heatmaps ()
634658
635659 def __on_split_cols_activated (self ):
636- self .set_column_split_key (self .col_split_cb .currentData (Qt .UserRole ))
660+ self .set_column_split_key (self .col_split_cb .currentData (Qt .EditRole ))
637661
638662 def set_column_split_key (self , key ):
639663 if key != self .split_columns_key :
@@ -812,7 +836,9 @@ def construct_heatmaps(self, data, group_var=None, column_split_key=None) -> 'Pa
812836
813837 self .__update_clustering_enable_state (effective_data )
814838
815- parts = self ._make_parts (effective_data , group_var , column_split_key )
839+ parts = self ._make_parts (
840+ effective_data , group_var ,
841+ column_split_key .name if column_split_key is not None else None )
816842 # Restore/update the row/columns items descriptions from cache if
817843 # available
818844 rows_cache_key = (group_var ,
@@ -882,9 +908,15 @@ def setup_scene(self, parts, data):
882908 col_names = columns ,
883909 )
884910 widget .setHeatmaps (parts )
911+
885912 side = self .row_side_colors ()
886913 if side is not None :
887914 widget .setRowSideColorAnnotations (side [0 ], side [1 ], name = side [2 ].name )
915+
916+ side = self .column_side_colors ()
917+ if side is not None :
918+ widget .setColumnSideColorAnnotations (side [0 ], side [1 ], name = side [2 ].name )
919+
888920 widget .setColumnLabelsPosition (self ._column_label_pos )
889921 widget .setAspectRatioMode (
890922 Qt .KeepAspectRatio if self .keep_aspect else Qt .IgnoreAspectRatio
@@ -1065,7 +1097,7 @@ def row_side_colors(self):
10651097 merges = self ._merge_row_indices ()
10661098 if merges is not None :
10671099 column_data = aggregate (var , column_data , merges )
1068- data , colormap = self . _colorize (var , column_data )
1100+ data , colormap = colorize (var , column_data )
10691101 if var .is_continuous :
10701102 span = (np .nanmin (column_data ), np .nanmax (column_data ))
10711103 if np .any (np .isnan (span )):
@@ -1091,27 +1123,27 @@ def update_row_side_colors(self):
10911123 else :
10921124 widget .setRowSideColorAnnotations (colors [0 ], colors [1 ], colors [2 ].name )
10931125
1094- def _colorize (self , var : Variable , data : np . ndarray ) -> Tuple [ np . ndarray , ColorMap ] :
1095- palette = var . palette # type: Palette
1096- colors = np . array (
1097- [[ c . red (), c . green (), c . blue ()] for c in palette . qcolors_w_nan ],
1098- dtype = np . uint8 ,
1099- )
1100- if var . is_discrete :
1101- mask = np . isnan ( data )
1102- data [ mask ] = - 1
1103- data = data . astype ( int )
1104- if mask . any ():
1105- values = ( * var . values , "N/A" )
1126+ def __set_column_annotation_color_key_index (self , index : int ) :
1127+ key = self . col_side_color_cb . itemData ( index , Qt . EditRole )
1128+ self . set_column_annotation_color_key ( key )
1129+
1130+ def set_column_annotation_color_key ( self , key ):
1131+ if self . col_side_color_model != key :
1132+ self . column_annotation_color_var = key
1133+ colors = self . column_side_colors ( )
1134+ if colors is not None :
1135+ self . scene . widget . setColumnSideColorAnnotations (
1136+ colors [ 0 ], colors [ 1 ], colors [ 2 ]. name ,
1137+ )
11061138 else :
1107- values = var . values
1108- colors = colors [: - 1 ]
1109- return data , CategoricalColorMap ( colors , values )
1110- elif var . is_continuous :
1111- cmap = GradientColorMap ( colors [: - 1 ])
1112- return data , cmap
1113- else :
1114- raise TypeError
1139+ self . scene . widget . setColumnSideColorAnnotations ( None )
1140+
1141+ def column_side_colors ( self ):
1142+ var = self . column_annotation_color_var
1143+ if var is None :
1144+ return None
1145+ table = self . col_annot_data
1146+ return color_annotation_data ( table , var )
11151147
11161148 def update_column_annotations (self ):
11171149 widget = self .scene .widget
@@ -1320,6 +1352,40 @@ def column_data_from_table(
13201352 return data
13211353
13221354
1355+ def color_annotation_data (
1356+ table : Table , var : Union [int , str , Variable ]
1357+ ) -> Tuple [np .ndarray , ColorMap , Variable ]:
1358+ var = table .domain [var ]
1359+ column_data = column_data_from_table (table , var )
1360+ data , colormap = colorize (var , column_data )
1361+ return data , colormap , var
1362+
1363+
1364+ def colorize (var : Variable , data : np .ndarray ) -> Tuple [np .ndarray , ColorMap ]:
1365+ palette = var .palette # type: Palette
1366+ colors = np .array (
1367+ [[c .red (), c .green (), c .blue ()] for c in palette .qcolors_w_nan ],
1368+ dtype = np .uint8 ,
1369+ )
1370+ if var .is_discrete :
1371+ mask = np .isnan (data )
1372+ data = data .astype (int )
1373+ data [mask ] = - 1
1374+ if mask .any ():
1375+ values = (* var .values , "N/A" )
1376+ else :
1377+ values = var .values
1378+ colors = colors [: - 1 ]
1379+ return data , CategoricalColorMap (colors , values )
1380+ elif var .is_continuous :
1381+ span = np .nanmin (data ), np .nanmax (data )
1382+ if np .any (np .isnan (span )):
1383+ span = 0 , 1.
1384+ return data , GradientColorMap (colors [:- 1 ], span = span )
1385+ else :
1386+ raise TypeError
1387+
1388+
13231389def aggregate (
13241390 var : Variable , data : np .ndarray , groupindices : Sequence [Sequence [int ]],
13251391) -> np .ndarray :
0 commit comments