Skip to content

Commit 4a8c2eb

Browse files
authored
Merge pull request #3053 from VesnaT/scatter_plot_group_colors
[FIX] OWScatterOWScatterPlotGraph: Match group colors with marker colors
2 parents 9884dee + c13d5a6 commit 4a8c2eb

File tree

4 files changed

+30
-9
lines changed

4 files changed

+30
-9
lines changed

Orange/widgets/utils/annotated_data.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,16 @@ def create_groups_table(data, selection,
117117
var_name=ANNOTATED_DATA_FEATURE_NAME):
118118
if data is None:
119119
return None
120-
values = ["G{}".format(i + 1) for i in range(np.max(selection))]
120+
max_sel = np.max(selection)
121+
values = ["G{}".format(i + 1) for i in range(max_sel)]
121122
if include_unselected:
122-
values.insert(0, "Unselected")
123+
# Place Unselected instances in the "last group", so that the group
124+
# colors and scatter diagram marker colors will match
125+
values.append("Unselected")
126+
mask = (selection != 0)
127+
selection = selection.copy()
128+
selection[mask] = selection[mask] - 1
129+
selection[~mask] = selection[~mask] = max_sel
123130
else:
124131
mask = np.flatnonzero(selection)
125132
data = data[mask]

Orange/widgets/utils/tests/test_annotated_data.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import numpy as np
55

66
from Orange.data import Table, Variable
7+
from Orange.data.filter import SameValue
78
from Orange.widgets.utils.annotated_data import (
89
create_annotated_table, get_next_name, get_unique_names,
9-
ANNOTATED_DATA_FEATURE_NAME)
10+
create_groups_table, ANNOTATED_DATA_FEATURE_NAME
11+
)
1012

1113

1214
class TestGetNextName(unittest.TestCase):
@@ -115,3 +117,14 @@ def test_get_unique_names(self):
115117
"bravo (3)"]
116118
self.assertEqual(get_unique_names(names, ["bravo", "charlie"]),
117119
["bravo (5)", "charlie (5)"])
120+
121+
def test_create_groups_table_include_unselected(self):
122+
group_indices = random.sample(range(0, len(self.zoo)), 20)
123+
selection = np.zeros(len(self.zoo), dtype=np.uint8)
124+
selection[group_indices[:10]] = 1
125+
selection[group_indices[10:]] = 2
126+
table = create_groups_table(self.zoo, selection)
127+
self.assertEqual(
128+
len(SameValue(table.domain["Selected"], "Unselected")(table)),
129+
len(self.zoo) - len(group_indices)
130+
)

Orange/widgets/visualize/owscatterplotgraph.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -851,12 +851,9 @@ def compute_colors_sel(self, keep_colors=False):
851851
_make_pen(QColor(255, 190, 0, 255),
852852
SELECTION_WIDTH + 1.)]
853853
else:
854-
# Start with the first color so that the colors of the
855-
# additional attribute in annotation (which start with 0,
856-
# unselected) will match these colors
857854
palette = ColorPaletteGenerator(number_of_colors=sels + 1)
858855
pens = [nopen] + \
859-
[_make_pen(palette[i + 1], SELECTION_WIDTH + 1.)
856+
[_make_pen(palette[i], SELECTION_WIDTH + 1.)
860857
for i in range(sels)]
861858
pen = [pens[a] for a in self.selection[self.valid_data]]
862859
else:

Orange/widgets/visualize/tests/test_owscatterplot.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,9 @@ def annotations():
196196
graph.select(points[5:10])
197197
np.testing.assert_equal(selectedx(), x[:10])
198198
np.testing.assert_equal(selected_groups(), np.array([0] * 5 + [1] * 5))
199-
sel_column[5:10] = 2
199+
sel_column[:5] = 0
200+
sel_column[5:10] = 1
201+
sel_column[10:] = 2
200202
np.testing.assert_equal(annotated(), sel_column)
201203
self.assertEqual(len(annotations()), 3)
202204

@@ -232,7 +234,9 @@ def annotations():
232234
# ... then Ctrl-Shift-select (add-to-last) 10:17; we have 17:25, 30:40
233235
with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier):
234236
graph.select(points[35:40])
235-
sel_column[30:40] = 2
237+
sel_column[:] = 2
238+
sel_column[17:25] = 0
239+
sel_column[30:40] = 1
236240
np.testing.assert_equal(selected_groups(), np.array([0] * 8 + [1] * 10))
237241
np.testing.assert_equal(annotated(), sel_column)
238242
self.assertEqual(len(annotations()), 3)

0 commit comments

Comments
 (0)