Skip to content

Commit c58eaa4

Browse files
authored
Merge pull request #4234 from ales-erjavec/owheatmap-split-by
[ENH] owheatmap: Add Split By combo box
2 parents b08fc7e + f56c634 commit c58eaa4

File tree

2 files changed

+95
-143
lines changed

2 files changed

+95
-143
lines changed

Orange/widgets/visualize/owheatmap.py

Lines changed: 84 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import itertools
33

4-
from collections import defaultdict, namedtuple
4+
from collections import namedtuple
55
from types import SimpleNamespace as namespace
66

77
import numpy as np
@@ -11,7 +11,7 @@
1111
QSizePolicy, QGraphicsScene, QGraphicsView, QGraphicsRectItem,
1212
QGraphicsWidget, QGraphicsSimpleTextItem, QGraphicsPixmapItem,
1313
QGraphicsGridLayout, QGraphicsLinearLayout, QGraphicsLayoutItem,
14-
QFormLayout, QApplication
14+
QFormLayout, QApplication, QComboBox
1515
)
1616
from AnyQt.QtGui import (
1717
QFontMetrics, QPen, QPixmap, QColor, QLinearGradient, QPainter,
@@ -24,12 +24,14 @@
2424
)
2525
import pyqtgraph as pg
2626

27+
from orangewidget.utils.combobox import ComboBox
28+
2729
from Orange.data import Domain, Table
2830
from Orange.data.sql.table import SqlTable
2931
import Orange.distance
3032

3133
from Orange.clustering import hierarchical, kmeans
32-
from Orange.widgets.utils.itemmodels import DomainModel
34+
from Orange.widgets.utils.itemmodels import DomainModel, VariableListModel
3335
from Orange.widgets.utils.stickygraphicsview import StickyGraphicsView
3436
from Orange.widgets.utils import colorbrewer
3537
from Orange.widgets.utils.annotated_data import (create_annotated_table,
@@ -41,108 +43,11 @@
4143
from Orange.widgets.widget import Msg, Input, Output
4244

4345

44-
def split_domain(domain, split_label):
45-
"""Split the domain based on values of `split_label` value.
46-
"""
47-
groups = defaultdict(list)
48-
for attr in domain.attributes:
49-
groups[attr.attributes.get(split_label)].append(attr)
50-
51-
attr_values = [attr.attributes.get(split_label)
52-
for attr in domain.attributes]
53-
54-
domains = []
55-
for value, attrs in groups.items():
56-
group_domain = Domain(attrs, domain.class_vars, domain.metas)
57-
58-
domains.append((value, group_domain))
59-
60-
if domains:
61-
assert all(len(dom) == len(domains[0][1]) for _, dom in domains)
62-
63-
return sorted(domains, key=lambda t: attr_values.index(t[0]))
64-
65-
66-
def vstack_by_subdomain(data, sub_domains):
67-
domain = sub_domains[0]
68-
newtable = Table(domain)
69-
70-
for sub_dom in sub_domains:
71-
sub_data = data.transform(sub_dom)
72-
# TODO: improve O(N ** 2)
73-
newtable.extend(sub_data)
74-
75-
return newtable
76-
77-
78-
def select_by_class(data, class_):
79-
indices = select_by_class_indices(data, class_)
80-
return data[indices]
81-
82-
83-
def select_by_class_indices(data, class_):
84-
col, _ = data.get_column_view(data.domain.class_var)
85-
return col == class_
86-
87-
88-
def group_by_unordered(iterable, key):
89-
groups = defaultdict(list)
90-
for item in iterable:
91-
groups[key(item)].append(item)
92-
return groups.items()
93-
94-
95-
def barycenter(a, axis=0):
96-
assert 0 <= axis < 2
97-
a = np.asarray(a)
98-
N = a.shape[axis]
99-
tileshape = [1 if i != axis else a.shape[i] for i in range(a.ndim)]
100-
xshape = list(a.shape)
101-
xshape[axis] = 1
102-
X = np.tile(np.reshape(np.arange(N), tileshape), xshape)
103-
amin = np.nanmin(a, axis=axis, keepdims=True)
104-
weights = a - amin
105-
weights[np.isnan(weights)] = 0
106-
wsum = np.sum(weights, axis=axis)
107-
mask = wsum <= np.finfo(float).eps
108-
if axis == 1:
109-
weights[mask, :] = 1
110-
else:
111-
weights[:, mask] = 1
112-
113-
return np.average(X, weights=weights, axis=axis)
114-
115-
11646
def kmeans_compress(X, k=50):
11747
km = kmeans.KMeans(n_clusters=k, n_init=5, random_state=42)
11848
return km.get_model(X)
11949

12050

121-
def candidate_split_labels(data):
122-
"""
123-
Return candidate labels on which we can split the data.
124-
"""
125-
groups = defaultdict(list)
126-
for attr in data.domain.attributes:
127-
for item in attr.attributes.items():
128-
groups[item].append(attr)
129-
130-
by_keys = defaultdict(list)
131-
for (key, _), attrs in groups.items():
132-
by_keys[key].append(attrs)
133-
134-
# Find the keys for which all values have the same number
135-
# of attributes.
136-
candidates = []
137-
for key, groups in by_keys.items():
138-
count = len(groups[0])
139-
if all(len(attrs) == count for attrs in groups) and \
140-
len(groups) > 1 and count > 1:
141-
candidates.append(key)
142-
143-
return candidates
144-
145-
14651
def leaf_indices(tree):
14752
return [leaf.value.index for leaf in hierarchical.leaves(tree)]
14853

@@ -384,9 +289,23 @@ def cluster_ord(self):
384289
[name for name, _, in _color_palettes].index("Blue-Yellow")
385290

386291

292+
def cbselect(cb: QComboBox, value, role: Qt.ItemDataRole = Qt.EditRole) -> None:
293+
"""
294+
Find and select the `value` in the `cb` QComboBox.
295+
296+
Parameters
297+
----------
298+
cb: QComboBox
299+
value: Any
300+
role: Qt.ItemDataRole
301+
The data role in the combo box model to match value against
302+
"""
303+
cb.setCurrentIndex(cb.findData(value, role))
304+
305+
387306
class OWHeatMap(widget.OWWidget):
388307
name = "Heat Map"
389-
description = "Plot a heat map for a pair of attributes."
308+
description = "Plot a data matrix heatmap."
390309
icon = "icons/Heatmap.svg"
391310
priority = 260
392311
keywords = []
@@ -418,6 +337,8 @@ class Outputs:
418337
legend = settings.Setting(True)
419338
# Annotations
420339
annotation_var = settings.ContextSetting(None)
340+
split_by_var = settings.ContextSetting(None)
341+
421342
# Stored color palette settings
422343
color_settings = settings.Setting(None)
423344
user_palettes = settings.Setting([])
@@ -542,6 +463,33 @@ def __init__(self):
542463
cluster_box, self, "row_clustering", "Rows",
543464
callback=self.update_clustering_examples)
544465

466+
box = gui.vBox(self.controlArea, "Split By")
467+
468+
self.row_split_model = DomainModel(
469+
placeholder="(None)",
470+
valid_types=(Orange.data.DiscreteVariable,),
471+
parent=self,
472+
)
473+
self.row_split_cb = cb = ComboBox(
474+
enabled=not self.merge_kmeans,
475+
sizeAdjustPolicy=ComboBox.AdjustToMinimumContentsLengthWithIcon,
476+
minimumContentsLength=14,
477+
toolTip="Split the heatmap vertically by a categorical column"
478+
)
479+
self.row_split_cb.setModel(self.row_split_model)
480+
self.connect_control(
481+
"split_by_var", lambda value, cb=cb: cbselect(cb, value)
482+
)
483+
self.connect_control(
484+
"merge_kmeans", self.row_split_cb.setDisabled
485+
)
486+
self.split_by_var = None
487+
488+
self.row_split_cb.activated.connect(
489+
self.__on_split_rows_activated
490+
)
491+
box.layout().addWidget(self.row_split_cb)
492+
545493
box = gui.vBox(self.controlArea, 'Annotation && Legends')
546494

547495
gui.checkBox(box, self, 'legend', 'Show legend',
@@ -626,6 +574,8 @@ def clear(self):
626574
self.merge_indices = None
627575
self.annotation_model.set_domain(None)
628576
self.annotation_var = None
577+
self.row_split_model.set_domain(None)
578+
self.split_by_var = None
629579
self.clear_scene()
630580
self.selected_rows = []
631581
self.__columns_cache.clear()
@@ -705,7 +655,14 @@ def set_dataset(self, data=None):
705655
if data is not None:
706656
self.annotation_model.set_domain(self.input_data.domain)
707657
self.annotation_var = None
658+
self.row_split_model.set_domain(data.domain)
659+
if data.domain.has_discrete_class:
660+
self.split_by_var = data.domain.class_var
661+
else:
662+
self.split_by_var = None
708663
self.openContext(self.input_data)
664+
if self.split_by_var not in self.row_split_model:
665+
self.split_by_var = None
709666

710667
self.update_heatmaps()
711668
if data is not None and self.__pending_selection is not None:
@@ -715,6 +672,14 @@ def set_dataset(self, data=None):
715672

716673
self.unconditional_commit()
717674

675+
def __on_split_rows_activated(self):
676+
self.set_split_variable(self.row_split_cb.currentData(Qt.EditRole))
677+
678+
def set_split_variable(self, var):
679+
if var != self.split_by_var:
680+
self.split_by_var = var
681+
self.update_heatmaps()
682+
718683
def update_heatmaps(self):
719684
if self.data is not None:
720685
self.clear_scene()
@@ -727,7 +692,9 @@ def update_heatmaps(self):
727692
elif self.merge_kmeans and len(self.data) < 3:
728693
self.Error.not_enough_instances_k_means()
729694
else:
730-
self.construct_heatmaps(self.data)
695+
self.heatmapparts = self.construct_heatmaps(
696+
self.data, self.split_by_var
697+
)
731698
self.construct_heatmaps_scene(
732699
self.heatmapparts, self.effective_data)
733700
self.selected_rows = []
@@ -741,7 +708,7 @@ def update_merge(self):
741708
self.update_heatmaps()
742709
self.commit()
743710

744-
def _make_parts(self, data, group_var=None, group_key=None):
711+
def _make_parts(self, data, group_var=None):
745712
"""
746713
Make initial `Parts` for data, split by group_var, group_key
747714
"""
@@ -758,20 +725,11 @@ def _make_parts(self, data, group_var=None, group_key=None):
758725
sortindices=None,
759726
cluster=None, cluster_ordered=None)]
760727

761-
if group_key is not None:
762-
col_groups = split_domain(data.domain, group_key)
763-
assert len(col_groups) > 0
764-
col_indices = [np.array([data.domain.index(var) for var in group])
765-
for _, group in col_groups]
766-
col_groups = [ColumnPart(title=name, domain=d, indices=ind,
767-
cluster=None, cluster_ordered=None)
768-
for (name, d), ind in zip(col_groups, col_indices)]
769-
else:
770-
col_groups = [
771-
ColumnPart(
772-
title=None, indices=slice(0, len(data.domain.attributes)),
773-
domain=data.domain, cluster=None, cluster_ordered=None)
774-
]
728+
col_groups = [
729+
ColumnPart(
730+
title=None, indices=slice(0, len(data.domain.attributes)),
731+
domain=data.domain, cluster=None, cluster_ordered=None)
732+
]
775733

776734
minv, maxv = np.nanmin(data.X), np.nanmax(data.X)
777735
return Parts(row_groups, col_groups, span=(minv, maxv))
@@ -806,11 +764,10 @@ def cluster_rows(self, data, parts):
806764

807765
row_groups.append(row._replace(cluster=cluster, cluster_ordered=cluster_ord))
808766

809-
return parts._replace(columns=parts.columns, rows=row_groups)
767+
return parts._replace(rows=row_groups)
810768

811769
def cluster_columns(self, data, parts):
812-
if len(parts.columns) > 1:
813-
data = vstack_by_subdomain(data, [col.domain for col in parts.columns])
770+
assert len(parts.columns) == 1, "columns split is no longer supported"
814771
assert all(var.is_continuous for var in data.domain.attributes)
815772

816773
col0 = parts.columns[0]
@@ -839,21 +796,9 @@ def cluster_columns(self, data, parts):
839796

840797
col_groups = [col._replace(cluster=cluster, cluster_ordered=cluster_ord)
841798
for col in parts.columns]
842-
return parts._replace(columns=col_groups, rows=parts.rows)
799+
return parts._replace(columns=col_groups)
843800

844-
def construct_heatmaps(self, data, split_label=None):
845-
if split_label is not None:
846-
groups = split_domain(data.domain, split_label)
847-
assert len(groups) > 0
848-
else:
849-
groups = [("", data.domain)]
850-
851-
if data.domain.has_discrete_class:
852-
group_var = data.domain.class_var
853-
else:
854-
group_var = None
855-
856-
group_label = split_label
801+
def construct_heatmaps(self, data, group_var=None) -> 'Parts':
857802
if self.merge_kmeans:
858803
if self.kmeans_model is None:
859804
effective_data = self.input_data.transform(
@@ -890,18 +835,14 @@ def construct_heatmaps(self, data, split_label=None):
890835

891836
self.__update_clustering_enable_state(effective_data)
892837

893-
parts = self._make_parts(effective_data, group_var, group_label)
838+
parts = self._make_parts(effective_data, group_var)
894839
# Restore/update the row/columns items descriptions from cache if
895840
# available
896841
rows_cache_key = (group_var,
897842
self.merge_kmeans_k if self.merge_kmeans else None)
898843
if rows_cache_key in self.__rows_cache:
899844
parts = parts._replace(rows=self.__rows_cache[rows_cache_key].rows)
900845

901-
if group_label in self.__columns_cache:
902-
parts = parts._replace(
903-
columns=self.__columns_cache[group_label].columns)
904-
905846
if self.row_clustering:
906847
assert len(effective_data) <= OWHeatMap._MaxOrderedClustering
907848
parts = self.cluster_rows(effective_data, parts)
@@ -913,9 +854,7 @@ def construct_heatmaps(self, data, split_label=None):
913854

914855
# Cache the updated parts
915856
self.__rows_cache[rows_cache_key] = parts
916-
self.__columns_cache[group_label] = parts
917-
918-
self.heatmapparts = parts
857+
return parts
919858

920859
def construct_heatmaps_scene(self, parts, data):
921860
def select_row(item):
@@ -1521,8 +1460,10 @@ def send_report(self):
15211460
self.report_items((
15221461
("Columns:", "Clustering" if self.col_clustering else "No sorting"),
15231462
("Rows:", "Clustering" if self.row_clustering else "No sorting"),
1463+
("Split:",
1464+
self.split_by_var is not None and self.split_by_var.name),
15241465
("Row annotation",
1525-
self.annotation_var is not None and self.annotation_var.name)
1466+
self.annotation_var is not None and self.annotation_var.name),
15261467
))
15271468
self.report_plot()
15281469

Orange/widgets/visualize/tests/test_owheatmap.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,17 @@ def test_saved_selection(self):
198198
self.send_signal(w.Inputs.data, iris, widget=w)
199199
self.assertEqual(len(self.get_output(w.Outputs.selected_data)), 21)
200200

201+
def test_set_split_var(self):
202+
data = Table("brown-selected")
203+
w = self.widget
204+
self.send_signal(self.widget.Inputs.data, data, widget=w)
205+
self.assertIs(w.split_by_var, data.domain.class_var)
206+
self.assertEqual(len(w.heatmapparts.rows),
207+
len(data.domain.class_var.values))
208+
w.set_split_variable(None)
209+
self.assertIs(w.split_by_var, None)
210+
self.assertEqual(len(w.heatmapparts.rows), 1)
211+
201212

202213
if __name__ == "__main__":
203214
unittest.main()

0 commit comments

Comments
 (0)