Skip to content

Commit 765ef7c

Browse files
committed
owheatmap: option to center palette
1 parent 89f31fc commit 765ef7c

File tree

2 files changed

+56
-13
lines changed

2 files changed

+56
-13
lines changed

Orange/widgets/visualize/owheatmap.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,12 @@ def color_palette_table(colors,
118118
return np.c_[r, g, b]
119119

120120

121-
def levels_with_thresholds(low, high, threshold_low, threshold_high):
121+
def levels_with_thresholds(low, high, threshold_low, threshold_high, center_palette):
122122
lt = low + (high - low) * threshold_low
123123
ht = low + (high - low) * threshold_high
124+
if center_palette:
125+
ht = max(abs(lt), abs(ht))
126+
lt = -max(abs(lt), abs(ht))
124127
return lt, ht
125128

126129

@@ -317,6 +320,8 @@ class Outputs:
317320
selected_data = Output("Selected Data", Table, default=True)
318321
annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
319322

323+
settings_version = 2
324+
320325
settingsHandler = settings.DomainContextHandler()
321326

322327
NoPosition, PositionTop, PositionBottom = 0, 1, 2
@@ -327,6 +332,7 @@ class Outputs:
327332
gamma = settings.Setting(0)
328333
threshold_low = settings.Setting(0.0)
329334
threshold_high = settings.Setting(1.0)
335+
center_palette = settings.Setting(False)
330336

331337
merge_kmeans = settings.Setting(False)
332338
merge_kmeans_k = settings.Setting(50)
@@ -447,6 +453,9 @@ def __init__(self):
447453

448454
colorbox.layout().addLayout(form)
449455

456+
gui.checkBox(colorbox, self, 'center_palette', 'Center colors at 0',
457+
callback=self.update_color_schema)
458+
450459
mergebox = gui.vBox(self.controlArea, "Merge",)
451460
gui.checkBox(mergebox, self, "merge_kmeans", "Merge by k-means",
452461
callback=self.update_sorting_examples)
@@ -982,7 +991,7 @@ def setup_scene(self, parts, data):
982991

983992
hw.set_levels(parts.levels)
984993
hw.set_thresholds(self.threshold_low, self.threshold_high)
985-
hw.set_color_table(palette)
994+
hw.set_color_table(palette, self.center_palette)
986995
hw.set_show_averages(self.averages)
987996
hw.set_heatmap_data(X_part)
988997

@@ -1057,7 +1066,7 @@ def setup_scene(self, parts, data):
10571066
parts.levels[0], parts.levels[1], self.threshold_low, self.threshold_high,
10581067
parent=widget)
10591068

1060-
legend.set_color_table(palette)
1069+
legend.set_color_table(palette, self.center_palette)
10611070
legend.setMinimumSize(QSizeF(100, 20))
10621071
legend.setVisible(self.legend)
10631072

@@ -1318,11 +1327,11 @@ def update_color_schema(self):
13181327
palette = self.color_palette()
13191328
for heatmap in self.heatmap_widgets():
13201329
heatmap.set_thresholds(self.threshold_low, self.threshold_high)
1321-
heatmap.set_color_table(palette)
1330+
heatmap.set_color_table(palette, self.center_palette)
13221331

13231332
for legend in self.legend_widgets():
13241333
legend.set_thresholds(self.threshold_low, self.threshold_high)
1325-
legend.set_color_table(palette)
1334+
legend.set_color_table(palette, self.center_palette)
13261335

13271336
def update_sorting_examples(self):
13281337
self.update_heatmaps()
@@ -1601,6 +1610,7 @@ def __init__(self, parent=None, data=None, **kwargs):
16011610

16021611
self.__levels = None
16031612
self.__threshold_low, self.__threshold_high = 0., 1.
1613+
self.__center_palette = False
16041614
self.__colortable = None
16051615
self.__data = data
16061616

@@ -1677,8 +1687,9 @@ def set_show_averages(self, show):
16771687
self.layout().invalidate()
16781688
self.update()
16791689

1680-
def set_color_table(self, table):
1690+
def set_color_table(self, table, center):
16811691
self.__colortable = table
1692+
self.__center_palette = center
16821693
self._update_pixmap()
16831694
self.update()
16841695

@@ -1699,7 +1710,8 @@ def _update_pixmap(self):
16991710
lut = None
17001711

17011712
ll, lh = self.__levels
1702-
ll, lh = levels_with_thresholds(ll, lh, self.__threshold_low, self.__threshold_high)
1713+
ll, lh = levels_with_thresholds(ll, lh, self.__threshold_low, self.__threshold_high,
1714+
self.__center_palette)
17031715

17041716
argb, _ = pg.makeARGB(
17051717
self.__data, lut=lut, levels=(ll, lh))
@@ -2058,6 +2070,7 @@ def __init__(self, low, high, threshold_low, threshold_high, parent=None):
20582070
self.high = high
20592071
self.threshold_low = threshold_low
20602072
self.threshold_high = threshold_high
2073+
self.center_palette = False
20612074
self.color_table = None
20622075

20632076
layout = QGraphicsLinearLayout(Qt.Vertical)
@@ -2084,8 +2097,9 @@ def __init__(self, low, high, threshold_low, threshold_high, parent=None):
20842097
layout.addItem(self.__pixitem)
20852098
self.__update()
20862099

2087-
def set_color_table(self, color_table):
2100+
def set_color_table(self, color_table, center):
20882101
self.color_table = color_table
2102+
self.center_palette = center
20892103
self.__update()
20902104

20912105
def set_thresholds(self, threshold_low, threshold_high):
@@ -2097,7 +2111,8 @@ def __update(self):
20972111
data = np.linspace(self.low, self.high, num=1000)
20982112
data = data.reshape((1, -1))
20992113
ll, lh = levels_with_thresholds(self.low, self.high,
2100-
self.threshold_low, self.threshold_high)
2114+
self.threshold_low, self.threshold_high,
2115+
self.center_palette)
21012116
argb, _ = pg.makeARGB(data, lut=self.color_table,
21022117
levels=(ll, lh))
21032118
qimg = pg.makeQImage(argb, transpose=False)

Orange/widgets/visualize/tests/test_owheatmap.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
from Orange.widgets.tests.base import WidgetTest, WidgetOutputsTestMixin, datasets
1414

1515

16+
def image_row_colors(image):
17+
colors = np.full((image.height(), 3), np.nan)
18+
for r in range(image.height()):
19+
c = image.pixelColor(0, r)
20+
colors[r] = c.red(), c.green(), c.blue()
21+
return colors
22+
23+
1624
class TestOWHeatMap(WidgetTest, WidgetOutputsTestMixin):
1725
@classmethod
1826
def setUpClass(cls):
@@ -164,10 +172,7 @@ def test_use_enough_colors(self):
164172
self.widget.update_color_schema()
165173
heatmap_widget = self.widget.heatmap_widget_grid[0][0]
166174
image = heatmap_widget.heatmap_item.pixmap().toImage()
167-
colors = np.full((len(data), 3), np.nan)
168-
for r in range(len(data)):
169-
c = image.pixelColor(0, r)
170-
colors[r] = c.red(), c.green(), c.blue()
175+
colors = image_row_colors(image)
171176
unique_colors = len(np.unique(colors, axis=0))
172177
self.assertLessEqual(len(data)*self.widget.threshold_low, unique_colors)
173178

@@ -209,6 +214,29 @@ def test_set_split_var(self):
209214
self.assertIs(w.split_by_var, None)
210215
self.assertEqual(len(w.heatmapparts.rows), 1)
211216

217+
def test_center_palette(self):
218+
data = np.arange(2).reshape(-1, 1)
219+
table = Table.from_numpy(Domain([ContinuousVariable("y")]), data)
220+
self.send_signal(self.widget.Inputs.data, table)
221+
222+
cb_model = self.widget.color_cb.model()
223+
ind = cb_model.indexFromItem(cb_model.findItems("Green-Black-Red")[0]).row()
224+
self.widget.palette_index = ind
225+
226+
desired_uncentered = [[0, 255, 0],
227+
[255, 0, 0]]
228+
229+
desired_centered = [[0, 0, 0],
230+
[255, 0, 0]]
231+
232+
for center, desired in [(False, desired_uncentered), (True, desired_centered)]:
233+
self.widget.center_palette = center
234+
self.widget.update_color_schema()
235+
heatmap_widget = self.widget.heatmap_widget_grid[0][0]
236+
image = heatmap_widget.heatmap_item.pixmap().toImage()
237+
colors = image_row_colors(image)
238+
np.testing.assert_almost_equal(colors, desired)
239+
212240

213241
if __name__ == "__main__":
214242
unittest.main()

0 commit comments

Comments
 (0)