diff --git a/Orange/data/table.py b/Orange/data/table.py index 7716d187c21..47f7668e530 100644 --- a/Orange/data/table.py +++ b/Orange/data/table.py @@ -1607,7 +1607,7 @@ def transpose(cls, table, feature_names_column="", names = get_unique_names_duplicates(names) attributes = [ContinuousVariable(name) for name in names] else: - places = int(np.ceil(np.log10(n_cols))) + places = int(np.ceil(np.log10(n_cols))) if n_cols else 1 attributes = [ContinuousVariable(f"{feature_name} {i:0{places}}") for i in range(1, n_cols + 1)] if old_domain is not None and feature_names_column: diff --git a/Orange/widgets/utils/graphicsflowlayout.py b/Orange/widgets/utils/graphicsflowlayout.py new file mode 100644 index 00000000000..618f1c70584 --- /dev/null +++ b/Orange/widgets/utils/graphicsflowlayout.py @@ -0,0 +1,191 @@ +from functools import reduce +from types import SimpleNamespace +from typing import Optional, List, Iterable, Tuple + +import numpy as np + +from AnyQt.QtCore import QRectF, QSizeF, Qt, QPointF, QMarginsF +from AnyQt.QtWidgets import QGraphicsLayout, QGraphicsLayoutItem + +import sip + +FLT_MAX = np.finfo(np.float32).max + + +class _FlowLayoutItem(SimpleNamespace): + item: QGraphicsLayoutItem + geom: QRectF + size: QSizeF + row: int = 0 + alignment: Qt.Alignment = 0 + + +class GraphicsFlowLayout(QGraphicsLayout): + def __init__(self, parent: Optional[QGraphicsLayoutItem] = None): + self.__items: List[QGraphicsLayoutItem] = [] + self.__spacing: Tuple[float, float] = (1., 1.) + super().__init__(parent) + sp = self.sizePolicy() + sp.setHeightForWidth(True) + self.setSizePolicy(sp) + + def setVerticalSpacing(self, spacing: float) -> None: + new = (self.__spacing[0], spacing) + if new != self.__spacing: + self.__spacing = new + self.invalidate() + + def verticalSpacing(self) -> float: + return self.__spacing[1] + + def setHorizontalSpacing(self, spacing: float) -> None: + new = (spacing, self.__spacing[1]) + if new != self.__spacing: + self.__spacing = new + self.invalidate() + + def horizontalSpacing(self) -> float: + return self.__spacing[0] + + def setGeometry(self, rect: QRectF) -> None: + super().setGeometry(rect) + margins = QMarginsF(*self.getContentsMargins()) + rect = rect.marginsRemoved(margins) + for item, r in zip(self.__items, self.__doLayout(rect)): + item.setGeometry(r) + + def invalidate(self) -> None: + self.updateGeometry() + super().invalidate() + + def __doLayout(self, rect: QRectF) -> Iterable[QRectF]: + x = y = 0 + rowheight = 0 + width = rect.width() + spacing_x, spacing_y = self.__spacing + first_in_row = True + rows: List[List[QRectF]] = [[]] + + def break_(): + nonlocal x, y, rowheight, first_in_row + y += rowheight + spacing_y + x = 0 + rowheight = 0 + first_in_row = True + rows.append([]) + + items = [_FlowLayoutItem(item=item, geom=QRectF(), size=QSizeF()) + for item in self.__items] + + for flitem in items: + item = flitem.item + sh = item.effectiveSizeHint(Qt.PreferredSize) + if x + sh.width() > width and not first_in_row: + break_() + r = QRectF(rect.x() + x, rect.y() + y, sh.width(), sh.height()) + flitem.geom = r + flitem.size = sh + flitem.row = len(rows) - 1 + rowheight = max(rowheight, sh.height()) + x += sh.width() + spacing_x + first_in_row = False + rows[-1].append(flitem.geom) + + alignment = Qt.AlignVCenter | Qt.AlignLeft + for flitem in items: + row = rows[flitem.row] + row_rect = reduce(QRectF.united, row, QRectF()) + if row_rect.isEmpty(): + continue + flitem.geom = qrect_aligned_to( + flitem.geom, row_rect, alignment & Qt.AlignVertical_Mask) + return [fli.geom for fli in items] + + def sizeHint(self, which: Qt.SizeHint, constraint=QSizeF(-1, -1)) -> QSizeF: + left, top, right, bottom = self.getContentsMargins() + extra_margins = QSizeF(left + right, top + bottom) + if constraint.width() >= 0: + constraint.setWidth( + max(constraint.width() - extra_margins.width(), 0.0)) + + if which == Qt.PreferredSize: + if constraint.width() >= 0: + rect = QRectF(0, 0, constraint.width(), FLT_MAX) + else: + rect = QRectF(0, 0, FLT_MAX, FLT_MAX) + res = self.__doLayout(rect) + sh = reduce(QRectF.united, res, QRectF()).size() + return sh + extra_margins + if which == Qt.MinimumSize: + return reduce(QSizeF.expandedTo, + (item.minimumSize() for item in self.__items), + QSizeF()) + extra_margins + return QSizeF() + + def count(self) -> int: + return len(self.__items) + + def itemAt(self, i: int) -> QGraphicsLayoutItem: + try: + return self.__items[i] + except IndexError: + return None # type: ignore + + def removeAt(self, index: int) -> None: + try: + item = self.__items.pop(index) + except IndexError: + pass + else: + item.setParentLayoutItem(None) + self.invalidate() + + def removeItem(self, item: QGraphicsLayoutItem): + try: + self.__items.remove(item) + except ValueError: + pass + else: + item.setParentLayoutItem(None) + self.invalidate() + + def addItem(self, item: QGraphicsLayoutItem) -> None: + self.insertItem(self.count(), item) + + def insertItem(self, index: int, item: QGraphicsLayoutItem, ) -> None: + self.addChildLayoutItem(item) + if 0 <= index < self.count(): + self.__items.insert(index, item) + else: + self.__items.append(item) + self.updateGeometry() + self.invalidate() + + def __dtor__(self): + items = self.__items + self.__items = [] + for item in items: + item.setParentLayoutItem(None) + if item.ownedByLayout(): + sip.delete(item) + + +def qrect_aligned_to( + rect_a: QRectF, rect_b: QRectF, alignment: Qt.Alignment) -> QRectF: + res = QRectF(rect_a) + valign = alignment & Qt.AlignVertical_Mask + halign = alignment & Qt.AlignHorizontal_Mask + if valign == Qt.AlignTop: + res.moveTop(rect_b.top()) + if valign == Qt.AlignVCenter: + res.moveCenter(QPointF(res.center().x(), rect_b.center().y())) + if valign == Qt.AlignBottom: + res.moveBottom(rect_b.bottom()) + + if halign == Qt.AlignLeft: + res.moveLeft(rect_b.left()) + if halign == Qt.AlignHCenter: + res.moveCenter(QPointF(rect_b.center().x(), res.center().y())) + if halign == Qt.AlignRight: + res.moveRight(rect_b.right()) + return res diff --git a/Orange/widgets/utils/tests/test_graphicsflowlayout.py b/Orange/widgets/utils/tests/test_graphicsflowlayout.py new file mode 100644 index 00000000000..73e389985e1 --- /dev/null +++ b/Orange/widgets/utils/tests/test_graphicsflowlayout.py @@ -0,0 +1,46 @@ +from AnyQt.QtCore import Qt, QSizeF, QRectF +from AnyQt.QtWidgets import QGraphicsWidget + +from Orange.widgets.tests.base import GuiTest +from Orange.widgets.utils.graphicsflowlayout import GraphicsFlowLayout + + +class TestGraphicsFlowLayout(GuiTest): + def test_layout(self): + layout = GraphicsFlowLayout() + layout.setContentsMargins(1, 1, 1, 1) + layout.setHorizontalSpacing(3) + self.assertEqual(layout.horizontalSpacing(), 3) + layout.setVerticalSpacing(3) + self.assertEqual(layout.verticalSpacing(), 3) + + def widget(): + w = QGraphicsWidget() + w.setMinimumSize(QSizeF(10, 10)) + w.setMaximumSize(QSizeF(10, 10)) + return w + + layout.addItem(widget()) + layout.addItem(widget()) + layout.addItem(widget()) + self.assertEqual(layout.count(), 3) + sh = layout.effectiveSizeHint(Qt.PreferredSize) + self.assertEqual(sh, QSizeF(30 + 6 + 2, 12)) + sh = layout.effectiveSizeHint(Qt.PreferredSize, QSizeF(12, -1)) + self.assertEqual(sh, QSizeF(12, 30 + 6 + 2)) + layout.setGeometry(QRectF(0, 0, sh.width(), sh.height())) + w1 = layout.itemAt(0) + self.assertEqual(w1.geometry(), QRectF(1, 1, 10, 10)) + w3 = layout.itemAt(2) + self.assertEqual(w3.geometry(), QRectF(1, 1 + 2 * 10 + 2 * 3, 10, 10)) + + def test_add_remove(self): + layout = GraphicsFlowLayout() + layout.addItem(GraphicsFlowLayout()) + layout.removeAt(0) + self.assertEqual(layout.count(), 0) + layout.addItem(GraphicsFlowLayout()) + item = layout.itemAt(0) + self.assertIs(item.parentLayoutItem(), layout) + layout.removeItem(item) + self.assertIs(item.parentLayoutItem(), None) diff --git a/Orange/widgets/visualize/owheatmap.py b/Orange/widgets/visualize/owheatmap.py index 6ed6376a9b8..6a6cc4713df 100644 --- a/Orange/widgets/visualize/owheatmap.py +++ b/Orange/widgets/visualize/owheatmap.py @@ -1,4 +1,5 @@ import enum +from collections import defaultdict from itertools import islice from typing import ( Iterable, Mapping, Any, TypeVar, Type, NamedTuple, Sequence, Optional, @@ -16,7 +17,8 @@ from AnyQt.QtCore import Qt, QSize, QRectF, QObject from orangewidget.utils.combobox import ComboBox, ComboBoxSearch -from Orange.data import Domain, Table, Variable +from Orange.data import Domain, Table, Variable, DiscreteVariable, \ + ContinuousVariable from Orange.data.sql.table import SqlTable import Orange.distance @@ -48,6 +50,20 @@ def kmeans_compress(X, k=50): return km.get_model(X) +def split_domain(domain: Domain, split_label: str): + """Split the domain based on values of `split_label` value. + """ + groups = defaultdict(list) + for var in domain.attributes: + val = var.attributes.get(split_label) + groups[val].append(var) + if None in groups: + na = groups.pop(None) + return [*groups.items(), ("N/A", na)] + else: + return list(groups.items()) + + def cbselect(cb: QComboBox, value, role: Qt.ItemDataRole = Qt.EditRole) -> None: """ Find and select the `value` in the `cb` QComboBox. @@ -168,8 +184,12 @@ class Outputs: annotation_var = settings.ContextSetting(None) #: color row annotation annotation_color_var = settings.ContextSetting(None) + column_annotation_color_key: Optional[Tuple[str, str]] = settings.ContextSetting(None) + # Discrete variable used to split that data/heatmaps (vertically) split_by_var = settings.ContextSetting(None) + # Split heatmap columns by 'key' (horizontal) + split_columns_key: Optional[Tuple[str, str]] = settings.ContextSetting(None) # Selected row/column clustering method (name) col_clustering_method: str = settings.Setting(Clustering.None_.name) row_clustering_method: str = settings.Setting(Clustering.None_.name) @@ -212,11 +232,7 @@ def __init__(self): self.row_clustering = enum_get( Clustering, self.row_clustering_method, Clustering.None_) - @self.settingsAboutToBePacked.connect - def _(): - self.col_clustering_method = self.col_clustering.name - self.row_clustering_method = self.row_clustering.name - + self.settingsAboutToBePacked.connect(self._save_state_for_serialization) self.keep_aspect = False #: The original data with all features (retained to @@ -226,6 +242,8 @@ def _(): #: merged using k-means self.data = None self.effective_data = None + #: Source of column annotations (derived from self.data) + self.col_annot_data: Optional[Table] = None #: kmeans model used to merge rows of input_data self.kmeans_model = None #: merge indices derived from kmeans @@ -300,6 +318,11 @@ def _(idx, cb=cb): form.addRow("Columns:", self.col_cluster_cb) cluster_box.layout().addLayout(form) box = gui.vBox(self.controlArea, "Split By") + form = QFormLayout( + formAlignment=Qt.AlignLeft, labelAlignment=Qt.AlignLeft, + fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow, + ) + box.layout().addLayout(form) self.row_split_model = DomainModel( placeholder="(None)", @@ -324,7 +347,25 @@ def _(idx, cb=cb): self.row_split_cb.activated.connect( self.__on_split_rows_activated ) - box.layout().addWidget(self.row_split_cb) + self.col_split_model = DomainModel( + placeholder="(None)", + order=DomainModel.MIXED, + valid_types=(Orange.data.DiscreteVariable,), + parent=self, + ) + self.col_split_cb = cb = ComboBoxSearch( + sizeAdjustPolicy=ComboBox.AdjustToMinimumContentsLength, + minimumContentsLength=14, + toolTip="Split the heatmap horizontally by column annotation" + ) + self.col_split_cb.setModel(self.col_split_model) + self.connect_control( + "split_columns_var", lambda value, cb=cb: cbselect(cb, value) + ) + self.split_columns_var = None + self.col_split_cb.activated.connect(self.__on_split_cols_activated) + form.addRow("Rows:", self.row_split_cb) + form.addRow("Columns:", self.col_split_cb) box = gui.vBox(self.controlArea, 'Annotation && Legends') @@ -366,13 +407,39 @@ def _(idx, cb=cb): form.addRow("Text", self.annotation_text_cb) form.addRow("Color", self.row_side_color_cb) box.layout().addWidget(annotbox) - posbox = gui.vBox(box, "Column Labels Position", addSpace=False) - posbox.setFlat(True) + annotbox = QGroupBox("Column annotations", flat=True) + form = QFormLayout( + annotbox, + formAlignment=Qt.AlignLeft, + labelAlignment=Qt.AlignLeft, + fieldGrowthPolicy=QFormLayout.AllNonFixedFieldsGrow + ) + self.col_side_color_model = DomainModel( + placeholder="(None)", + valid_types=(DiscreteVariable, ContinuousVariable), + parent=self + ) + self.col_side_color_cb = cb = ComboBoxSearch( + sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLength, + minimumContentsLength=12 + ) + self.col_side_color_cb.setModel(self.col_side_color_model) + self.connect_control( + "column_annotation_color_var", self.column_annotation_color_var_changed, + ) + self.column_annotation_color_var = None + self.col_side_color_cb.activated.connect( + self.__set_column_annotation_color_var_index) + cb = gui.comboBox( - posbox, self, "column_label_pos", + None, self, "column_label_pos", callback=self.update_column_annotations) cb.setModel(create_list_model(ColumnLabelsPosData, parent=self)) cb.setCurrentIndex(self.column_label_pos) + form.addRow("Color", self.col_side_color_cb) + form.addRow("Label position", cb) + box.layout().addWidget(annotbox) + gui.checkBox(self.controlArea, self, "keep_aspect", "Keep aspect ratio", box="Resize", callback=self.__aspect_mode_changed) @@ -411,6 +478,19 @@ class HeatmapScene(QGraphicsScene): ) self.addActions([self.__font_inc, self.__font_dec]) + def _save_state_for_serialization(self): + def desc(var: Optional[Variable]) -> Optional[Tuple[str, str]]: + if var is not None: + return type(var).__name__, var.name + else: + return None + + self.col_clustering_method = self.col_clustering.name + self.row_clustering_method = self.row_clustering.name + + self.column_annotation_color_key = desc(self.column_annotation_color_var) + self.split_columns_key = desc(self.split_columns_var) + @property def center_palette(self): palette = self.color_map_widget.currentData() @@ -461,9 +541,13 @@ def clear(self): self.annotation_model.set_domain(None) self.annotation_var = None self.row_side_color_model.set_domain(None) + self.col_side_color_model.set_domain(None) self.annotation_color_var = None + self.column_annotation_color_var = None self.row_split_model.set_domain(None) + self.col_split_model.set_domain(None) self.split_by_var = None + self.split_columns_var = None self.parts = None self.clear_scene() self.selected_rows = [] @@ -547,14 +631,40 @@ def set_dataset(self, data=None): self.annotation_var = None self.annotation_color_var = None self.row_split_model.set_domain(data.domain) + self.col_annot_data = data.transpose(data[:0].transform(Domain(data.domain.attributes))) + self.col_split_model.set_domain(self.col_annot_data.domain) + self.col_side_color_model.set_domain(self.col_annot_data.domain) if data.domain.has_discrete_class: self.split_by_var = data.domain.class_var else: self.split_by_var = None + self.split_columns_var = None + self.column_annotation_color_var = None self.openContext(self.input_data) if self.split_by_var not in self.row_split_model: self.split_by_var = None + def match(desc: Tuple[str, str], source: Iterable[Variable]): + for v in source: + if desc == (type(v).__name__, v.name): + return v + return None + + def is_variable(obj): + return isinstance(obj, Variable) + + if self.split_columns_key is not None: + self.split_columns_var = match( + self.split_columns_key, + filter(is_variable, self.col_split_model) + ) + + if self.column_annotation_color_key is not None: + self.column_annotation_color_var = match( + self.column_annotation_color_key, + filter(is_variable, self.col_side_color_model) + ) + self.update_heatmaps() if data is not None and self.__pending_selection is not None: assert self.scene.widget is not None @@ -577,6 +687,14 @@ def set_split_variable(self, var): self.split_by_var = var self.update_heatmaps() + def __on_split_cols_activated(self): + self.set_column_split_var(self.col_split_cb.currentData(Qt.EditRole)) + + def set_column_split_var(self, var: Optional[Variable]): + if var is not self.split_columns_var: + self.split_columns_var = var + self.update_heatmaps() + def update_heatmaps(self): if self.data is not None: self.clear_scene() @@ -591,7 +709,7 @@ def update_heatmaps(self): elif self.merge_kmeans and len(self.data) < 3: self.Error.not_enough_instances_k_means() else: - parts = self.construct_heatmaps(self.data, self.split_by_var) + parts = self.construct_heatmaps(self.data, self.split_by_var, self.split_columns_var) self.construct_heatmaps_scene(parts, self.effective_data) self.selected_rows = [] else: @@ -604,7 +722,7 @@ def update_merge(self): self.update_heatmaps() self.commit() - def _make_parts(self, data, group_var=None): + def _make_parts(self, data, group_var=None, column_split_key=None): """ Make initial `Parts` for data, split by group_var, group_key """ @@ -626,11 +744,20 @@ def _make_parts(self, data, group_var=None): row_groups = [RowPart(title=None, indices=range(0, len(data)), cluster=None, cluster_ordered=None)] - col_groups = [ - ColumnPart( - title=None, indices=range(0, len(data.domain.attributes)), - domain=data.domain, cluster=None, cluster_ordered=None) - ] + if column_split_key is not None: + col_groups = split_domain(data.domain, column_split_key) + assert len(col_groups) > 0 + col_indices = [np.array([data.domain.index(var) for var in group]) + for _, group in col_groups] + col_groups = [ColumnPart(title=str(name), domain=d, indices=ind, + cluster=None, cluster_ordered=None) + for (name, d), ind in zip(col_groups, col_indices)] + else: + col_groups = [ + ColumnPart( + title=None, indices=range(0, len(data.domain.attributes)), + domain=data.domain.attributes, cluster=None, cluster_ordered=None) + ] minv, maxv = np.nanmin(data.X), np.nanmax(data.X) return Parts(row_groups, col_groups, span=(minv, maxv)) @@ -668,42 +795,42 @@ def cluster_rows(self, data: Table, parts: 'Parts', ordered=False) -> 'Parts': return parts._replace(rows=row_groups) - def cluster_columns(self, data, parts, ordered=False): - assert len(parts.columns) == 1, "columns split is no longer supported" + def cluster_columns(self, data, parts: 'Parts', ordered=False): assert all(var.is_continuous for var in data.domain.attributes) + col_groups = [] + for col in parts.columns: + if col.cluster is not None: + cluster = col.cluster + else: + cluster = None + if col.cluster_ordered is not None: + cluster_ord = col.cluster_ordered + else: + cluster_ord = None + if col.can_cluster: + need_dist = cluster is None or (ordered and cluster_ord is None) + matrix = None + if need_dist: + subset = data.transform(Domain(col.domain)) + subset = Orange.distance._preprocess(subset) + matrix = np.asarray(Orange.distance.PearsonR(subset, axis=0)) + # nan values break clustering below + matrix = np.nan_to_num(matrix) - col0 = parts.columns[0] - if col0.cluster is not None: - cluster = col0.cluster - else: - cluster = None - if col0.cluster_ordered is not None: - cluster_ord = col0.cluster_ordered - else: - cluster_ord = None - need_dist = cluster is None or (ordered and cluster_ord is None) - matrix = None - if need_dist: - data = Orange.distance._preprocess(data) - matrix = np.asarray(Orange.distance.PearsonR(data, axis=0)) - # nan values break clustering below - matrix = np.nan_to_num(matrix) - - if cluster is None: - assert matrix is not None - assert len(matrix) < self.MaxClustering - cluster = hierarchical.dist_matrix_clustering( - matrix, linkage=hierarchical.WARD - ) - if ordered and cluster_ord is None: - assert len(matrix) < self.MaxOrderedClustering - cluster_ord = hierarchical.optimal_leaf_ordering(cluster, matrix) + if cluster is None: + assert matrix is not None + assert len(matrix) < self.MaxClustering + cluster = hierarchical.dist_matrix_clustering( + matrix, linkage=hierarchical.WARD + ) + if ordered and cluster_ord is None: + assert len(matrix) < self.MaxOrderedClustering + cluster_ord = hierarchical.optimal_leaf_ordering(cluster, matrix) - col_groups = [col._replace(cluster=cluster, cluster_ordered=cluster_ord) - for col in parts.columns] + col_groups.append(col._replace(cluster=cluster, cluster_ordered=cluster_ord)) return parts._replace(columns=col_groups) - def construct_heatmaps(self, data, group_var=None) -> 'Parts': + def construct_heatmaps(self, data, group_var=None, column_split_key=None) -> 'Parts': if self.merge_kmeans: if self.kmeans_model is None: effective_data = self.input_data.transform( @@ -740,7 +867,9 @@ def construct_heatmaps(self, data, group_var=None) -> 'Parts': self.__update_clustering_enable_state(effective_data) - parts = self._make_parts(effective_data, group_var) + parts = self._make_parts( + effective_data, group_var, + column_split_key.name if column_split_key is not None else None) # Restore/update the row/columns items descriptions from cache if # available rows_cache_key = (group_var, @@ -748,6 +877,10 @@ def construct_heatmaps(self, data, group_var=None) -> 'Parts': if rows_cache_key in self.__rows_cache: parts = parts._replace(rows=self.__rows_cache[rows_cache_key].rows) + if column_split_key in self.__columns_cache: + parts = parts._replace( + columns=self.__columns_cache[column_split_key].columns) + if self.row_clustering != Clustering.None_: parts = self.cluster_rows( effective_data, parts, @@ -806,9 +939,15 @@ def setup_scene(self, parts, data): col_names=columns, ) widget.setHeatmaps(parts) + side = self.row_side_colors() if side is not None: widget.setRowSideColorAnnotations(side[0], side[1], name=side[2].name) + + side = self.column_side_colors() + if side is not None: + widget.setColumnSideColorAnnotations(side[0], side[1], name=side[2].name) + widget.setColumnLabelsPosition(self._column_label_pos) widget.setAspectRatioMode( Qt.KeepAspectRatio if self.keep_aspect else Qt.IgnoreAspectRatio @@ -977,7 +1116,7 @@ def row_side_colors(self): merges = self._merge_row_indices() if merges is not None: column_data = aggregate(var, column_data, merges) - data, colormap = self._colorize(var, column_data) + data, colormap = colorize(var, column_data) if var.is_continuous: span = (np.nanmin(column_data), np.nanmax(column_data)) if np.any(np.isnan(span)): @@ -1003,27 +1142,30 @@ def update_row_side_colors(self): else: widget.setRowSideColorAnnotations(colors[0], colors[1], colors[2].name) - def _colorize(self, var: Variable, data: np.ndarray) -> Tuple[np.ndarray, ColorMap]: - palette = var.palette # type: Palette - colors = np.array( - [[c.red(), c.green(), c.blue()] for c in palette.qcolors_w_nan], - dtype=np.uint8, - ) - if var.is_discrete: - mask = np.isnan(data) - data[mask] = -1 - data = data.astype(int) - if mask.any(): - values = (*var.values, "N/A") + def __set_column_annotation_color_var_index(self, index: int): + key = self.col_side_color_cb.itemData(index, Qt.EditRole) + self.set_column_annotation_color_var(key) + + def column_annotation_color_var_changed(self, value): + cbselect(self.col_side_color_cb, value, Qt.EditRole) + + def set_column_annotation_color_var(self, var): + if self.column_annotation_color_var is not var: + self.column_annotation_color_var = var + colors = self.column_side_colors() + if colors is not None: + self.scene.widget.setColumnSideColorAnnotations( + colors[0], colors[1], colors[2].name, + ) else: - values = var.values - colors = colors[: -1] - return data, CategoricalColorMap(colors, values) - elif var.is_continuous: - cmap = GradientColorMap(colors[:-1]) - return data, cmap - else: - raise TypeError + self.scene.widget.setColumnSideColorAnnotations(None) + + def column_side_colors(self): + var = self.column_annotation_color_var + if var is None: + return None + table = self.col_annot_data + return color_annotation_data(table, var) def update_column_annotations(self): widget = self.scene.widget @@ -1175,6 +1317,13 @@ class ColumnPart(NamedTuple): cluster: Optional[hierarchical.Tree] = None cluster_ordered: Optional[hierarchical.Tree] = None + @property + def can_cluster(self) -> bool: + if isinstance(self.indices, slice): + return (self.indices.stop - self.indices.start) > 1 + else: + return len(self.indices) > 1 + class Parts(NamedTuple): rows: Sequence[RowPart] @@ -1225,6 +1374,40 @@ def column_data_from_table( return data +def color_annotation_data( + table: Table, var: Union[int, str, Variable] +) -> Tuple[np.ndarray, ColorMap, Variable]: + var = table.domain[var] + column_data = column_data_from_table(table, var) + data, colormap = colorize(var, column_data) + return data, colormap, var + + +def colorize(var: Variable, data: np.ndarray) -> Tuple[np.ndarray, ColorMap]: + palette = var.palette # type: Palette + colors = np.array( + [[c.red(), c.green(), c.blue()] for c in palette.qcolors_w_nan], + dtype=np.uint8, + ) + if var.is_discrete: + mask = np.isnan(data) + data = data.astype(int) + data[mask] = -1 + if mask.any(): + values = (*var.values, "N/A") + else: + values = var.values + colors = colors[: -1] + return data, CategoricalColorMap(colors, values) + elif var.is_continuous: + span = np.nanmin(data), np.nanmax(data) + if np.any(np.isnan(span)): + span = 0, 1. + return data, GradientColorMap(colors[:-1], span=span) + else: + raise TypeError + + def aggregate( var: Variable, data: np.ndarray, groupindices: Sequence[Sequence[int]], ) -> np.ndarray: diff --git a/Orange/widgets/visualize/tests/test_owheatmap.py b/Orange/widgets/visualize/tests/test_owheatmap.py index d0dbf978055..7d4c75d9119 100644 --- a/Orange/widgets/visualize/tests/test_owheatmap.py +++ b/Orange/widgets/visualize/tests/test_owheatmap.py @@ -222,6 +222,39 @@ def test_set_split_var_missing(self): self.assertEqual(len(w.parts.rows), len(data.domain.class_var.values) + 1) + def _brown_selected_10(self): + data = self.brown_selected[::5] + data = data.transform( + Domain(data.domain.attributes[:10], data.domain.class_vars, + data.domain.metas + (data.domain["diau g"],))) + data.ensure_copy() + return data + + def test_set_split_column_key(self): + data = self._brown_selected_10() + function = data.domain["function"] + data_t = data.transpose(data) + w = self.widget + self.send_signal(self.widget.Inputs.data, data_t, widget=w) + w.set_column_split_var(function) + self.assertEqual(len(w.parts.columns), len(function.values)) + w.set_column_split_var(None) + self.assertEqual(len(w.parts.columns), 1) + + def test_set_split_column_key_missing(self): + data = self._brown_selected_10() + data.Y[:5] = np.nan + data_t = data.transpose(data) + function = data.domain["function"] + w = self.widget + self.send_signal(self.widget.Inputs.data, data_t, widget=w) + w.set_column_split_var(function) + self.assertEqual(len(w.parts.columns), len(function.values) + 1) + ncols = sum(len(p.indices) for p in w.parts.columns) + self.assertEqual(ncols, len(data_t.domain.attributes)) + w.set_column_split_var(None) + self.assertEqual(len(w.parts.columns), 1) + def test_palette_centering(self): data = np.arange(2).reshape(-1, 1) table = Table.from_numpy(Domain([ContinuousVariable("y")]), data) @@ -277,10 +310,7 @@ def test_row_color_annotations(self): def test_row_color_annotations_with_na(self): widget = self.widget - data = self.brown_selected[::5] - data = data.transform( - Domain(data.domain.attributes[:10], data.domain.class_vars, - data.domain.metas + (data.domain["diau g"], ))) + data = self._brown_selected_10() data.Y[:3] = np.nan data.metas[:3, -1] = np.nan self.send_signal(widget.Inputs.data, data, widget=widget) @@ -295,6 +325,38 @@ def test_row_color_annotations_with_na(self): widget.set_annotation_color_var(None) self.assertFalse(widget.scene.widget.right_side_colors[0].isVisible()) + def test_col_color_annotations(self): + widget = self.widget + data = self._brown_selected_10() + data_t = data.transpose(data) + self.send_signal(widget.Inputs.data, data_t, widget=widget) + # discrete + widget.set_column_annotation_color_var(data.domain["function"]) + self.assertTrue(widget.scene.widget.top_side_colors[0].isVisible()) + # continuous + widget.set_column_annotation_color_var(data.domain["diau g"]) + widget.set_column_annotation_color_var(None) + self.assertFalse(widget.scene.widget.top_side_colors[0].isVisible()) + + def test_col_color_annotations_with_na(self): + widget = self.widget + data = self._brown_selected_10() + data.Y[:3] = np.nan + data.metas[:3, -1] = np.nan + data_t = data.transpose(data) + self.send_signal(widget.Inputs.data, data_t, widget=widget) + widget.set_column_annotation_color_var(data.domain["function"]) + self.assertTrue(widget.scene.widget.top_side_colors[0].isVisible()) + widget.set_column_annotation_color_var(data.domain["diau g"]) + data.Y[:] = np.nan + data.metas[:, -1] = np.nan + data_t = data.transpose(data) + self.send_signal(widget.Inputs.data, data_t, widget=widget) + widget.set_column_annotation_color_var(data.domain["function"]) + widget.set_column_annotation_color_var(data.domain["diau g"]) + widget.set_column_annotation_color_var(None) + self.assertFalse(widget.scene.widget.top_side_colors[0].isVisible()) + def test_summary(self): """Check if status bar is updated when data is received""" info = self.widget.info diff --git a/Orange/widgets/visualize/utils/graphicsrichtextwidget.py b/Orange/widgets/visualize/utils/graphicsrichtextwidget.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/Orange/widgets/visualize/utils/heatmap.py b/Orange/widgets/visualize/utils/heatmap.py index 433e57f1095..33c0f4c77e8 100644 --- a/Orange/widgets/visualize/utils/heatmap.py +++ b/Orange/widgets/visualize/utils/heatmap.py @@ -26,6 +26,7 @@ from Orange.widgets.utils import apply_all from Orange.widgets.utils.colorpalettes import DefaultContinuousPalette from Orange.widgets.utils.graphicslayoutitem import SimpleLayoutItem, scaled +from Orange.widgets.utils.graphicsflowlayout import GraphicsFlowLayout from Orange.widgets.utils.graphicspixmapwidget import GraphicsPixmapWidget from Orange.widgets.utils.image import qimage_from_array @@ -290,16 +291,20 @@ class Position(enum.IntFlag): # Start row/column where the heatmap items are inserted # (after the titles/legends/dendrograms) - Row0 = 3 + Row0 = 5 Col0 = 3 # The (color) legend row and column LegendRow, LegendCol = 0, 4 # The column for the vertical dendrogram DendrogramColumn = 1 + # Horizontal split title column + GroupTitleRow = 1 # The row for the horizontal dendrograms - DendrogramRow = 1 + DendrogramRow = 2 # The row for top column annotation labels - TopLabelsRow = 2 + TopLabelsRow = 3 + # Top color annotation row + TopAnnotationRow = 4 # Vertical split title column GroupTitleColumn = 0 @@ -322,6 +327,7 @@ def __init__(self, parent=None, **kwargs): self.col_dendrograms = [] # type: List[Optional[DendrogramWidget]] self.row_dendrograms = [] # type: List[Optional[DendrogramWidget]] self.right_side_colors = [] # type: List[Optional[GraphicsPixmapWidget]] + self.top_side_colors = [] # type: List[Optional[GraphicsPixmapWidget]] self.heatmap_colormap_legend = None self.bottom_legend_container = None self.__layout = GridLayout() @@ -353,6 +359,7 @@ def clear(self): self.col_dendrograms = [] self.row_dendrograms = [] self.right_side_colors = [] + self.top_side_colors = [] self.heatmap_colormap_legend = None self.bottom_legend_container = None self.parts = None @@ -373,12 +380,14 @@ def setHeatmaps(self, parts: 'Parts') -> None: # The row for the horizontal dendrograms DendrogramRow = self.DendrogramRow RightLabelColumn = Col0 + 2 * M + 1 + TopAnnotationRow = self.TopAnnotationRow TopLabelsRow = self.TopLabelsRow BottomLabelsRow = Row0 + N colormap = self.__colormap column_dendrograms: List[Optional[DendrogramWidget]] = [None] * M row_dendrograms: List[Optional[DendrogramWidget]] = [None] * N right_side_colors: List[Optional[GraphicsPixmapWidget]] = [None] * N + top_side_colors: List[Optional[GraphicsPixmapWidget]] = [None] * M data = parts.data if parts.col_names is None: @@ -426,9 +435,10 @@ def setHeatmaps(self, parts: 'Parts') -> None: if colitem.title: item = SimpleLayoutItem( QGraphicsSimpleTextItem(colitem.title, parent=self), - parent=grid + parent=grid, anchor=(0.5, 0.5), anchorItem=(0.5, 0.5) ) - grid.addItem(item, 1, Col0 + 2 * j + 1) + item.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Fixed) + grid.addItem(item, self.GroupTitleRow, Col0 + 2 * j + 1) if colitem.cluster: dendrogram = DendrogramWidget( @@ -501,8 +511,6 @@ def setHeatmaps(self, parts: 'Parts') -> None: objectName="row-labels-right" ) labelslist.setMaximumWidth(300) - pm = QPixmap(1, rowitem.size) - pm.fill(Qt.transparent) rowauxsidecolor = GraphicsPixmapWidget( parent=self, visible=False, scaleContents=True, aspectMode=Qt.IgnoreAspectRatio, @@ -529,10 +537,20 @@ def setHeatmaps(self, parts: 'Parts') -> None: visible=self.__columnLabelPosition & Position.Top, objectName="column-labels-top", ) + colauxsidecolor = GraphicsPixmapWidget( + parent=self, visible=False, + scaleContents=True, aspectMode=Qt.IgnoreAspectRatio, + sizePolicy=QSizePolicy(QSizePolicy.Ignored, + QSizePolicy.Maximum), + minimumSize=QSizeF(-1, 10) + ) + grid.addItem(labelslist, TopLabelsRow, Col0 + 2 * j + 1, Qt.AlignBottom | Qt.AlignLeft) + grid.addItem(colauxsidecolor, TopAnnotationRow, Col0 + 2 * j + 1) col_annotation_widgets.append(labelslist) col_annotation_widgets_top.append(labelslist) + top_side_colors[j] = colauxsidecolor # Bottom attr annotations labelslist = TextListWidget( @@ -552,11 +570,21 @@ def setHeatmaps(self, parts: 'Parts') -> None: row_color_annotation_header.rotate(-90) grid.addItem(SimpleLayoutItem( - row_color_annotation_header, anchor=(0, 1), resizeContents=True, + row_color_annotation_header, anchor=(0, 1), aspectMode=Qt.KeepAspectRatio, - sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Preferred), + sizePolicy=QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Preferred), ), - self.TopLabelsRow, RightLabelColumn - 1, + 0, RightLabelColumn - 1, self.TopLabelsRow + 1, 1, + alignment=Qt.AlignBottom + ) + + col_color_annotation_header = QGraphicsSimpleTextItem("", self) + grid.addItem(SimpleLayoutItem( + col_color_annotation_header, anchor=(1, 1), anchorItem=(1, 1), + aspectMode=Qt.KeepAspectRatio, + sizePolicy=QSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed), + ), + TopAnnotationRow, 0, 1, Col0, alignment=Qt.AlignRight ) legend = GradientLegendWidget( @@ -568,13 +596,39 @@ def setHeatmaps(self, parts: 'Parts') -> None: sizePolicy=QSizePolicy(QSizePolicy.Ignored, QSizePolicy.Fixed) ) legend.setMaximumWidth(300) - grid.addItem(legend, self.LegendRow, self.LegendCol) - legend_container = QGraphicsWidget( - visible=False, - sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed) + grid.addItem(legend, self.LegendRow, self.LegendCol, 1, M * 2 - 1) + + def container(parent=None, orientation=Qt.Horizontal, margin=0, spacing=0, **kwargs): + widget = QGraphicsWidget(**kwargs) + layout = QGraphicsLinearLayout(orientation) + layout.setContentsMargins(margin, margin, margin, margin) + layout.setSpacing(spacing) + widget.setLayout(layout) + if parent is not None: + widget.setParentItem(parent) + + return widget + # Container for color annotation legends + legend_container = container( + spacing=3, + sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed), + visible=False, objectName="annotation-legend-container" ) - legend_container.setLayout(QGraphicsLinearLayout()) - legend_container.layout().setContentsMargins(0, 0, 0, 0) + legend_container_rows = container( + parent=legend_container, + sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed), + visible=False, objectName="row-annotation-legend-container" + ) + legend_container_cols = container( + parent=legend_container, + sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed), + visible=False, objectName="col-annotation-legend-container" + ) + # ? keep refs to child containers; segfault in scene.clear() ? + legend_container._refs = (legend_container_rows, legend_container_cols) + legend_container.layout().addItem(legend_container_rows) + legend_container.layout().addItem(legend_container_cols) + grid.addItem(legend_container, BottomLabelsRow + 1, Col0 + 1, 1, M * 2 - 1, alignment=Qt.AlignRight) @@ -586,6 +640,7 @@ def setHeatmaps(self, parts: 'Parts') -> None: self.col_dendrograms = column_dendrograms self.row_dendrograms = row_dendrograms self.right_side_colors = right_side_colors + self.top_side_colors = top_side_colors self.heatmap_colormap_legend = legend self.bottom_legend_container = legend_container self.parts = parts @@ -688,50 +743,113 @@ def setRowSideColorAnnotations( name: str Name/title for the annotation column. """ - items = self.right_side_colors col = self.Col0 + 2 * len(self.parts.columns) + legend_layout = self.bottom_legend_container.layout() + legend_container = legend_layout.itemAt(1) + self.__setColorAnnotationsHelper( + data, colormap, name, self.right_side_colors, col, Qt.Vertical, + legend_container + ) + legend_container.setVisible(True) + + def setColumnSideColorAnnotations( + self, data: np.ndarray, colormap: ColorMap = None, name="" + ): + """ + Set an optional column color annotations. + + Parameters + ---------- + data: Optional[np.ndarray] + A sequence such that it is accepted by `colormap.apply`. If None + then the color annotations are cleared. + colormap: ColorMap + name: str + Name/title for the annotation column. + """ + row = self.TopAnnotationRow + legend_layout = self.bottom_legend_container.layout() + legend_container = legend_layout.itemAt(0) + self.__setColorAnnotationsHelper( + data, colormap, name, self.top_side_colors, row, Qt.Horizontal, + legend_container) + legend_container.setVisible(True) + + def __setColorAnnotationsHelper( + self, data: np.ndarray, colormap: ColorMap, name: str, + items: List[GraphicsPixmapWidget], position: int, + orientation: Qt.Orientation, legend_container: QGraphicsWidget): layout = self.__layout - nameitem = layout.itemAt(self.TopLabelsRow, col) - width = QFontMetrics(self.font()).lineSpacing() - legend_container = self.bottom_legend_container + if orientation == Qt.Horizontal: + nameitem = layout.itemAt(position, 0) + else: + nameitem = layout.itemAt(self.TopLabelsRow, position) + size = QFontMetrics(self.font()).lineSpacing() layout_clear(legend_container.layout()) + def grid_set_maximum_size(position: int, size: float): + if orientation == Qt.Horizontal: + layout.setRowMaximumHeight(position, size) + else: + layout.setColumnMaximumWidth(position, size) + + def set_minimum_size(item: QGraphicsLayoutItem, size: float): + if orientation == Qt.Horizontal: + item.setMinimumHeight(size) + else: + item.setMinimumWidth(size) + item.updateGeometry() + + def reset_minimum_size(item: QGraphicsLayoutItem): + set_minimum_size(item, -1) + def set_hidden(item: GraphicsPixmapWidget): item.setVisible(False) - item.setMinimumWidth(-1) - item.updateGeometry() + reset_minimum_size(item,) def set_visible(item: GraphicsPixmapWidget): item.setVisible(True) - item.setMinimumWidth(10) + set_minimum_size(item, 10) + + def set_preferred_size(item, size): + if orientation == Qt.Horizontal: + item.setPreferredHeight(size) + else: + item.setPreferredWidth(size) item.updateGeometry() if data is None: apply_all(filter(None, items), set_hidden) - layout.setColumnMaximumWidth(col, 0) + grid_set_maximum_size(position, 0) + nameitem.item.setVisible(False) nameitem.updateGeometry() legend_container.setVisible(False) return else: apply_all(filter(None, items), set_visible) - layout.setColumnMaximumWidth(col, FLT_MAX) + grid_set_maximum_size(position, FLT_MAX) legend_container.setVisible(True) - parts = self.parts.rows + if orientation == Qt.Horizontal: + parts = self.parts.columns + else: + parts = self.parts.rows for p, item in zip(parts, items): if item is not None: subset = data[p.normalized_indices] subset = colormap.apply(subset) - img = qimage_from_array(subset.reshape((-1, 1, subset.shape[-1]))) + rgbdata = subset.reshape((-1, 1, subset.shape[-1])) + if orientation == Qt.Horizontal: + rgbdata = rgbdata.reshape((1, -1, rgbdata.shape[-1])) + img = qimage_from_array(rgbdata) item.setPixmap(img) item.setVisible(True) - item.setPreferredWidth(width) + set_preferred_size(item, size) nameitem.item.setText(name) nameitem.item.setVisible(True) - nameitem.setPreferredWidth(width) - nameitem.updateGeometry() + set_preferred_size(nameitem, size) container = legend_container.layout() if isinstance(colormap, CategoricalColorMap): @@ -744,8 +862,8 @@ def set_visible(item: GraphicsPixmapWidget): container.addItem(legend) elif isinstance(colormap, GradientColorMap): legend = GradientLegendWidget( - *colormap.span, colormap, - sizePolicy=QSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Maximum) + *colormap.span, colormap, title=name, + sizePolicy=QSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Maximum), ) legend.setMinimumWidth(100) container.addItem(legend) @@ -1132,10 +1250,20 @@ def boundingRect(self): br = br.adjusted(-adjust, 0, adjust, 0) return br + def showEvent(self, event): + super().showEvent(event) + # AxisItem resizes to 0 width/height when hidden, does not update when + # shown implicitly (i.e. a parent becomes visible). + # Use showLabel(False) which should update the size without actually + # changing anything else (no public interface to explicitly recalc + # fixed sizes). + self.showLabel(False) + class GradientLegendWidget(QGraphicsWidget): def __init__( - self, low, high, colormap: GradientColorMap, parent=None, **kwargs + self, low, high, colormap: GradientColorMap, parent=None, title="", + **kwargs ): kwargs.setdefault( "sizePolicy", QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) @@ -1144,11 +1272,18 @@ def __init__( self.low = low self.high = high self.colormap = colormap + self.title = title layout = QGraphicsLinearLayout(Qt.Vertical) layout.setContentsMargins(0, 0, 0, 0) layout.setSpacing(0) self.setLayout(layout) + if title: + titleitem = SimpleLayoutItem( + QGraphicsSimpleTextItem(title, self), parent=layout, + anchor=(0.5, 1.), anchorItem=(0.5, 1.0) + ) + layout.addItem(titleitem) self.__axis = axis = _GradientLegendAxisItem( orientation="top", maxTickLength=3) axis.setRange(low, high) @@ -1213,8 +1348,11 @@ def __init__( self.__colormap = colormap self.__title = title self.__names = colormap.names - self.__layout = QGraphicsGridLayout() - self.__layout.setSpacing(2) + self.__layout = QGraphicsLinearLayout(Qt.Vertical) + self.__flow = GraphicsFlowLayout() + self.__layout.addItem(self.__flow) + self.__flow.setHorizontalSpacing(4) + self.__flow.setVerticalSpacing(4) self.__orientation = orientation kwargs.setdefault( "sizePolicy", QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Maximum) @@ -1235,10 +1373,16 @@ def orientation(self): return self.__orientation def _clear(self): - items = reversed(list(layout_items(self.__layout))) + items = list(layout_items(self.__flow)) + layout_clear(self.__flow) for item in items: - self.__layout.removeItem(item) + if isinstance(item, SimpleLayoutItem): + remove_item(item.item) + # remove 'title' item if present + items = [item for item in layout_items(self.__layout) + if item is not self.__flow] for item in items: + self.__layout.removeItem(item) if isinstance(item, SimpleLayoutItem): remove_item(item.item) @@ -1248,53 +1392,39 @@ def _setup(self): names = self.__colormap.names title = self.__title layout = self.__layout - assert layout.count() == 0 + flow = self.__flow + assert flow.count() == 0 font = self.font() fm = QFontMetrics(font) size = fm.width("X") - start = 0 + headeritem = None if title: - start = 1 - item = QGraphicsSimpleTextItem(title) - item.setFont(font) headeritem = QGraphicsSimpleTextItem(title) headeritem.setFont(font) - else: - headeritem = None - items = [] - for i, (color, label) in enumerate(zip(colors, names), start=start): - colitem = QGraphicsRectItem(0, 0, size, size) - colitem.setBrush(QColor(*color)) + def centered(item): + return SimpleLayoutItem(item, anchor=(0.5, 0.5), anchorItem=(0.5, 0.5)) + + def legend_item_pair(color: QColor, size: float, text: str): + coloritem = QGraphicsRectItem(0, 0, size, size) + coloritem.setBrush(color) textitem = QGraphicsSimpleTextItem() textitem.setFont(font) - textitem.setText(label) - items.append((colitem, textitem)) + textitem.setText(text) + layout = QGraphicsLinearLayout(Qt.Horizontal) + layout.setSpacing(2) + layout.addItem(centered(coloritem)) + layout.addItem(SimpleLayoutItem(textitem)) + return coloritem, textitem, layout - def centered(item): - return SimpleLayoutItem(item, anchor=(0.5, 0.5), anchorItem=(0.5, 0.5)) + items = [legend_item_pair(QColor(*color), size, name) + for color, name in zip(colors, names)] - def addrowspan(item): - layout.addItem(centered(item), layout.rowCount(), 0, 1, 2) + for sym, label, pair_layout in items: + flow.addItem(pair_layout) - def addrow(symbol, label): - row = layout.rowCount() - layout.addItem(centered(symbol), row, 0) - layout.addItem( - SimpleLayoutItem(label), row, 1, - alignment=Qt.AlignVCenter | Qt.AlignLeft - ) - if self.__orientation == Qt.Vertical: - if headeritem: - addrowspan(headeritem) - apply_all(items, lambda el: addrow(*el)) - else: - for sym, label in items: - layout.addItem(centered(sym), 1, layout.columnCount()) - layout.addItem(SimpleLayoutItem(label), 1, layout.columnCount()) - if headeritem: - layout.addItem( - centered(headeritem), 0, 0, 1, layout.columnCount()) + if headeritem: + layout.insertItem(0, centered(headeritem)) def changeEvent(self, event: QEvent) -> None: if event.type() == QEvent.FontChange: @@ -1305,7 +1435,7 @@ def _updateFont(self, font): w = QFontMetrics(font).width("X") for item in filter( lambda item: isinstance(item, SimpleLayoutItem), - layout_items(self.__layout) + layout_items_recursive(self.__layout) ): if isinstance(item.item, QGraphicsSimpleTextItem): item.item.setFont(font) @@ -1320,6 +1450,16 @@ def layout_items(layout: QGraphicsLayout) -> Iterable[QGraphicsLayoutItem]: yield item +def layout_items_recursive(layout: QGraphicsLayout): + for item in map(layout.itemAt, range(layout.count())): + if item is not None: + if item.isLayout(): + assert isinstance(item, QGraphicsLayout) + yield from layout_items_recursive(item) + else: + yield item + + def layout_clear(layout: QGraphicsLayout) -> None: for i in reversed(range(layout.count())): item = layout.itemAt(i) diff --git a/Orange/widgets/visualize/utils/tests/test_heatmap.py b/Orange/widgets/visualize/utils/tests/test_heatmap.py index 4b26f7a03fd..1aa48c6af4d 100644 --- a/Orange/widgets/visualize/utils/tests/test_heatmap.py +++ b/Orange/widgets/visualize/utils/tests/test_heatmap.py @@ -1,6 +1,7 @@ import numpy as np from AnyQt.QtCore import Qt, QPoint +from AnyQt.QtGui import QFont from AnyQt.QtTest import QTest, QSignalSpy from AnyQt.QtWidgets import QGraphicsScene, QGraphicsView @@ -8,11 +9,11 @@ from orangewidget.tests.base import GuiTest from Orange.clustering.hierarchical import Tree, SingletonData, ClusterData -from Orange.widgets.visualize.utils.heatmap import HeatmapGridWidget, ColorMap, \ - GradientColorMap, CategoricalColorMap +from Orange.widgets.visualize.utils.heatmap import HeatmapGridWidget, \ + GradientColorMap, CategoricalColorMap, CategoricalColorLegend -class TestHeatmapGridWidget(GuiTest): +class _GraphicsGuiTest(GuiTest): scene: QGraphicsScene view: QGraphicsView @@ -30,6 +31,8 @@ def tearDown(self) -> None: self.view = None super().tearDown() + +class TestHeatmapGridWidget(_GraphicsGuiTest): _c2 = Tree(ClusterData((0, 1), 0.5), ( Tree(SingletonData((0, 0), 0, 0), ()), Tree(SingletonData((1, 1), 0, 1), ()), @@ -170,3 +173,20 @@ def test_colormap(self): w.setHeatmaps(self._Data["2-2"]) w.setColorMap(GradientColorMap([[255] * 3, [0] * 3])) w.setColorMap(GradientColorMap([[255] * 3, [0] * 3], center=0)) + + +class TestCategoricalColorLegend(_GraphicsGuiTest): + def ensure_scene_polished(self): + self.view.grab() + + def test_font_propagation(self): + cmap = CategoricalColorMap(np.array([[255] * 3, [0] * 3]), + names=["a", "b"]) + w = CategoricalColorLegend(cmap, title="Title") + self.scene.addItem(w) + font = QFont("Windings") + w.setFont(font) + # needs to be polished for FontChange to be delivered + self.ensure_scene_polished() + self.assertEqual(w.layout().itemAt(0).item.font().family(), + font.family())