1717from AnyQt .QtCore import Qt , QSize , QRectF , QObject
1818
1919from orangewidget .utils .combobox import ComboBox , ComboBoxSearch
20- from Orange .data import Domain , Table , Variable
20+ from Orange .data import Domain , Table , Variable , DiscreteVariable , \
21+ ContinuousVariable
2122from Orange .data .sql .table import SqlTable
2223import Orange .distance
2324
@@ -183,6 +184,8 @@ class Outputs:
183184 annotation_var = settings .ContextSetting (None )
184185 #: color row annotation
185186 annotation_color_var = settings .ContextSetting (None )
187+ column_annotation_color_key = settings .ContextSetting (None )
188+
186189 # Discrete variable used to split that data/heatmaps (vertically)
187190 split_by_var = settings .ContextSetting (None )
188191 # Split heatmap columns by 'key' (horizontal)
@@ -408,13 +411,39 @@ def _(idx, cb=cb):
408411 form .addRow ("Text" , self .annotation_text_cb )
409412 form .addRow ("Color" , self .row_side_color_cb )
410413 box .layout ().addWidget (annotbox )
411- posbox = gui .vBox (box , "Column Labels Position" , addSpace = False )
412- posbox .setFlat (True )
414+ annotbox = QGroupBox ("Column annotations" , flat = True )
415+ form = QFormLayout (
416+ annotbox ,
417+ formAlignment = Qt .AlignLeft ,
418+ labelAlignment = Qt .AlignLeft ,
419+ fieldGrowthPolicy = QFormLayout .AllNonFixedFieldsGrow
420+ )
421+ self .col_side_color_model = DomainModel (
422+ placeholder = "(None)" ,
423+ valid_types = (DiscreteVariable , ContinuousVariable ),
424+ parent = self
425+ )
426+ self .col_side_color_cb = cb = ComboBoxSearch (
427+ sizeAdjustPolicy = QComboBox .AdjustToMinimumContentsLength ,
428+ minimumContentsLength = 12
429+ )
430+ self .col_side_color_cb .setModel (self .col_side_color_model )
431+ self .connect_control (
432+ "column_annotation_color_key" , self .column_annotation_color_key_changed ,
433+ )
434+ self .column_annotation_color_key = None
435+ self .col_side_color_cb .activated .connect (
436+ self .__set_column_annotation_color_key_index )
437+
413438 cb = gui .comboBox (
414- posbox , self , "column_label_pos" ,
439+ None , self , "column_label_pos" ,
415440 callback = self .update_column_annotations )
416441 cb .setModel (create_list_model (ColumnLabelsPosData , parent = self ))
417442 cb .setCurrentIndex (self .column_label_pos )
443+ form .addRow ("Color" , self .col_side_color_cb )
444+ form .addRow ("Label position" , cb )
445+ box .layout ().addWidget (annotbox )
446+
418447 gui .checkBox (self .controlArea , self , "keep_aspect" ,
419448 "Keep aspect ratio" , box = "Resize" ,
420449 callback = self .__aspect_mode_changed )
@@ -503,7 +532,9 @@ def clear(self):
503532 self .annotation_model .set_domain (None )
504533 self .annotation_var = None
505534 self .row_side_color_model .set_domain (None )
535+ self .col_side_color_model .set_domain (None )
506536 self .annotation_color_var = None
537+ self .column_annotation_color_key = None
507538 self .row_split_model .set_domain (None )
508539 self .col_split_model .set_domain (None )
509540 self .split_by_var = None
@@ -593,12 +624,13 @@ def set_dataset(self, data=None):
593624 self .row_split_model .set_domain (data .domain )
594625 self .col_annot_data = data .transpose (data [:0 ].transform (Domain (data .domain .attributes )))
595626 self .col_split_model .set_domain (self .col_annot_data .domain )
596-
627+ self . col_side_color_model . set_domain ( self . col_annot_data . domain )
597628 if data .domain .has_discrete_class :
598629 self .split_by_var = data .domain .class_var
599630 else :
600631 self .split_by_var = None
601632 self .split_columns_key = None
633+ self .column_annotation_color_key = None
602634 self .openContext (self .input_data )
603635 if self .split_by_var not in self .row_split_model :
604636 self .split_by_var = None
@@ -607,6 +639,10 @@ def set_dataset(self, data=None):
607639 if idx == - 1 :
608640 self .split_columns_key = None
609641
642+ idx = self .col_side_color_cb .findData (self .column_annotation_color_key , Qt .EditRole )
643+ if idx == - 1 :
644+ self .column_annotation_color_key = None
645+
610646 self .update_heatmaps ()
611647 if data is not None and self .__pending_selection is not None :
612648 assert self .scene .widget is not None
@@ -630,7 +666,7 @@ def set_split_variable(self, var):
630666 self .update_heatmaps ()
631667
632668 def __on_split_cols_activated (self ):
633- self .set_column_split_key (self .col_split_cb .currentData (Qt .UserRole ))
669+ self .set_column_split_key (self .col_split_cb .currentData (Qt .EditRole ))
634670
635671 def set_column_split_key (self , key ):
636672 if key != self .split_columns_key :
@@ -809,7 +845,9 @@ def construct_heatmaps(self, data, group_var=None, column_split_key=None) -> 'Pa
809845
810846 self .__update_clustering_enable_state (effective_data )
811847
812- parts = self ._make_parts (effective_data , group_var , column_split_key )
848+ parts = self ._make_parts (
849+ effective_data , group_var ,
850+ column_split_key .name if column_split_key is not None else None )
813851 # Restore/update the row/columns items descriptions from cache if
814852 # available
815853 rows_cache_key = (group_var ,
@@ -879,9 +917,15 @@ def setup_scene(self, parts, data):
879917 col_names = columns ,
880918 )
881919 widget .setHeatmaps (parts )
920+
882921 side = self .row_side_colors ()
883922 if side is not None :
884923 widget .setRowSideColorAnnotations (side [0 ], side [1 ], name = side [2 ].name )
924+
925+ side = self .column_side_colors ()
926+ if side is not None :
927+ widget .setColumnSideColorAnnotations (side [0 ], side [1 ], name = side [2 ].name )
928+
885929 widget .setColumnLabelsPosition (self ._column_label_pos )
886930 widget .setAspectRatioMode (
887931 Qt .KeepAspectRatio if self .keep_aspect else Qt .IgnoreAspectRatio
@@ -1050,7 +1094,7 @@ def row_side_colors(self):
10501094 merges = self ._merge_row_indices ()
10511095 if merges is not None :
10521096 column_data = aggregate (var , column_data , merges )
1053- data , colormap = self . _colorize (var , column_data )
1097+ data , colormap = colorize (var , column_data )
10541098 if var .is_continuous :
10551099 span = (np .nanmin (column_data ), np .nanmax (column_data ))
10561100 if np .any (np .isnan (span )):
@@ -1076,27 +1120,31 @@ def update_row_side_colors(self):
10761120 else :
10771121 widget .setRowSideColorAnnotations (colors [0 ], colors [1 ], colors [2 ].name )
10781122
1079- def _colorize (self , var : Variable , data : np .ndarray ) -> Tuple [np .ndarray , ColorMap ]:
1080- palette = var .palette # type: Palette
1081- colors = np .array (
1082- [[c .red (), c .green (), c .blue ()] for c in palette .qcolors_w_nan ],
1083- dtype = np .uint8 ,
1084- )
1085- if var .is_discrete :
1086- mask = np .isnan (data )
1087- data [mask ] = - 1
1088- data = data .astype (int )
1089- if mask .any ():
1090- values = (* var .values , "N/A" )
1123+ def __set_column_annotation_color_key_index (self , index : int ):
1124+ key = self .col_side_color_cb .itemData (index , Qt .EditRole )
1125+ self .set_column_annotation_color_key (key )
1126+
1127+ def column_annotation_color_key_changed (self , value ):
1128+ cbselect (self .col_side_color_cb , value , Qt .EditRole )
1129+
1130+ def set_column_annotation_color_key (self , key ):
1131+ if self .column_annotation_color_key != key :
1132+ self .column_annotation_color_key = key
1133+ cbselect (self .col_side_color_cb , key , Qt .EditRole )
1134+ colors = self .column_side_colors ()
1135+ if colors is not None :
1136+ self .scene .widget .setColumnSideColorAnnotations (
1137+ colors [0 ], colors [1 ], colors [2 ].name ,
1138+ )
10911139 else :
1092- values = var . values
1093- colors = colors [: - 1 ]
1094- return data , CategoricalColorMap ( colors , values )
1095- elif var . is_continuous :
1096- cmap = GradientColorMap ( colors [: - 1 ])
1097- return data , cmap
1098- else :
1099- raise TypeError
1140+ self . scene . widget . setColumnSideColorAnnotations ( None )
1141+
1142+ def column_side_colors ( self ):
1143+ var = self . column_annotation_color_key
1144+ if var is None :
1145+ return None
1146+ table = self . col_annot_data
1147+ return color_annotation_data ( table , var )
11001148
11011149 def update_column_annotations (self ):
11021150 widget = self .scene .widget
@@ -1305,6 +1353,40 @@ def column_data_from_table(
13051353 return data
13061354
13071355
1356+ def color_annotation_data (
1357+ table : Table , var : Union [int , str , Variable ]
1358+ ) -> Tuple [np .ndarray , ColorMap , Variable ]:
1359+ var = table .domain [var ]
1360+ column_data = column_data_from_table (table , var )
1361+ data , colormap = colorize (var , column_data )
1362+ return data , colormap , var
1363+
1364+
1365+ def colorize (var : Variable , data : np .ndarray ) -> Tuple [np .ndarray , ColorMap ]:
1366+ palette = var .palette # type: Palette
1367+ colors = np .array (
1368+ [[c .red (), c .green (), c .blue ()] for c in palette .qcolors_w_nan ],
1369+ dtype = np .uint8 ,
1370+ )
1371+ if var .is_discrete :
1372+ mask = np .isnan (data )
1373+ data = data .astype (int )
1374+ data [mask ] = - 1
1375+ if mask .any ():
1376+ values = (* var .values , "N/A" )
1377+ else :
1378+ values = var .values
1379+ colors = colors [: - 1 ]
1380+ return data , CategoricalColorMap (colors , values )
1381+ elif var .is_continuous :
1382+ span = np .nanmin (data ), np .nanmax (data )
1383+ if np .any (np .isnan (span )):
1384+ span = 0 , 1.
1385+ return data , GradientColorMap (colors [:- 1 ], span = span )
1386+ else :
1387+ raise TypeError
1388+
1389+
13081390def aggregate (
13091391 var : Variable , data : np .ndarray , groupindices : Sequence [Sequence [int ]],
13101392) -> np .ndarray :
0 commit comments