Skip to content

Commit e5f5128

Browse files
committed
OWScatterPlot: Add column with groups to selected data output
1 parent 962f4c2 commit e5f5128

File tree

3 files changed

+47
-29
lines changed

3 files changed

+47
-29
lines changed

Orange/widgets/utils/annotated_data.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -103,19 +103,22 @@ def create_annotated_table(data, selected_indices):
103103
return table
104104

105105

106-
def create_groups_table(data, selection):
106+
def create_groups_table(data, selection,
107+
include_unselected=True,
108+
var_name=ANNOTATED_DATA_FEATURE_NAME):
107109
if data is None:
108110
return None
109-
names = [var.name for var in data.domain.variables + data.domain.metas]
110-
name = get_next_name(names, ANNOTATED_DATA_FEATURE_NAME)
111-
metas = data.domain.metas + (
112-
DiscreteVariable(
113-
name,
114-
["Unselected"] + ["G{}".format(i + 1)
115-
for i in range(np.max(selection))]),
116-
)
111+
values = ["G{}".format(i + 1) for i in range(np.max(selection))]
112+
if include_unselected:
113+
values.insert(0, "Unselected")
114+
else:
115+
mask = np.flatnonzero(selection)
116+
data = data[mask]
117+
selection = selection[mask] - 1
118+
119+
var_name = get_next_name(data.domain, var_name)
120+
metas = data.domain.metas + (DiscreteVariable(var_name, values), )
117121
domain = Domain(data.domain.attributes, data.domain.class_vars, metas)
118122
table = data.transform(domain)
119-
table.metas[:, len(data.domain.metas):] = \
120-
selection.reshape(len(data), 1)
123+
table.metas[:, len(data.domain.metas):] = selection.reshape(len(data), 1)
121124
return table

Orange/widgets/visualize/owscatterplot.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotGraph
2222
from Orange.widgets.visualize.utils import VizRankDialogAttrPair
2323
from Orange.widgets.widget import OWWidget, AttributeList, Msg, Input, Output
24-
from Orange.widgets.utils.annotated_data import (create_annotated_table,
25-
ANNOTATED_DATA_SIGNAL_NAME,
26-
create_groups_table)
24+
from Orange.widgets.utils.annotated_data import (
25+
create_annotated_table, create_groups_table, ANNOTATED_DATA_SIGNAL_NAME)
2726

2827

2928
class ScatterPlotVizRank(VizRankDialogAttrPair):
@@ -428,25 +427,26 @@ def selection_changed(self):
428427
self.commit()
429428

430429
def send_data(self):
431-
selected = None
432-
selection = None
433430
# TODO: Implement selection for sql data
431+
def _get_selected():
432+
if not len(selection):
433+
return None
434+
return create_groups_table(data, graph.selection, False, "Group")
435+
436+
def _get_annotated():
437+
if graph.selection is not None and np.max(graph.selection) > 1:
438+
return create_groups_table(data, graph.selection)
439+
else:
440+
return create_annotated_table(data, selection)
441+
434442
graph = self.graph
435-
if isinstance(self.data, SqlTable):
436-
selected = self.data
437-
elif self.data is not None:
438-
selection = graph.get_selection()
439-
if len(selection) > 0:
440-
selected = self.data[selection]
441-
if graph.selection is not None and np.max(graph.selection) > 1:
442-
annotated = create_groups_table(self.data, graph.selection)
443-
else:
444-
annotated = create_annotated_table(self.data, selection)
445-
self.Outputs.selected_data.send(selected)
446-
self.Outputs.annotated_data.send(annotated)
443+
data = self.data
444+
selection = graph.get_selection()
445+
self.Outputs.annotated_data.send(_get_annotated())
446+
self.Outputs.selected_data.send(_get_selected())
447447

448448
# Store current selection in a setting that is stored in workflow
449-
if selection is not None and len(selection):
449+
if len(selection):
450450
self.selection_group = list(zip(selection, graph.selection[selection]))
451451
else:
452452
self.selection_group = None

Orange/widgets/visualize/tests/test_owscatterplot.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,19 @@ class TestOWScatterPlot(WidgetTest, WidgetOutputsTestMixin):
1818
def setUpClass(cls):
1919
super().setUpClass()
2020
WidgetOutputsTestMixin.init(cls)
21+
cls.same_input_output_domain = False
2122

2223
cls.signal_name = "Data"
2324
cls.signal_data = cls.data
2425

2526
def setUp(self):
2627
self.widget = self.create_widget(OWScatterPlot)
2728

29+
def _compare_selected_annotated_domains(self, selected, annotated):
30+
# Base class tests that selected.domain is a subset of annotated.domain
31+
# In scatter plot, the two domains are unrelated, so we disable the test
32+
pass
33+
2834
def test_set_data(self):
2935
# Connect iris to scatter plot
3036
self.send_signal(self.widget.Inputs.data, self.data)
@@ -154,6 +160,9 @@ def test_group_selections(self):
154160
def selectedx():
155161
return self.get_output(self.widget.Outputs.selected_data).X
156162

163+
def selected_groups():
164+
return self.get_output(self.widget.Outputs.selected_data).metas[:, 0]
165+
157166
def annotated():
158167
return self.get_output(self.widget.Outputs.annotated_data).metas
159168

@@ -163,6 +172,7 @@ def annotations():
163172
# Select 0:5
164173
graph.select(points[:5])
165174
np.testing.assert_equal(selectedx(), x[:5])
175+
np.testing.assert_equal(selected_groups(), np.zeros(5))
166176
sel_column[:5] = 1
167177
np.testing.assert_equal(annotated(), sel_column)
168178
self.assertEqual(annotations(), ["No", "Yes"])
@@ -171,6 +181,7 @@ def annotations():
171181
with self.modifiers(Qt.ShiftModifier):
172182
graph.select(points[5:10])
173183
np.testing.assert_equal(selectedx(), x[:10])
184+
np.testing.assert_equal(selected_groups(), np.array([0] * 5 + [1] * 5))
174185
sel_column[5:10] = 2
175186
np.testing.assert_equal(annotated(), sel_column)
176187
self.assertEqual(len(annotations()), 3)
@@ -180,12 +191,14 @@ def annotations():
180191
sel_column = np.zeros((len(self.data), 1))
181192
sel_column[15:20] = 1
182193
np.testing.assert_equal(selectedx(), x[15:20])
194+
np.testing.assert_equal(selected_groups(), np.zeros(5))
183195
self.assertEqual(annotations(), ["No", "Yes"])
184196

185197
# Alt-select (remove) 10:17; we have 17:20
186198
with self.modifiers(Qt.AltModifier):
187199
graph.select(points[10:17])
188200
np.testing.assert_equal(selectedx(), x[17:20])
201+
np.testing.assert_equal(selected_groups(), np.zeros(3))
189202
sel_column[15:17] = 0
190203
np.testing.assert_equal(annotated(), sel_column)
191204
self.assertEqual(annotations(), ["No", "Yes"])
@@ -194,6 +207,7 @@ def annotations():
194207
with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier):
195208
graph.select(points[20:25])
196209
np.testing.assert_equal(selectedx(), x[17:25])
210+
np.testing.assert_equal(selected_groups(), np.zeros(8))
197211
sel_column[20:25] = 1
198212
np.testing.assert_equal(annotated(), sel_column)
199213
self.assertEqual(annotations(), ["No", "Yes"])
@@ -205,6 +219,7 @@ def annotations():
205219
with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier):
206220
graph.select(points[35:40])
207221
sel_column[30:40] = 2
222+
np.testing.assert_equal(selected_groups(), np.array([0] * 8 + [1] * 10))
208223
np.testing.assert_equal(annotated(), sel_column)
209224
self.assertEqual(len(annotations()), 3)
210225

0 commit comments

Comments
 (0)