Skip to content

Commit ca1a0bd

Browse files
committed
OWScatterPlot: Add column with groups to selected data output
1 parent 6afb4cf commit ca1a0bd

File tree

3 files changed

+46
-30
lines changed

3 files changed

+46
-30
lines changed

Orange/widgets/utils/annotated_data.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -103,19 +103,22 @@ def create_annotated_table(data, selected_indices):
103103
return table
104104

105105

106-
def create_groups_table(data, selection):
106+
def create_groups_table(data, selection,
107+
include_unselected=True,
108+
var_name=ANNOTATED_DATA_FEATURE_NAME):
107109
if data is None:
108110
return None
109-
names = [var.name for var in data.domain.variables + data.domain.metas]
110-
name = get_next_name(names, ANNOTATED_DATA_FEATURE_NAME)
111-
metas = data.domain.metas + (
112-
DiscreteVariable(
113-
name,
114-
["Unselected"] + ["G{}".format(i + 1)
115-
for i in range(np.max(selection))]),
116-
)
111+
values = ["G{}".format(i + 1) for i in range(np.max(selection))]
112+
if include_unselected:
113+
values.insert(0, "Unselected")
114+
else:
115+
mask = np.flatnonzero(selection)
116+
data = data[mask]
117+
selection = selection[mask] - 1
118+
119+
var_name = get_next_name(data.domain, var_name)
120+
metas = data.domain.metas + (DiscreteVariable(var_name, values), )
117121
domain = Domain(data.domain.attributes, data.domain.class_vars, metas)
118122
table = data.transform(domain)
119-
table.metas[:, len(data.domain.metas):] = \
120-
selection.reshape(len(data), 1)
123+
table.metas[:, len(data.domain.metas):] = selection.reshape(len(data), 1)
121124
return table

Orange/widgets/visualize/owscatterplot.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotGraph
2222
from Orange.widgets.visualize.utils import VizRankDialogAttrPair
2323
from Orange.widgets.widget import OWWidget, AttributeList, Msg, Input, Output
24-
from Orange.widgets.utils.annotated_data import (create_annotated_table,
25-
ANNOTATED_DATA_SIGNAL_NAME,
26-
create_groups_table)
24+
from Orange.widgets.utils.annotated_data import (
25+
create_annotated_table, create_groups_table, ANNOTATED_DATA_SIGNAL_NAME,
26+
get_next_name)
2727

2828

2929
class ScatterPlotVizRank(VizRankDialogAttrPair):
@@ -428,25 +428,25 @@ def selection_changed(self):
428428
self.commit()
429429

430430
def send_data(self):
431-
selected = None
432-
selection = None
433-
# TODO: Implement selection for sql data
431+
def _get_selected():
432+
if not len(selection):
433+
return None
434+
return create_groups_table(data, graph.selection, False, "Group")
435+
436+
def _get_annotated():
437+
if graph.selection is not None and np.max(graph.selection) > 1:
438+
return create_groups_table(data, graph.selection)
439+
else:
440+
return create_annotated_table(data, selection)
441+
434442
graph = self.graph
435-
if isinstance(self.data, SqlTable):
436-
selected = self.data
437-
elif self.data is not None:
438-
selection = graph.get_selection()
439-
if len(selection) > 0:
440-
selected = self.data[selection]
441-
if graph.selection is not None and np.max(graph.selection) > 1:
442-
annotated = create_groups_table(self.data, graph.selection)
443-
else:
444-
annotated = create_annotated_table(self.data, selection)
445-
self.Outputs.selected_data.send(selected)
446-
self.Outputs.annotated_data.send(annotated)
443+
data = self.data
444+
selection = graph.get_selection()
445+
self.Outputs.annotated_data.send(_get_annotated())
446+
self.Outputs.selected_data.send(_get_selected())
447447

448448
# Store current selection in a setting that is stored in workflow
449-
if selection is not None and len(selection):
449+
if len(selection):
450450
self.selection_group = list(zip(selection, graph.selection[selection]))
451451
else:
452452
self.selection_group = None

Orange/widgets/visualize/tests/test_owscatterplot.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@ class TestOWScatterPlot(WidgetTest, WidgetOutputsTestMixin):
1818
def setUpClass(cls):
1919
super().setUpClass()
2020
WidgetOutputsTestMixin.init(cls)
21+
cls.same_input_output_domain = False
2122

2223
cls.signal_name = "Data"
2324
cls.signal_data = cls.data
2425

2526
def setUp(self):
2627
self.widget = self.create_widget(OWScatterPlot)
2728

29+
def _compare_selected_annotated_domains(self, selected, annotated):
30+
pass
31+
2832
def test_set_data(self):
2933
# Connect iris to scatter plot
3034
self.send_signal(self.widget.Inputs.data, self.data)
@@ -154,6 +158,9 @@ def test_group_selections(self):
154158
def selectedx():
155159
return self.get_output(self.widget.Outputs.selected_data).X
156160

161+
def selected_groups():
162+
return self.get_output(self.widget.Outputs.selected_data).metas[:, 0]
163+
157164
def annotated():
158165
return self.get_output(self.widget.Outputs.annotated_data).metas
159166

@@ -163,6 +170,7 @@ def annotations():
163170
# Select 0:5
164171
graph.select(points[:5])
165172
np.testing.assert_equal(selectedx(), x[:5])
173+
np.testing.assert_equal(selected_groups(), np.zeros(5))
166174
sel_column[:5] = 1
167175
np.testing.assert_equal(annotated(), sel_column)
168176
self.assertEqual(annotations(), ["No", "Yes"])
@@ -171,6 +179,7 @@ def annotations():
171179
with self.modifiers(Qt.ShiftModifier):
172180
graph.select(points[5:10])
173181
np.testing.assert_equal(selectedx(), x[:10])
182+
np.testing.assert_equal(selected_groups(), np.array([0] * 5 + [1] * 5))
174183
sel_column[5:10] = 2
175184
np.testing.assert_equal(annotated(), sel_column)
176185
self.assertEqual(len(annotations()), 3)
@@ -180,12 +189,14 @@ def annotations():
180189
sel_column = np.zeros((len(self.data), 1))
181190
sel_column[15:20] = 1
182191
np.testing.assert_equal(selectedx(), x[15:20])
192+
np.testing.assert_equal(selected_groups(), np.zeros(5))
183193
self.assertEqual(annotations(), ["No", "Yes"])
184194

185195
# Alt-select (remove) 10:17; we have 17:20
186196
with self.modifiers(Qt.AltModifier):
187197
graph.select(points[10:17])
188198
np.testing.assert_equal(selectedx(), x[17:20])
199+
np.testing.assert_equal(selected_groups(), np.zeros(3))
189200
sel_column[15:17] = 0
190201
np.testing.assert_equal(annotated(), sel_column)
191202
self.assertEqual(annotations(), ["No", "Yes"])
@@ -194,6 +205,7 @@ def annotations():
194205
with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier):
195206
graph.select(points[20:25])
196207
np.testing.assert_equal(selectedx(), x[17:25])
208+
np.testing.assert_equal(selected_groups(), np.zeros(8))
197209
sel_column[20:25] = 1
198210
np.testing.assert_equal(annotated(), sel_column)
199211
self.assertEqual(annotations(), ["No", "Yes"])
@@ -205,6 +217,7 @@ def annotations():
205217
with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier):
206218
graph.select(points[35:40])
207219
sel_column[30:40] = 2
220+
np.testing.assert_equal(selected_groups(), np.array([0] * 8 + [1] * 10))
208221
np.testing.assert_equal(annotated(), sel_column)
209222
self.assertEqual(len(annotations()), 3)
210223

0 commit comments

Comments
 (0)