From 784589c16a5f978a24d52dd50c93a48fe04530c0 Mon Sep 17 00:00:00 2001 From: Ales Erjavec Date: Mon, 30 Jan 2017 15:14:46 +0100 Subject: [PATCH] owpaintdata: Adjust color model to input dataset Fix a IndexError when the input dataset's class variable has more then 17 values. --- Orange/widgets/data/owpaintdata.py | 13 +++++++++++++ Orange/widgets/data/tests/test_owpaintdata.py | 11 ++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/Orange/widgets/data/owpaintdata.py b/Orange/widgets/data/owpaintdata.py index 42a69ff82d3..e51acbf881a 100644 --- a/Orange/widgets/data/owpaintdata.py +++ b/Orange/widgets/data/owpaintdata.py @@ -794,6 +794,7 @@ def __init__(self): self.input_data = None self.input_classes = [] + self.input_colors = None self.input_has_attr2 = True self.current_tool = None self._selected_indices = None @@ -1016,9 +1017,12 @@ def _check_and_set_data(data): if data.domain.class_vars: self.Warning.continuous_target() self.input_classes = ["C1"] + self.input_colors = None y = np.zeros(len(data)) else: self.input_classes = y.values + self.input_colors = y.colors + y = data[:, y].Y self.input_has_attr2 = len(data.domain.attributes) >= 2 @@ -1036,7 +1040,16 @@ def reset_to_input(self): self.undo_stack.clear() index = self.selected_class_label() + if self.input_colors is not None: + colors = self.input_colors + else: + colors = colorpalette.DefaultRGBColors + palette = colorpalette.ColorPaletteGenerator( + number_of_colors=len(colors), rgb_colors=colors) + self.colors = palette + self.class_model.colors = palette self.class_model[:] = self.input_classes + newindex = min(max(index, 0), len(self.class_model) - 1) itemmodels.select_row(self.classValuesView, newindex) diff --git a/Orange/widgets/data/tests/test_owpaintdata.py b/Orange/widgets/data/tests/test_owpaintdata.py index 6ff2cefa2de..4df12d22206 100644 --- a/Orange/widgets/data/tests/test_owpaintdata.py +++ b/Orange/widgets/data/tests/test_owpaintdata.py @@ -4,7 +4,7 @@ import numpy as np from AnyQt.QtCore import QRectF, QPointF -from Orange.data import Table +from Orange.data import Table, DiscreteVariable, ContinuousVariable, Domain from Orange.widgets.data import owpaintdata from Orange.widgets.data.owpaintdata import OWPaintData from Orange.widgets.tests.base import WidgetTest @@ -43,3 +43,12 @@ def test_output_shares_internal_buffer(self): np.testing.assert_equal(output1.Y, output1_copy.Y) self.assertTrue(np.any(output1.X != output2.X)) + + def test_20_values_class(self): + domain = Domain( + [ContinuousVariable("A"), + ContinuousVariable("B")], + DiscreteVariable("C", values=[chr(ord("a") + i) for i in range(20)]) + ) + data = Table(domain, [[0.1, 0.2, "a"], [0.4, 0.7, "t"]]) + self.send_signal("Data", data)