Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions Orange/widgets/visualize/owscatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class Outputs:
annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
features = Output("Features", Table, dynamic=False)

settings_version = 2
settingsHandler = DomainContextHandler()

auto_send_selection = Setting(True)
Expand All @@ -125,7 +126,7 @@ class Outputs:

attr_x = ContextSetting(None)
attr_y = ContextSetting(None)
selection = Setting(None, schema_only=True)
selection_group = Setting(None, schema_only=True)

graph = SettingProvider(OWScatterPlotGraph)

Expand Down Expand Up @@ -399,10 +400,11 @@ def sparse_to_dense(self, input_data=None):

def apply_selection(self):
"""Apply selection saved in workflow."""
if self.data is not None and self.selection is not None:
if self.data is not None and self.selection_group is not None:
self.graph.selection = np.zeros(len(self.data), dtype=np.uint8)
self.selection = [x for x in self.selection if x < len(self.data)]
self.graph.selection[self.selection] = 1
self.selection_group = [x for x in self.selection_group if x[0] < len(self.data)]
selection_array = np.array(self.selection_group).T
self.graph.selection[selection_array[0]] = selection_array[1]
self.graph.update_colors(keep_colors=True)

@Inputs.features
Expand Down Expand Up @@ -476,8 +478,10 @@ def send_data(self):
self.Outputs.annotated_data.send(annotated)

# Store current selection in a setting that is stored in workflow
if self.selection is not None and len(selection):
self.selection = list(selection)
if selection is not None and len(selection):
self.selection_group = list(zip(selection, graph.selection[selection]))
else:
self.selection_group = None

def send_features(self):
features = None
Expand Down Expand Up @@ -518,6 +522,11 @@ def onDeleteWidget(self):
self.graph.plot_widget.getViewBox().deleteLater()
self.graph.plot_widget.clear()

@classmethod
def migrate_settings(cls, settings, version):
if version < 2 and "selection" in settings and settings["selection"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if version < 2 and settings.get(selection):

settings["selection_group"] = [(a, 1) for a in settings["selection"]]


def main(argv=None):
import sys
Expand Down
16 changes: 14 additions & 2 deletions Orange/widgets/visualize/tests/test_owscatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,16 @@ def test_none_data(self):
self.send_signal(self.widget.Inputs.data, table)
self.widget.reset_graph_data()

def test_saving_selection(self):
self.send_signal(self.widget.Inputs.data, self.data) # iris
self.widget.graph.select_by_rectangle(QRectF(4, 3, 3, 1))
selected_inds = np.flatnonzero(self.widget.graph.selection)
settings = self.widget.settingsHandler.pack_data(self.widget)
np.testing.assert_equal(selected_inds, [i for i, g in settings["selection_group"]])

def test_points_selection(self):
# Opening widget with saved selection should restore it
self.widget.selection = list(range(50))
self.widget.selection_group = [(i, 1) for i in range(50)]
self.send_signal(self.widget.Inputs.data, self.data) # iris
selected_data = self.get_output(self.widget.Outputs.selected_data)
self.assertEqual(len(selected_data), 50)
Expand All @@ -238,10 +245,15 @@ def test_points_selection(self):
selected_data = self.get_output(self.widget.Outputs.selected_data)
self.assertIsNone(selected_data)

def test_migrate_selection(self):
settings = dict(selection=list(range(2)))
OWScatterPlot.migrate_settings(settings, 0)
self.assertEqual(settings["selection_group"], [(0, 1), (1, 1)])

def test_invalid_points_selection(self):
# if selection contains rows that are not present in the current
# dataset, widget should select what can be selected.
self.widget.selection = list(range(50))
self.widget.selection_group = [(i, 1) for i in range(50)]
self.send_signal(self.widget.Inputs.data, self.data[:10])
selected_data = self.get_output(self.widget.Outputs.selected_data)
self.assertEqual(len(selected_data), 10)
Expand Down