Skip to content

Commit 0ac5b15

Browse files
committed
owheatmap: Add column color annotations
1 parent a5af7de commit 0ac5b15

File tree

1 file changed

+110
-28
lines changed

1 file changed

+110
-28
lines changed

Orange/widgets/visualize/owheatmap.py

Lines changed: 110 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from AnyQt.QtCore import Qt, QSize, QRectF, QObject
1818

1919
from orangewidget.utils.combobox import ComboBox, ComboBoxSearch
20-
from Orange.data import Domain, Table, Variable
20+
from Orange.data import Domain, Table, Variable, DiscreteVariable, \
21+
ContinuousVariable
2122
from Orange.data.sql.table import SqlTable
2223
import Orange.distance
2324

@@ -183,6 +184,8 @@ class Outputs:
183184
annotation_var = settings.ContextSetting(None)
184185
#: color row annotation
185186
annotation_color_var = settings.ContextSetting(None)
187+
column_annotation_color_key = settings.ContextSetting(None)
188+
186189
# Discrete variable used to split that data/heatmaps (vertically)
187190
split_by_var = settings.ContextSetting(None)
188191
# Split heatmap columns by 'key' (horizontal)
@@ -408,13 +411,39 @@ def _(idx, cb=cb):
408411
form.addRow("Text", self.annotation_text_cb)
409412
form.addRow("Color", self.row_side_color_cb)
410413
box.layout().addWidget(annotbox)
411-
posbox = gui.vBox(box, "Column Labels Position", addSpace=False)
412-
posbox.setFlat(True)
414+
annotbox = QGroupBox("Column annotations", flat=True)
415+
form = QFormLayout(
416+
annotbox,
417+
formAlignment=Qt.AlignLeft,
418+
labelAlignment=Qt.AlignLeft,
419+
fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow
420+
)
421+
self.col_side_color_model = DomainModel(
422+
placeholder="(None)",
423+
valid_types=(DiscreteVariable, ContinuousVariable),
424+
parent=self
425+
)
426+
self.col_side_color_cb = cb = ComboBoxSearch(
427+
sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength,
428+
minimumContentsLength=12
429+
)
430+
self.col_side_color_cb.setModel(self.col_side_color_model)
431+
self.connect_control(
432+
"column_annotation_color_key", self.column_annotation_color_key_changed,
433+
)
434+
self.column_annotation_color_key = None
435+
self.col_side_color_cb.activated.connect(
436+
self.__set_column_annotation_color_key_index)
437+
413438
cb = gui.comboBox(
414-
posbox, self, "column_label_pos",
439+
None, self, "column_label_pos",
415440
callback=self.update_column_annotations)
416441
cb.setModel(create_list_model(ColumnLabelsPosData, parent=self))
417442
cb.setCurrentIndex(self.column_label_pos)
443+
form.addRow("Color", self.col_side_color_cb)
444+
form.addRow("Label position", cb)
445+
box.layout().addWidget(annotbox)
446+
418447
gui.checkBox(self.controlArea, self, "keep_aspect",
419448
"Keep aspect ratio", box="Resize",
420449
callback=self.__aspect_mode_changed)
@@ -503,7 +532,9 @@ def clear(self):
503532
self.annotation_model.set_domain(None)
504533
self.annotation_var = None
505534
self.row_side_color_model.set_domain(None)
535+
self.col_side_color_model.set_domain(None)
506536
self.annotation_color_var = None
537+
self.column_annotation_color_key = None
507538
self.row_split_model.set_domain(None)
508539
self.col_split_model.set_domain(None)
509540
self.split_by_var = None
@@ -593,12 +624,13 @@ def set_dataset(self, data=None):
593624
self.row_split_model.set_domain(data.domain)
594625
self.col_annot_data = data.transpose(data[:0].transform(Domain(data.domain.attributes)))
595626
self.col_split_model.set_domain(self.col_annot_data.domain)
596-
627+
self.col_side_color_model.set_domain(self.col_annot_data.domain)
597628
if data.domain.has_discrete_class:
598629
self.split_by_var = data.domain.class_var
599630
else:
600631
self.split_by_var = None
601632
self.split_columns_key = None
633+
self.column_annotation_color_key = None
602634
self.openContext(self.input_data)
603635
if self.split_by_var not in self.row_split_model:
604636
self.split_by_var = None
@@ -607,6 +639,10 @@ def set_dataset(self, data=None):
607639
if idx == -1:
608640
self.split_columns_key = None
609641

642+
idx = self.col_side_color_cb.findData(self.column_annotation_color_key, Qt.EditRole)
643+
if idx == -1:
644+
self.column_annotation_color_key = None
645+
610646
self.update_heatmaps()
611647
if data is not None and self.__pending_selection is not None:
612648
assert self.scene.widget is not None
@@ -630,7 +666,7 @@ def set_split_variable(self, var):
630666
self.update_heatmaps()
631667

632668
def __on_split_cols_activated(self):
633-
self.set_column_split_key(self.col_split_cb.currentData(Qt.UserRole))
669+
self.set_column_split_key(self.col_split_cb.currentData(Qt.EditRole))
634670

635671
def set_column_split_key(self, key):
636672
if key != self.split_columns_key:
@@ -809,7 +845,9 @@ def construct_heatmaps(self, data, group_var=None, column_split_key=None) -> 'Pa
809845

810846
self.__update_clustering_enable_state(effective_data)
811847

812-
parts = self._make_parts(effective_data, group_var, column_split_key)
848+
parts = self._make_parts(
849+
effective_data, group_var,
850+
column_split_key.name if column_split_key is not None else None)
813851
# Restore/update the row/columns items descriptions from cache if
814852
# available
815853
rows_cache_key = (group_var,
@@ -879,9 +917,15 @@ def setup_scene(self, parts, data):
879917
col_names=columns,
880918
)
881919
widget.setHeatmaps(parts)
920+
882921
side = self.row_side_colors()
883922
if side is not None:
884923
widget.setRowSideColorAnnotations(side[0], side[1], name=side[2].name)
924+
925+
side = self.column_side_colors()
926+
if side is not None:
927+
widget.setColumnSideColorAnnotations(side[0], side[1], name=side[2].name)
928+
885929
widget.setColumnLabelsPosition(self._column_label_pos)
886930
widget.setAspectRatioMode(
887931
Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio
@@ -1050,7 +1094,7 @@ def row_side_colors(self):
10501094
merges = self._merge_row_indices()
10511095
if merges is not None:
10521096
column_data = aggregate(var, column_data, merges)
1053-
data, colormap = self._colorize(var, column_data)
1097+
data, colormap = colorize(var, column_data)
10541098
if var.is_continuous:
10551099
span = (np.nanmin(column_data), np.nanmax(column_data))
10561100
if np.any(np.isnan(span)):
@@ -1076,27 +1120,31 @@ def update_row_side_colors(self):
10761120
else:
10771121
widget.setRowSideColorAnnotations(colors[0], colors[1], colors[2].name)
10781122

1079-
def _colorize(self, var: Variable, data: np.ndarray) -> Tuple[np.ndarray, ColorMap]:
1080-
palette = var.palette # type: Palette
1081-
colors = np.array(
1082-
[[c.red(), c.green(), c.blue()] for c in palette.qcolors_w_nan],
1083-
dtype=np.uint8,
1084-
)
1085-
if var.is_discrete:
1086-
mask = np.isnan(data)
1087-
data[mask] = -1
1088-
data = data.astype(int)
1089-
if mask.any():
1090-
values = (*var.values, "N/A")
1123+
def __set_column_annotation_color_key_index(self, index: int):
1124+
key = self.col_side_color_cb.itemData(index, Qt.EditRole)
1125+
self.set_column_annotation_color_key(key)
1126+
1127+
def column_annotation_color_key_changed(self, value):
1128+
cbselect(self.col_side_color_cb, value, Qt.EditRole)
1129+
1130+
def set_column_annotation_color_key(self, key):
1131+
if self.column_annotation_color_key != key:
1132+
self.column_annotation_color_key = key
1133+
cbselect(self.col_side_color_cb, key, Qt.EditRole)
1134+
colors = self.column_side_colors()
1135+
if colors is not None:
1136+
self.scene.widget.setColumnSideColorAnnotations(
1137+
colors[0], colors[1], colors[2].name,
1138+
)
10911139
else:
1092-
values = var.values
1093-
colors = colors[: -1]
1094-
return data, CategoricalColorMap(colors, values)
1095-
elif var.is_continuous:
1096-
cmap = GradientColorMap(colors[:-1])
1097-
return data, cmap
1098-
else:
1099-
raise TypeError
1140+
self.scene.widget.setColumnSideColorAnnotations(None)
1141+
1142+
def column_side_colors(self):
1143+
var = self.column_annotation_color_key
1144+
if var is None:
1145+
return None
1146+
table = self.col_annot_data
1147+
return color_annotation_data(table, var)
11001148

11011149
def update_column_annotations(self):
11021150
widget = self.scene.widget
@@ -1305,6 +1353,40 @@ def column_data_from_table(
13051353
return data
13061354

13071355

1356+
def color_annotation_data(
1357+
table: Table, var: Union[int, str, Variable]
1358+
) -> Tuple[np.ndarray, ColorMap, Variable]:
1359+
var = table.domain[var]
1360+
column_data = column_data_from_table(table, var)
1361+
data, colormap = colorize(var, column_data)
1362+
return data, colormap, var
1363+
1364+
1365+
def colorize(var: Variable, data: np.ndarray) -> Tuple[np.ndarray, ColorMap]:
1366+
palette = var.palette # type: Palette
1367+
colors = np.array(
1368+
[[c.red(), c.green(), c.blue()] for c in palette.qcolors_w_nan],
1369+
dtype=np.uint8,
1370+
)
1371+
if var.is_discrete:
1372+
mask = np.isnan(data)
1373+
data = data.astype(int)
1374+
data[mask] = -1
1375+
if mask.any():
1376+
values = (*var.values, "N/A")
1377+
else:
1378+
values = var.values
1379+
colors = colors[: -1]
1380+
return data, CategoricalColorMap(colors, values)
1381+
elif var.is_continuous:
1382+
span = np.nanmin(data), np.nanmax(data)
1383+
if np.any(np.isnan(span)):
1384+
span = 0, 1.
1385+
return data, GradientColorMap(colors[:-1], span=span)
1386+
else:
1387+
raise TypeError
1388+
1389+
13081390
def aggregate(
13091391
var: Variable, data: np.ndarray, groupindices: Sequence[Sequence[int]],
13101392
) -> np.ndarray:

0 commit comments

Comments
 (0)