Skip to content

Commit ea87577

Browse files
authored
Merge pull request #2678 from janezd/scatterplot-selection-groups
[ENH] Add Groups column to Selected Data in Scatter plot output
2 parents 8cbe98a + b3c0a0d commit ea87577

File tree

3 files changed

+65
-38
lines changed

3 files changed

+65
-38
lines changed

Orange/widgets/utils/annotated_data.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from itertools import chain
23

34
import numpy as np
45
from Orange.data import Domain, DiscreteVariable
@@ -52,6 +53,11 @@ def get_next_name(names, name):
5253
:param name: str
5354
:return: str
5455
"""
56+
if isinstance(names, Domain):
57+
names = [
58+
var.name
59+
for var in chain(names.attributes, names.class_vars, names.metas)
60+
]
5561
indexes = get_indices(names, name)
5662
if name not in names and not indexes:
5763
return name
@@ -74,6 +80,19 @@ def get_unique_names(names, proposed):
7480
return proposed
7581

7682

83+
def _table_with_annotation_column(data, values, column_data, var_name):
84+
var = DiscreteVariable(get_next_name(data.domain, var_name), values)
85+
class_vars, metas = data.domain.class_vars, data.domain.metas
86+
if not data.domain.class_vars:
87+
class_vars += (var, )
88+
else:
89+
metas += (var, )
90+
domain = Domain(data.domain.attributes, class_vars, metas)
91+
table = data.transform(domain)
92+
table[:, var] = column_data.reshape((len(data), 1))
93+
return table
94+
95+
7796
def create_annotated_table(data, selected_indices):
7897
"""
7998
Returns data with concatenated flag column. Flag column represents
@@ -86,30 +105,23 @@ def create_annotated_table(data, selected_indices):
86105
"""
87106
if data is None:
88107
return None
89-
names = [var.name for var in data.domain.variables + data.domain.metas]
90-
name = get_next_name(names, ANNOTATED_DATA_FEATURE_NAME)
91-
domain = add_columns(data.domain, metas=[DiscreteVariable(name, ("No", "Yes"))])
92108
annotated = np.zeros((len(data), 1))
93109
if selected_indices is not None:
94110
annotated[selected_indices] = 1
95-
table = data.transform(domain)
96-
table[:, name] = annotated
97-
return table
111+
return _table_with_annotation_column(
112+
data, ("No", "Yes"), annotated, ANNOTATED_DATA_FEATURE_NAME)
98113

99114

100-
def create_groups_table(data, selection):
115+
def create_groups_table(data, selection,
116+
include_unselected=True,
117+
var_name=ANNOTATED_DATA_FEATURE_NAME):
101118
if data is None:
102119
return None
103-
names = [var.name for var in data.domain.variables + data.domain.metas]
104-
name = get_next_name(names, ANNOTATED_DATA_FEATURE_NAME)
105-
metas = data.domain.metas + (
106-
DiscreteVariable(
107-
name,
108-
["Unselected"] + ["G{}".format(i + 1)
109-
for i in range(np.max(selection))]),
110-
)
111-
domain = Domain(data.domain.attributes, data.domain.class_vars, metas)
112-
table = data.transform(domain)
113-
table.metas[:, len(data.domain.metas):] = \
114-
selection.reshape(len(data), 1)
115-
return table
120+
values = ["G{}".format(i + 1) for i in range(np.max(selection))]
121+
if include_unselected:
122+
values.insert(0, "Unselected")
123+
else:
124+
mask = np.flatnonzero(selection)
125+
data = data[mask]
126+
selection = selection[mask] - 1
127+
return _table_with_annotation_column(data, values, selection, var_name)

Orange/widgets/visualize/owscatterplot.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotGraph
2121
from Orange.widgets.visualize.utils import VizRankDialogAttrPair
2222
from Orange.widgets.widget import OWWidget, AttributeList, Msg, Input, Output
23-
from Orange.widgets.utils.annotated_data import (create_annotated_table,
24-
ANNOTATED_DATA_SIGNAL_NAME,
25-
create_groups_table)
23+
from Orange.widgets.utils.annotated_data import (
24+
create_annotated_table, create_groups_table, ANNOTATED_DATA_SIGNAL_NAME)
2625

2726

2827
class ScatterPlotVizRank(VizRankDialogAttrPair):
@@ -427,25 +426,26 @@ def selection_changed(self):
427426
self.commit()
428427

429428
def send_data(self):
430-
selected = None
431-
selection = None
432429
# TODO: Implement selection for sql data
430+
def _get_selected():
431+
if not len(selection):
432+
return None
433+
return create_groups_table(data, graph.selection, False, "Group")
434+
435+
def _get_annotated():
436+
if graph.selection is not None and np.max(graph.selection) > 1:
437+
return create_groups_table(data, graph.selection)
438+
else:
439+
return create_annotated_table(data, selection)
440+
433441
graph = self.graph
434-
if isinstance(self.data, SqlTable):
435-
selected = self.data
436-
elif self.data is not None:
437-
selection = graph.get_selection()
438-
if len(selection) > 0:
439-
selected = self.data[selection]
440-
if graph.selection is not None and np.max(graph.selection) > 1:
441-
annotated = create_groups_table(self.data, graph.selection)
442-
else:
443-
annotated = create_annotated_table(self.data, selection)
444-
self.Outputs.selected_data.send(selected)
445-
self.Outputs.annotated_data.send(annotated)
442+
data = self.data
443+
selection = graph.get_selection()
444+
self.Outputs.annotated_data.send(_get_annotated())
445+
self.Outputs.selected_data.send(_get_selected())
446446

447447
# Store current selection in a setting that is stored in workflow
448-
if selection is not None and len(selection):
448+
if len(selection):
449449
self.selection_group = list(zip(selection, graph.selection[selection]))
450450
else:
451451
self.selection_group = None

Orange/widgets/visualize/tests/test_owscatterplot.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,19 @@ class TestOWScatterPlot(WidgetTest, WidgetOutputsTestMixin):
1818
def setUpClass(cls):
1919
super().setUpClass()
2020
WidgetOutputsTestMixin.init(cls)
21+
cls.same_input_output_domain = False
2122

2223
cls.signal_name = "Data"
2324
cls.signal_data = cls.data
2425

2526
def setUp(self):
2627
self.widget = self.create_widget(OWScatterPlot)
2728

29+
def _compare_selected_annotated_domains(self, selected, annotated):
30+
# Base class tests that selected.domain is a subset of annotated.domain
31+
# In scatter plot, the two domains are unrelated, so we disable the test
32+
pass
33+
2834
def test_set_data(self):
2935
# Connect iris to scatter plot
3036
self.send_signal(self.widget.Inputs.data, self.data)
@@ -154,6 +160,9 @@ def test_group_selections(self):
154160
def selectedx():
155161
return self.get_output(self.widget.Outputs.selected_data).X
156162

163+
def selected_groups():
164+
return self.get_output(self.widget.Outputs.selected_data).metas[:, 0]
165+
157166
def annotated():
158167
return self.get_output(self.widget.Outputs.annotated_data).metas
159168

@@ -163,6 +172,7 @@ def annotations():
163172
# Select 0:5
164173
graph.select(points[:5])
165174
np.testing.assert_equal(selectedx(), x[:5])
175+
np.testing.assert_equal(selected_groups(), np.zeros(5))
166176
sel_column[:5] = 1
167177
np.testing.assert_equal(annotated(), sel_column)
168178
self.assertEqual(annotations(), ["No", "Yes"])
@@ -171,6 +181,7 @@ def annotations():
171181
with self.modifiers(Qt.ShiftModifier):
172182
graph.select(points[5:10])
173183
np.testing.assert_equal(selectedx(), x[:10])
184+
np.testing.assert_equal(selected_groups(), np.array([0] * 5 + [1] * 5))
174185
sel_column[5:10] = 2
175186
np.testing.assert_equal(annotated(), sel_column)
176187
self.assertEqual(len(annotations()), 3)
@@ -180,12 +191,14 @@ def annotations():
180191
sel_column = np.zeros((len(self.data), 1))
181192
sel_column[15:20] = 1
182193
np.testing.assert_equal(selectedx(), x[15:20])
194+
np.testing.assert_equal(selected_groups(), np.zeros(5))
183195
self.assertEqual(annotations(), ["No", "Yes"])
184196

185197
# Alt-select (remove) 10:17; we have 17:20
186198
with self.modifiers(Qt.AltModifier):
187199
graph.select(points[10:17])
188200
np.testing.assert_equal(selectedx(), x[17:20])
201+
np.testing.assert_equal(selected_groups(), np.zeros(3))
189202
sel_column[15:17] = 0
190203
np.testing.assert_equal(annotated(), sel_column)
191204
self.assertEqual(annotations(), ["No", "Yes"])
@@ -194,6 +207,7 @@ def annotations():
194207
with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier):
195208
graph.select(points[20:25])
196209
np.testing.assert_equal(selectedx(), x[17:25])
210+
np.testing.assert_equal(selected_groups(), np.zeros(8))
197211
sel_column[20:25] = 1
198212
np.testing.assert_equal(annotated(), sel_column)
199213
self.assertEqual(annotations(), ["No", "Yes"])
@@ -205,6 +219,7 @@ def annotations():
205219
with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier):
206220
graph.select(points[35:40])
207221
sel_column[30:40] = 2
222+
np.testing.assert_equal(selected_groups(), np.array([0] * 8 + [1] * 10))
208223
np.testing.assert_equal(annotated(), sel_column)
209224
self.assertEqual(len(annotations()), 3)
210225

0 commit comments

Comments
 (0)