Skip to content

Commit 7f77934

Browse files
committed
owheatmap: Add column color annotations
1 parent acf1577 commit 7f77934

File tree

2 files changed

+136
-29
lines changed

2 files changed

+136
-29
lines changed

Orange/widgets/visualize/owheatmap.py

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from AnyQt.QtCore import Qt, QSize, QRectF, QObject
1818

1919
from orangewidget.utils.combobox import ComboBox, ComboBoxSearch
20-
from Orange.data import Domain, Table, Variable, DiscreteVariable
20+
from Orange.data import Domain, Table, Variable, DiscreteVariable, \
21+
ContinuousVariable
2122
from Orange.data.sql.table import SqlTable
2223
import Orange.distance
2324

@@ -26,7 +27,7 @@
2627
from Orange.widgets.utils.itemmodels import DomainModel
2728
from Orange.widgets.utils.stickygraphicsview import StickyGraphicsView
2829
from Orange.widgets.utils.graphicsview import GraphicsWidgetView
29-
from Orange.widgets.utils.colorpalettes import DiscretePalette, Palette
30+
from Orange.widgets.utils.colorpalettes import Palette
3031

3132
from Orange.widgets.utils.annotated_data import (create_annotated_table,
3233
ANNOTATED_DATA_SIGNAL_NAME)
@@ -412,13 +413,36 @@ def _(idx, cb=cb):
412413
form.addRow("Text", self.annotation_text_cb)
413414
form.addRow("Color", self.row_side_color_cb)
414415
box.layout().addWidget(annotbox)
415-
posbox = gui.vBox(box, "Column Labels Position", addSpace=False)
416-
posbox.setFlat(True)
416+
annotbox = QGroupBox("Column annotations", flat=True)
417+
form = QFormLayout(
418+
annotbox,
419+
formAlignment=Qt.AlignLeft,
420+
labelAlignment=Qt.AlignLeft,
421+
fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow
422+
)
423+
self.col_side_color_model = DomainModel(
424+
placeholder="(None)",
425+
valid_types=(DiscreteVariable, ContinuousVariable),
426+
parent=self
427+
)
428+
self.col_side_color_cb = cb = ComboBoxSearch(
429+
sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength,
430+
minimumContentsLength=12
431+
)
432+
self.col_side_color_cb.setModel(self.col_side_color_model)
433+
self.column_annotation_color_var = None
434+
self.col_side_color_cb.activated.connect(self.__set_column_annotation_color_key_index)
435+
# posbox = gui.vBox(box, "Column Labels Position", addSpace=False)
436+
# posbox.setFlat(True)
417437
cb = gui.comboBox(
418-
posbox, self, "column_label_pos",
438+
None, self, "column_label_pos",
419439
callback=self.update_column_annotations)
420440
cb.setModel(create_list_model(ColumnLabelsPosData, parent=self))
421441
cb.setCurrentIndex(self.column_label_pos)
442+
form.addRow("Color", self.col_side_color_cb)
443+
form.addRow("Label position", cb)
444+
box.layout().addWidget(annotbox)
445+
422446
gui.checkBox(self.controlArea, self, "keep_aspect",
423447
"Keep aspect ratio", box="Resize",
424448
callback=self.__aspect_mode_changed)
@@ -596,7 +620,7 @@ def set_dataset(self, data=None):
596620
self.row_split_model.set_domain(data.domain)
597621
self.col_annot_data = data.transpose(data[:0].transform(Domain(data.domain.attributes)))
598622
self.col_split_model.set_domain(self.col_annot_data.domain)
599-
623+
self.col_side_color_model.set_domain(self.col_annot_data.domain)
600624
if data.domain.has_discrete_class:
601625
self.split_by_var = data.domain.class_var
602626
else:
@@ -633,7 +657,7 @@ def set_split_variable(self, var):
633657
self.update_heatmaps()
634658

635659
def __on_split_cols_activated(self):
636-
self.set_column_split_key(self.col_split_cb.currentData(Qt.UserRole))
660+
self.set_column_split_key(self.col_split_cb.currentData(Qt.EditRole))
637661

638662
def set_column_split_key(self, key):
639663
if key != self.split_columns_key:
@@ -812,7 +836,9 @@ def construct_heatmaps(self, data, group_var=None, column_split_key=None) -> 'Pa
812836

813837
self.__update_clustering_enable_state(effective_data)
814838

815-
parts = self._make_parts(effective_data, group_var, column_split_key)
839+
parts = self._make_parts(
840+
effective_data, group_var,
841+
column_split_key.name if column_split_key is not None else None)
816842
# Restore/update the row/columns items descriptions from cache if
817843
# available
818844
rows_cache_key = (group_var,
@@ -882,9 +908,15 @@ def setup_scene(self, parts, data):
882908
col_names=columns,
883909
)
884910
widget.setHeatmaps(parts)
911+
885912
side = self.row_side_colors()
886913
if side is not None:
887914
widget.setRowSideColorAnnotations(side[0], side[1], name=side[2].name)
915+
916+
side = self.column_side_colors()
917+
if side is not None:
918+
widget.setColumnSideColorAnnotations(side[0], side[1], name=side[2].name)
919+
888920
widget.setColumnLabelsPosition(self._column_label_pos)
889921
widget.setAspectRatioMode(
890922
Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio
@@ -1065,7 +1097,7 @@ def row_side_colors(self):
10651097
merges = self._merge_row_indices()
10661098
if merges is not None:
10671099
column_data = aggregate(var, column_data, merges)
1068-
data, colormap = self._colorize(var, column_data)
1100+
data, colormap = colorize(var, column_data)
10691101
if var.is_continuous:
10701102
span = (np.nanmin(column_data), np.nanmax(column_data))
10711103
if np.any(np.isnan(span)):
@@ -1091,27 +1123,27 @@ def update_row_side_colors(self):
10911123
else:
10921124
widget.setRowSideColorAnnotations(colors[0], colors[1], colors[2].name)
10931125

1094-
def _colorize(self, var: Variable, data: np.ndarray) -> Tuple[np.ndarray, ColorMap]:
1095-
palette = var.palette # type: Palette
1096-
colors = np.array(
1097-
[[c.red(), c.green(), c.blue()] for c in palette.qcolors_w_nan],
1098-
dtype=np.uint8,
1099-
)
1100-
if var.is_discrete:
1101-
mask = np.isnan(data)
1102-
data[mask] = -1
1103-
data = data.astype(int)
1104-
if mask.any():
1105-
values = (*var.values, "N/A")
1126+
def __set_column_annotation_color_key_index(self, index: int):
1127+
key = self.col_side_color_cb.itemData(index, Qt.EditRole)
1128+
self.set_column_annotation_color_key(key)
1129+
1130+
def set_column_annotation_color_key(self, key):
1131+
if self.col_side_color_model != key:
1132+
self.column_annotation_color_var = key
1133+
colors = self.column_side_colors()
1134+
if colors is not None:
1135+
self.scene.widget.setColumnSideColorAnnotations(
1136+
colors[0], colors[1], colors[2].name,
1137+
)
11061138
else:
1107-
values = var.values
1108-
colors = colors[: -1]
1109-
return data, CategoricalColorMap(colors, values)
1110-
elif var.is_continuous:
1111-
cmap = GradientColorMap(colors[:-1])
1112-
return data, cmap
1113-
else:
1114-
raise TypeError
1139+
self.scene.widget.setColumnSideColorAnnotations(None)
1140+
1141+
def column_side_colors(self):
1142+
var = self.column_annotation_color_var
1143+
if var is None:
1144+
return None
1145+
table = self.col_annot_data
1146+
return color_annotation_data(table, var)
11151147

11161148
def update_column_annotations(self):
11171149
widget = self.scene.widget
@@ -1320,6 +1352,40 @@ def column_data_from_table(
13201352
return data
13211353

13221354

1355+
def color_annotation_data(
1356+
table: Table, var: Union[int, str, Variable]
1357+
) -> Tuple[np.ndarray, ColorMap, Variable]:
1358+
var = table.domain[var]
1359+
column_data = column_data_from_table(table, var)
1360+
data, colormap = colorize(var, column_data)
1361+
return data, colormap, var
1362+
1363+
1364+
def colorize(var: Variable, data: np.ndarray) -> Tuple[np.ndarray, ColorMap]:
1365+
palette = var.palette # type: Palette
1366+
colors = np.array(
1367+
[[c.red(), c.green(), c.blue()] for c in palette.qcolors_w_nan],
1368+
dtype=np.uint8,
1369+
)
1370+
if var.is_discrete:
1371+
mask = np.isnan(data)
1372+
data = data.astype(int)
1373+
data[mask] = -1
1374+
if mask.any():
1375+
values = (*var.values, "N/A")
1376+
else:
1377+
values = var.values
1378+
colors = colors[: -1]
1379+
return data, CategoricalColorMap(colors, values)
1380+
elif var.is_continuous:
1381+
span = np.nanmin(data), np.nanmax(data)
1382+
if np.any(np.isnan(span)):
1383+
span = 0, 1.
1384+
return data, GradientColorMap(colors[:-1], span=span)
1385+
else:
1386+
raise TypeError
1387+
1388+
13231389
def aggregate(
13241390
var: Variable, data: np.ndarray, groupindices: Sequence[Sequence[int]],
13251391
) -> np.ndarray:

Orange/widgets/visualize/tests/test_owheatmap.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def test_row_color_annotations_with_na(self):
290290
data = data.transform(
291291
Domain(data.domain.attributes[:10], data.domain.class_vars,
292292
data.domain.metas + (data.domain["diau g"], )))
293+
data.ensure_copy()
293294
data.Y[:3] = np.nan
294295
data.metas[:3, -1] = np.nan
295296
self.send_signal(widget.Inputs.data, data, widget=widget)
@@ -304,6 +305,46 @@ def test_row_color_annotations_with_na(self):
304305
widget.set_annotation_color_var(None)
305306
self.assertFalse(widget.scene.widget.right_side_colors[0].isVisible())
306307

308+
def test_col_color_annotations(self):
309+
widget = self.widget
310+
data = self.brown_selected[::5]
311+
data = data.transform(
312+
Domain(data.domain.attributes[:10], data.domain.class_vars,
313+
data.domain.metas + (data.domain["diau g"],)))
314+
data.ensure_copy()
315+
data_t = data.transpose(data)
316+
self.send_signal(widget.Inputs.data, data_t, widget=widget)
317+
# discrete
318+
widget.set_column_annotation_color_key(data.domain["function"])
319+
self.assertTrue(widget.scene.widget.top_side_colors[0].isVisible())
320+
# continuous
321+
widget.set_column_annotation_color_key(data.domain["diau g"])
322+
widget.set_column_annotation_color_key(None)
323+
self.assertFalse(widget.scene.widget.top_side_colors[0].isVisible())
324+
325+
def test_col_color_annotations_with_na(self):
326+
widget = self.widget
327+
data = self.brown_selected[::5]
328+
data = data.transform(
329+
Domain(data.domain.attributes[:10], data.domain.class_vars,
330+
data.domain.metas + (data.domain["diau g"], )))
331+
data.ensure_copy()
332+
data.Y[:3] = np.nan
333+
data.metas[:3, -1] = np.nan
334+
data_t = data.transpose(data)
335+
self.send_signal(widget.Inputs.data, data_t, widget=widget)
336+
widget.set_column_annotation_color_key(data.domain["function"])
337+
self.assertTrue(widget.scene.widget.top_side_colors[0].isVisible())
338+
widget.set_column_annotation_color_key(data.domain["diau g"])
339+
data.Y[:] = np.nan
340+
data.metas[:, -1] = np.nan
341+
data_t = data.transpose(data)
342+
self.send_signal(widget.Inputs.data, data_t, widget=widget)
343+
widget.set_column_annotation_color_key(data.domain["function"])
344+
widget.set_column_annotation_color_key(data.domain["diau g"])
345+
widget.set_column_annotation_color_key(None)
346+
self.assertFalse(widget.scene.widget.top_side_colors[0].isVisible())
347+
307348
def test_summary(self):
308349
"""Check if status bar is updated when data is received"""
309350
info = self.widget.info

0 commit comments

Comments
 (0)