Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions Orange/widgets/utils/colorgradientselection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,29 @@

from AnyQt.QtCore import Qt, QSize, QAbstractItemModel, Property
from AnyQt.QtWidgets import (
QWidget, QSlider, QFormLayout, QComboBox, QStyle
)
QWidget, QSlider, QFormLayout, QComboBox, QStyle,
QHBoxLayout, QLineEdit, QLabel)
from AnyQt.QtCore import Signal
from AnyQt.QtGui import QFontMetrics, QDoubleValidator

from Orange.widgets.utils import itemmodels
from Orange.widgets.utils import itemmodels, colorpalettes


class ColorGradientSelection(QWidget):
activated = Signal(int)

currentIndexChanged = Signal(int)
thresholdsChanged = Signal(float, float)
centerChanged = Signal(float)

def __init__(self, *args, thresholds=(0.0, 1.0), **kwargs):
def __init__(self, *args, thresholds=(0.0, 1.0), center=None, **kwargs):
super().__init__(*args, **kwargs)

low = round(clip(thresholds[0], 0., 1.), 2)
high = round(clip(thresholds[1], 0., 1.), 2)
high = max(low, high)
self.__threshold_low, self.__threshold_high = low, high
self.__center = center
form = QFormLayout(
formAlignment=Qt.AlignLeft,
labelAlignment=Qt.AlignLeft,
Expand All @@ -43,6 +46,29 @@ def __init__(self, *args, thresholds=(0.0, 1.0), **kwargs):
self.gradient_cb.activated[int].connect(self.activated)
self.gradient_cb.currentIndexChanged.connect(self.currentIndexChanged)

if center is not None:
def __on_center_changed():
self.__center = float(self.center_edit.text() or "0")
self.centerChanged.emit(self.__center)

self.center_box = QWidget()
center_layout = QHBoxLayout()
self.center_box.setLayout(center_layout)
width = QFontMetrics(self.font()).boundingRect("9999999").width()
self.center_edit = QLineEdit(
text=f"{self.__center}",
maximumWidth=width, placeholderText="0", alignment=Qt.AlignRight)
self.center_edit.setValidator(QDoubleValidator())
self.center_edit.editingFinished.connect(__on_center_changed)
center_layout.setContentsMargins(0, 0, 0, 0)
center_layout.addStretch(1)
center_layout.addWidget(QLabel("Centered at"))
center_layout.addWidget(self.center_edit)
self.gradient_cb.currentIndexChanged.connect(
self.__update_center_visibility)
else:
self.center_box = None

slider_low = QSlider(
objectName="threshold-low-slider", minimum=0, maximum=100,
value=int(low * 100), orientation=Qt.Horizontal,
Expand All @@ -60,6 +86,8 @@ def __init__(self, *args, thresholds=(0.0, 1.0), **kwargs):
"gradient from the higher end")
)
form.setWidget(0, QFormLayout.SpanningRole, self.gradient_cb)
if self.center_box:
form.setWidget(1, QFormLayout.SpanningRole, self.center_box)
form.addRow(self.tr("Low:"), slider_low)
form.addRow(self.tr("High:"), slider_high)
self.slider_low = slider_low
Expand All @@ -79,6 +107,7 @@ def findData(self, data: Any, role: Qt.ItemDataRole) -> int:

def setCurrentIndex(self, index: int) -> None:
self.gradient_cb.setCurrentIndex(index)
self.__update_center_visibility()

def currentIndex(self) -> int:
return self.gradient_cb.currentIndex()
Expand Down Expand Up @@ -109,6 +138,14 @@ def thresholdHigh(self) -> float:
def setThresholdHigh(self, high: float) -> None:
self.setThresholds(min(self.__threshold_low, high), high)

def center(self) -> float:
return self.__center

def setCenter(self, center: float) -> None:
self.__center = center
self.center_edit.setText(f"{center}")
self.centerChanged.emit(center)

thresholdHigh_ = Property(
float, thresholdLow, setThresholdLow, notify=thresholdsChanged)

Expand Down Expand Up @@ -146,6 +183,15 @@ def setThresholds(self, low: float, high: float) -> None:
self.slider_high.setSliderPosition(high * 100)
self.thresholdsChanged.emit(high, low)

def __update_center_visibility(self):
if self.center_box is None:
return

palette = self.currentData()
self.center_box.setVisible(
isinstance(palette, colorpalettes.Palette)
and palette.flags & palette.Flags.Diverging != 0)


def clip(a, amin, amax):
return min(max(a, amin), amax)
43 changes: 42 additions & 1 deletion Orange/widgets/utils/tests/test_colorgradientselection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from unittest.mock import Mock

import numpy as np

from AnyQt.QtTest import QSignalSpy
from AnyQt.QtCore import Qt, QStringListModel
from AnyQt.QtCore import Qt, QStringListModel, QModelIndex

from Orange.widgets.utils import itemmodels
from Orange.widgets.utils.colorgradientselection import ColorGradientSelection
from Orange.widgets.tests.base import GuiTest


class TestColorGradientSelection(GuiTest):
def test_constructor(self):
w = ColorGradientSelection(thresholds=(0.1, 0.9))
Expand Down Expand Up @@ -68,3 +72,40 @@ def test_slider_move(self):
low, high = changed[-1]
self.assertLessEqual(low, high)
self.assertEqual(high, 0.0)

def test_center(self):
w = ColorGradientSelection(center=42)
self.assertEqual(w.center(), 42)
w.setCenter(40)
self.assertEqual(w.center(), 40)

def test_center_visibility(self):
w = ColorGradientSelection(center=0)
w.center_box.setVisible = Mock()
model = itemmodels.ContinuousPalettesModel()
w.setModel(model)
for row in range(model.rowCount(QModelIndex())):
palette = model.data(model.index(row, 0), Qt.UserRole)
if palette:
if palette.flags & palette.Diverging:
diverging = row
else:
nondiverging = row

w.setCurrentIndex(diverging)
w.center_box.setVisible.assert_called_with(True)
w.setCurrentIndex(nondiverging)
w.center_box.setVisible.assert_called_with(False)
w.setCurrentIndex(diverging)
w.center_box.setVisible.assert_called_with(True)

w = ColorGradientSelection()
self.assertIsNone(w.center_box)

def test_center_changed(self):
w = ColorGradientSelection(center=42)
changed = QSignalSpy(w.centerChanged)
w.center_edit.setText("41")
w.center_edit.editingFinished.emit()
self.assertEqual(w.center(), 41)
self.assertEqual(list(changed), [[41]])
16 changes: 15 additions & 1 deletion Orange/widgets/visualize/owheatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class Outputs:

threshold_low = settings.Setting(0.0)
threshold_high = settings.Setting(1.0)
color_center = settings.Setting(0)

merge_kmeans = settings.Setting(False)
merge_kmeans_k = settings.Setting(50)
Expand Down Expand Up @@ -202,6 +203,12 @@ class Error(widget.OWWidget.Error):
class Warning(widget.OWWidget.Warning):
empty_clusters = Msg("Empty clusters were removed")

UserAdviceMessages = [
widget.Message(
"For data with a meaningful mid-point, "
"choose one of diverging palettes.",
"diverging_palette")]

def __init__(self):
super().__init__()
self.__pending_selection = self.selected_rows
Expand Down Expand Up @@ -244,6 +251,7 @@ def _():

self.color_map_widget = cmw = ColorGradientSelection(
thresholds=(self.threshold_low, self.threshold_high),
center=self.color_center
)
model = itemmodels.ContinuousPalettesModel(parent=self)
cmw.setModel(model)
Expand All @@ -257,6 +265,12 @@ def _set_thresholds(low, high):
self.threshold_low, self.threshold_high = low, high
self.update_color_schema()
cmw.thresholdsChanged.connect(_set_thresholds)

def _set_centering(center):
self.color_center = center
self.update_color_schema()
cmw.centerChanged.connect(_set_centering)

colorbox.layout().addWidget(self.color_map_widget)

mergebox = gui.vBox(self.controlArea, "Merge",)
Expand Down Expand Up @@ -449,7 +463,7 @@ def color_palette(self):
def color_map(self) -> GradientColorMap:
return GradientColorMap(
self.color_palette(), (self.threshold_low, self.threshold_high),
0 if self.center_palette else None
self.color_map_widget.center() if self.center_palette else None
)

def clear(self):
Expand Down
17 changes: 16 additions & 1 deletion Orange/widgets/visualize/tests/test_owheatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# pylint: disable=missing-docstring, protected-access
import warnings
import unittest
from unittest.mock import patch
from unittest.mock import patch, Mock

import numpy as np
from sklearn.exceptions import ConvergenceWarning
Expand Down Expand Up @@ -245,6 +245,21 @@ def test_palette_centering(self):
colors = image_row_colors(image)
np.testing.assert_almost_equal(colors, desired)

def test_centering_threshold_change(self):
data = np.arange(2).reshape(-1, 1)
table = Table.from_numpy(Domain([ContinuousVariable("y")]), data)
self.send_signal(self.widget.Inputs.data, table)

cmw = self.widget.color_map_widget
palette_index = cmw.findData(
colorpalettes.ContinuousPalettes["diverging_bwr_40_95_c42"],
Qt.UserRole)
cmw.setCurrentIndex(palette_index)

self.widget.update_color_schema = Mock()
cmw.centerChanged.emit(42)
self.widget.update_color_schema.assert_called()

def test_palette_center(self):
widget = self.widget
model = widget.color_map_widget.model()
Expand Down
2 changes: 1 addition & 1 deletion doc/visual-programming/source/widgets/visualize/heatmap.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Plots a heat map for a pair of attributes.

![](images/HeatMap-stamped.png)

1. The color scheme legend. **Low** and **High** are thresholds for the color palette (low for attributes with low values and high for attributes with high values).
1. The color scheme legend. **Low** and **High** are thresholds for the color palette (low for attributes with low values and high for attributes with high values). Selecting one of diverging palettes, which have two extreme colors and a neutral (black or white) color at the midpoint, enables an option to set a meaningful mid-point value (default is 0).
2. Merge data.
3. Sort columns and rows:
- **No Sorting** (lists attributes as found in the dataset)
Expand Down