Skip to content

Commit 9274a20

Browse files
authored
Merge pull request #2152 from jerneju/value-scatterplot
[FIX] Scatter Plot: dealing with scipy sparse matrix
2 parents 8e13aeb + cceeee3 commit 9274a20

File tree

3 files changed

+78
-15
lines changed

3 files changed

+78
-15
lines changed

Orange/widgets/visualize/owscatterplot.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import scipy.sparse as sp
23

34
from AnyQt.QtCore import Qt, QTimer
45
from AnyQt.QtGui import (
@@ -394,7 +395,8 @@ def set_subset_data(self, subset_data):
394395

395396
# called when all signals are received, so the graph is updated only once
396397
def handleNewSignals(self):
397-
self.graph.new_data(self.data_metas_X, self.subset_data)
398+
self.graph.new_data(self.sparse_to_dense(self.data_metas_X),
399+
self.sparse_to_dense(self.subset_data))
398400
if self.attribute_selection_list and \
399401
all(attr in self.graph.domain
400402
for attr in self.attribute_selection_list):
@@ -407,6 +409,37 @@ def handleNewSignals(self):
407409
self.apply_selection()
408410
self.unconditional_commit()
409411

412+
def prepare_data(self):
413+
"""
414+
Only when dealing with sparse matrices.
415+
GH-2152
416+
"""
417+
self.graph.new_data(self.sparse_to_dense(self.data_metas_X),
418+
self.sparse_to_dense(self.subset_data),
419+
new=False)
420+
421+
def sparse_to_dense(self, input_data=None):
422+
self.vizrank_button.setEnabled(not (self.data and self.data.is_sparse()))
423+
if input_data is None or not input_data.is_sparse():
424+
return input_data
425+
keys = []
426+
attrs = {self.attr_x,
427+
self.attr_y,
428+
self.graph.attr_color,
429+
self.graph.attr_shape,
430+
self.graph.attr_size,
431+
self.graph.attr_label}
432+
for i, attr in enumerate(input_data.domain):
433+
if attr in attrs:
434+
keys.append(i)
435+
new_domain = input_data.domain.select_columns(keys)
436+
dmx = Table.from_table(new_domain, input_data)
437+
dmx.X = dmx.X.toarray()
438+
# TODO: remove once we make sure Y is always dense.
439+
if sp.issparse(dmx.Y):
440+
dmx.Y = dmx.Y.toarray()
441+
return dmx
442+
410443
def apply_selection(self):
411444
"""Apply selection saved in workflow."""
412445
if self.data is not None and self.selection is not None:
@@ -441,12 +474,14 @@ def set_attr(self, attr_x, attr_y):
441474
self.update_attr()
442475

443476
def update_attr(self):
477+
self.prepare_data()
444478
self.update_graph()
445479
self.cb_class_density.setEnabled(self.graph.can_draw_density())
446480
self.cb_reg_line.setEnabled(self.graph.can_draw_regresssion_line())
447481
self.send_features()
448482

449483
def update_colors(self):
484+
self.prepare_data()
450485
self.cb_class_density.setEnabled(self.graph.can_draw_density())
451486

452487
def update_density(self):

Orange/widgets/visualize/owscatterplotgraph.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -591,17 +591,18 @@ def update_tooltip(self, modifiers):
591591
text = self.tiptexts.get(int(modifiers), self.tiptexts[0])
592592
self.tip_textitem.setHtml(text)
593593

594-
def new_data(self, data, subset_data=None, **args):
595-
self.plot_widget.clear()
596-
self.remove_legend()
594+
def new_data(self, data, subset_data=None, new=True, **args):
595+
if new:
596+
self.plot_widget.clear()
597+
self.remove_legend()
597598

598-
self.density_img = None
599-
self.scatterplot_item = None
600-
self.scatterplot_item_sel = None
601-
self.reg_line_item = None
602-
self.labels = []
603-
self.selection = None
604-
self.valid_data = None
599+
self.density_img = None
600+
self.scatterplot_item = None
601+
self.scatterplot_item_sel = None
602+
self.reg_line_item = None
603+
self.labels = []
604+
self.selection = None
605+
self.valid_data = None
605606

606607
self.subset_indices = set(e.id for e in subset_data) if subset_data else None
607608

@@ -776,13 +777,15 @@ def compute_sizes(self):
776777
return size_data
777778

778779
def update_sizes(self):
780+
self.master.prepare_data()
781+
self.update_point_size()
782+
783+
def update_point_size(self):
779784
if self.scatterplot_item:
780785
size_data = self.compute_sizes()
781786
self.scatterplot_item.setSize(size_data)
782787
self.scatterplot_item_sel.setSize(size_data + SELECTION_WIDTH)
783788

784-
update_point_size = update_sizes
785-
786789
def get_color_index(self):
787790
if self.attr_color is None:
788791
return -1
@@ -907,6 +910,9 @@ def make_pen(color, width):
907910

908911
def update_colors(self, keep_colors=False):
909912
self.master.update_colors()
913+
self.update_alpha_value(keep_colors)
914+
915+
def update_alpha_value(self, keep_colors=False):
910916
if self.scatterplot_item:
911917
pen_data, brush_data = self.compute_colors(keep_colors)
912918
pen_data_sel, brush_data_sel = self.compute_colors_sel(keep_colors)
@@ -922,8 +928,6 @@ def update_colors(self, keep_colors=False):
922928
elif self.density_img:
923929
self.plot_widget.removeItem(self.density_img)
924930

925-
update_alpha_value = update_colors
926-
927931
def create_labels(self):
928932
for x, y in zip(*self.scatterplot_item.getData()):
929933
ti = TextItem()
@@ -937,6 +941,7 @@ def update_labels(self):
937941
for label in self.labels:
938942
label.setText("")
939943
return
944+
self.assure_attribute_present(self.attr_label)
940945
if not self.labels:
941946
self.create_labels()
942947
label_column = self.data.get_column_view(self.attr_label)[0]
@@ -972,11 +977,16 @@ def compute_symbols(self):
972977
return shape_data
973978

974979
def update_shapes(self):
980+
self.assure_attribute_present(self.attr_shape)
975981
if self.scatterplot_item:
976982
shape_data = self.compute_symbols()
977983
self.scatterplot_item.setSymbol(shape_data)
978984
self.make_legend()
979985

986+
def assure_attribute_present(self, attr):
987+
if attr not in self.data.domain:
988+
self.master.prepare_data()
989+
980990
def update_grid(self):
981991
self.plot_widget.showGrid(x=self.show_grid, y=self.show_grid)
982992

Orange/widgets/visualize/tests/test_owscatterplot.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# pylint: disable=missing-docstring
33
from unittest.mock import MagicMock
44
import numpy as np
5+
import scipy.sparse as sp
56

67
from AnyQt.QtCore import QRectF, Qt
78

@@ -268,6 +269,23 @@ def test_set_strings_settings(self):
268269
self.assertEqual(w.graph.attr_shape.name, "iris")
269270
self.assertEqual(w.graph.attr_size.name, "petal width")
270271

272+
def test_sparse(self):
273+
"""
274+
Test sparse data.
275+
GH-2152
276+
GH-2157
277+
"""
278+
table = Table("iris")
279+
table.X = sp.csr_matrix(table.X)
280+
self.assertTrue(sp.issparse(table.X))
281+
table.Y = sp.csr_matrix(table._Y) # pylint: disable=protected-access
282+
self.assertTrue(sp.issparse(table.Y))
283+
self.send_signal("Data", table)
284+
self.widget.set_subset_data(table[:30])
285+
data = self.get_output("Data")
286+
self.assertTrue(data.is_sparse())
287+
self.assertEqual(len(data.domain), 5)
288+
271289

272290
if __name__ == "__main__":
273291
import unittest

0 commit comments

Comments
 (0)