11import math
22import itertools
33
4- from collections import defaultdict , namedtuple
4+ from collections import namedtuple
55from types import SimpleNamespace as namespace
66
77import numpy as np
1111 QSizePolicy , QGraphicsScene , QGraphicsView , QGraphicsRectItem ,
1212 QGraphicsWidget , QGraphicsSimpleTextItem , QGraphicsPixmapItem ,
1313 QGraphicsGridLayout , QGraphicsLinearLayout , QGraphicsLayoutItem ,
14- QFormLayout , QApplication
14+ QFormLayout , QApplication , QComboBox
1515)
1616from AnyQt .QtGui import (
1717 QFontMetrics , QPen , QPixmap , QColor , QLinearGradient , QPainter ,
2424)
2525import pyqtgraph as pg
2626
27+ from orangewidget .utils .combobox import ComboBox
28+
2729from Orange .data import Domain , Table
2830from Orange .data .sql .table import SqlTable
2931import Orange .distance
3032
3133from Orange .clustering import hierarchical , kmeans
32- from Orange .widgets .utils .itemmodels import DomainModel
34+ from Orange .widgets .utils .itemmodels import DomainModel , VariableListModel
3335from Orange .widgets .utils .stickygraphicsview import StickyGraphicsView
3436from Orange .widgets .utils import colorbrewer
3537from Orange .widgets .utils .annotated_data import (create_annotated_table ,
4143from Orange .widgets .widget import Msg , Input , Output
4244
4345
44- def split_domain (domain , split_label ):
45- """Split the domain based on values of `split_label` value.
46- """
47- groups = defaultdict (list )
48- for attr in domain .attributes :
49- groups [attr .attributes .get (split_label )].append (attr )
50-
51- attr_values = [attr .attributes .get (split_label )
52- for attr in domain .attributes ]
53-
54- domains = []
55- for value , attrs in groups .items ():
56- group_domain = Domain (attrs , domain .class_vars , domain .metas )
57-
58- domains .append ((value , group_domain ))
59-
60- if domains :
61- assert all (len (dom ) == len (domains [0 ][1 ]) for _ , dom in domains )
62-
63- return sorted (domains , key = lambda t : attr_values .index (t [0 ]))
64-
65-
66- def vstack_by_subdomain (data , sub_domains ):
67- domain = sub_domains [0 ]
68- newtable = Table (domain )
69-
70- for sub_dom in sub_domains :
71- sub_data = data .transform (sub_dom )
72- # TODO: improve O(N ** 2)
73- newtable .extend (sub_data )
74-
75- return newtable
76-
77-
78- def select_by_class (data , class_ ):
79- indices = select_by_class_indices (data , class_ )
80- return data [indices ]
81-
82-
83- def select_by_class_indices (data , class_ ):
84- col , _ = data .get_column_view (data .domain .class_var )
85- return col == class_
86-
87-
88- def group_by_unordered (iterable , key ):
89- groups = defaultdict (list )
90- for item in iterable :
91- groups [key (item )].append (item )
92- return groups .items ()
93-
94-
95- def barycenter (a , axis = 0 ):
96- assert 0 <= axis < 2
97- a = np .asarray (a )
98- N = a .shape [axis ]
99- tileshape = [1 if i != axis else a .shape [i ] for i in range (a .ndim )]
100- xshape = list (a .shape )
101- xshape [axis ] = 1
102- X = np .tile (np .reshape (np .arange (N ), tileshape ), xshape )
103- amin = np .nanmin (a , axis = axis , keepdims = True )
104- weights = a - amin
105- weights [np .isnan (weights )] = 0
106- wsum = np .sum (weights , axis = axis )
107- mask = wsum <= np .finfo (float ).eps
108- if axis == 1 :
109- weights [mask , :] = 1
110- else :
111- weights [:, mask ] = 1
112-
113- return np .average (X , weights = weights , axis = axis )
114-
115-
11646def kmeans_compress (X , k = 50 ):
11747 km = kmeans .KMeans (n_clusters = k , n_init = 5 , random_state = 42 )
11848 return km .get_model (X )
11949
12050
121- def candidate_split_labels (data ):
122- """
123- Return candidate labels on which we can split the data.
124- """
125- groups = defaultdict (list )
126- for attr in data .domain .attributes :
127- for item in attr .attributes .items ():
128- groups [item ].append (attr )
129-
130- by_keys = defaultdict (list )
131- for (key , _ ), attrs in groups .items ():
132- by_keys [key ].append (attrs )
133-
134- # Find the keys for which all values have the same number
135- # of attributes.
136- candidates = []
137- for key , groups in by_keys .items ():
138- count = len (groups [0 ])
139- if all (len (attrs ) == count for attrs in groups ) and \
140- len (groups ) > 1 and count > 1 :
141- candidates .append (key )
142-
143- return candidates
144-
145-
14651def leaf_indices (tree ):
14752 return [leaf .value .index for leaf in hierarchical .leaves (tree )]
14853
@@ -384,9 +289,23 @@ def cluster_ord(self):
384289 [name for name , _ , in _color_palettes ].index ("Blue-Yellow" )
385290
386291
292+ def cbselect (cb : QComboBox , value , role : Qt .ItemDataRole = Qt .EditRole ) -> None :
293+ """
294+ Find and select the `value` in the `cb` QComboBox.
295+
296+ Parameters
297+ ----------
298+ cb: QComboBox
299+ value: Any
300+ role: Qt.ItemDataRole
301+ The data role in the combo box model to match value against
302+ """
303+ cb .setCurrentIndex (cb .findData (value , role ))
304+
305+
387306class OWHeatMap (widget .OWWidget ):
388307 name = "Heat Map"
389- description = "Plot a heat map for a pair of attributes ."
308+ description = "Plot a data matrix heatmap ."
390309 icon = "icons/Heatmap.svg"
391310 priority = 260
392311 keywords = []
@@ -418,6 +337,8 @@ class Outputs:
418337 legend = settings .Setting (True )
419338 # Annotations
420339 annotation_var = settings .ContextSetting (None )
340+ split_by_var = settings .ContextSetting (None )
341+
421342 # Stored color palette settings
422343 color_settings = settings .Setting (None )
423344 user_palettes = settings .Setting ([])
@@ -542,6 +463,33 @@ def __init__(self):
542463 cluster_box , self , "row_clustering" , "Rows" ,
543464 callback = self .update_clustering_examples )
544465
466+ box = gui .vBox (self .controlArea , "Split By" )
467+
468+ self .row_split_model = DomainModel (
469+ placeholder = "(None)" ,
470+ valid_types = (Orange .data .DiscreteVariable ,),
471+ parent = self ,
472+ )
473+ self .row_split_cb = cb = ComboBox (
474+ enabled = not self .merge_kmeans ,
475+ sizeAdjustPolicy = ComboBox .AdjustToMinimumContentsLengthWithIcon ,
476+ minimumContentsLength = 14 ,
477+ toolTip = "Split the heatmap vertically by a categorical column"
478+ )
479+ self .row_split_cb .setModel (self .row_split_model )
480+ self .connect_control (
481+ "split_by_var" , lambda value , cb = cb : cbselect (cb , value )
482+ )
483+ self .connect_control (
484+ "merge_kmeans" , self .row_split_cb .setDisabled
485+ )
486+ self .split_by_var = None
487+
488+ self .row_split_cb .activated .connect (
489+ self .__on_split_rows_activated
490+ )
491+ box .layout ().addWidget (self .row_split_cb )
492+
545493 box = gui .vBox (self .controlArea , 'Annotation && Legends' )
546494
547495 gui .checkBox (box , self , 'legend' , 'Show legend' ,
@@ -626,6 +574,8 @@ def clear(self):
626574 self .merge_indices = None
627575 self .annotation_model .set_domain (None )
628576 self .annotation_var = None
577+ self .row_split_model .set_domain (None )
578+ self .split_by_var = None
629579 self .clear_scene ()
630580 self .selected_rows = []
631581 self .__columns_cache .clear ()
@@ -705,7 +655,14 @@ def set_dataset(self, data=None):
705655 if data is not None :
706656 self .annotation_model .set_domain (self .input_data .domain )
707657 self .annotation_var = None
658+ self .row_split_model .set_domain (data .domain )
659+ if data .domain .has_discrete_class :
660+ self .split_by_var = data .domain .class_var
661+ else :
662+ self .split_by_var = None
708663 self .openContext (self .input_data )
664+ if self .split_by_var not in self .row_split_model :
665+ self .split_by_var = None
709666
710667 self .update_heatmaps ()
711668 if data is not None and self .__pending_selection is not None :
@@ -715,6 +672,14 @@ def set_dataset(self, data=None):
715672
716673 self .unconditional_commit ()
717674
675+ def __on_split_rows_activated (self ):
676+ self .set_split_variable (self .row_split_cb .currentData (Qt .EditRole ))
677+
678+ def set_split_variable (self , var ):
679+ if var != self .split_by_var :
680+ self .split_by_var = var
681+ self .update_heatmaps ()
682+
718683 def update_heatmaps (self ):
719684 if self .data is not None :
720685 self .clear_scene ()
@@ -727,7 +692,9 @@ def update_heatmaps(self):
727692 elif self .merge_kmeans and len (self .data ) < 3 :
728693 self .Error .not_enough_instances_k_means ()
729694 else :
730- self .construct_heatmaps (self .data )
695+ self .heatmapparts = self .construct_heatmaps (
696+ self .data , self .split_by_var
697+ )
731698 self .construct_heatmaps_scene (
732699 self .heatmapparts , self .effective_data )
733700 self .selected_rows = []
@@ -741,7 +708,7 @@ def update_merge(self):
741708 self .update_heatmaps ()
742709 self .commit ()
743710
744- def _make_parts (self , data , group_var = None , group_key = None ):
711+ def _make_parts (self , data , group_var = None ):
745712 """
746713 Make initial `Parts` for data, split by group_var, group_key
747714 """
@@ -758,20 +725,11 @@ def _make_parts(self, data, group_var=None, group_key=None):
758725 sortindices = None ,
759726 cluster = None , cluster_ordered = None )]
760727
761- if group_key is not None :
762- col_groups = split_domain (data .domain , group_key )
763- assert len (col_groups ) > 0
764- col_indices = [np .array ([data .domain .index (var ) for var in group ])
765- for _ , group in col_groups ]
766- col_groups = [ColumnPart (title = name , domain = d , indices = ind ,
767- cluster = None , cluster_ordered = None )
768- for (name , d ), ind in zip (col_groups , col_indices )]
769- else :
770- col_groups = [
771- ColumnPart (
772- title = None , indices = slice (0 , len (data .domain .attributes )),
773- domain = data .domain , cluster = None , cluster_ordered = None )
774- ]
728+ col_groups = [
729+ ColumnPart (
730+ title = None , indices = slice (0 , len (data .domain .attributes )),
731+ domain = data .domain , cluster = None , cluster_ordered = None )
732+ ]
775733
776734 minv , maxv = np .nanmin (data .X ), np .nanmax (data .X )
777735 return Parts (row_groups , col_groups , span = (minv , maxv ))
@@ -806,11 +764,10 @@ def cluster_rows(self, data, parts):
806764
807765 row_groups .append (row ._replace (cluster = cluster , cluster_ordered = cluster_ord ))
808766
809- return parts ._replace (columns = parts . columns , rows = row_groups )
767+ return parts ._replace (rows = row_groups )
810768
811769 def cluster_columns (self , data , parts ):
812- if len (parts .columns ) > 1 :
813- data = vstack_by_subdomain (data , [col .domain for col in parts .columns ])
770+ assert len (parts .columns ) == 1 , "columns split is no longer supported"
814771 assert all (var .is_continuous for var in data .domain .attributes )
815772
816773 col0 = parts .columns [0 ]
@@ -839,21 +796,9 @@ def cluster_columns(self, data, parts):
839796
840797 col_groups = [col ._replace (cluster = cluster , cluster_ordered = cluster_ord )
841798 for col in parts .columns ]
842- return parts ._replace (columns = col_groups , rows = parts . rows )
799+ return parts ._replace (columns = col_groups )
843800
844- def construct_heatmaps (self , data , split_label = None ):
845- if split_label is not None :
846- groups = split_domain (data .domain , split_label )
847- assert len (groups ) > 0
848- else :
849- groups = [("" , data .domain )]
850-
851- if data .domain .has_discrete_class :
852- group_var = data .domain .class_var
853- else :
854- group_var = None
855-
856- group_label = split_label
801+ def construct_heatmaps (self , data , group_var = None ) -> 'Parts' :
857802 if self .merge_kmeans :
858803 if self .kmeans_model is None :
859804 effective_data = self .input_data .transform (
@@ -890,18 +835,14 @@ def construct_heatmaps(self, data, split_label=None):
890835
891836 self .__update_clustering_enable_state (effective_data )
892837
893- parts = self ._make_parts (effective_data , group_var , group_label )
838+ parts = self ._make_parts (effective_data , group_var )
894839 # Restore/update the row/columns items descriptions from cache if
895840 # available
896841 rows_cache_key = (group_var ,
897842 self .merge_kmeans_k if self .merge_kmeans else None )
898843 if rows_cache_key in self .__rows_cache :
899844 parts = parts ._replace (rows = self .__rows_cache [rows_cache_key ].rows )
900845
901- if group_label in self .__columns_cache :
902- parts = parts ._replace (
903- columns = self .__columns_cache [group_label ].columns )
904-
905846 if self .row_clustering :
906847 assert len (effective_data ) <= OWHeatMap ._MaxOrderedClustering
907848 parts = self .cluster_rows (effective_data , parts )
@@ -913,9 +854,7 @@ def construct_heatmaps(self, data, split_label=None):
913854
914855 # Cache the updated parts
915856 self .__rows_cache [rows_cache_key ] = parts
916- self .__columns_cache [group_label ] = parts
917-
918- self .heatmapparts = parts
857+ return parts
919858
920859 def construct_heatmaps_scene (self , parts , data ):
921860 def select_row (item ):
@@ -1521,8 +1460,10 @@ def send_report(self):
15211460 self .report_items ((
15221461 ("Columns:" , "Clustering" if self .col_clustering else "No sorting" ),
15231462 ("Rows:" , "Clustering" if self .row_clustering else "No sorting" ),
1463+ ("Split:" ,
1464+ self .split_by_var is not None and self .split_by_var .name ),
15241465 ("Row annotation" ,
1525- self .annotation_var is not None and self .annotation_var .name )
1466+ self .annotation_var is not None and self .annotation_var .name ),
15261467 ))
15271468 self .report_plot ()
15281469
0 commit comments