Skip to content

Commit 554b885

Browse files
authored
Merge pull request #2305 from jerneju/sparse-merge
[FIX] Merge: work with sparse
2 parents d1352ea + e461dc3 commit 554b885

File tree

3 files changed

+55
-18
lines changed

3 files changed

+55
-18
lines changed

Orange/data/table.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,15 @@ def __init__(self, table, row_index):
5757
self.id = table.ids[row_index]
5858
self._x = table.X[row_index]
5959
if sp.issparse(self._x):
60-
self.sparse_x = self._x
60+
self.sparse_x = sp.csr_matrix(self._x)
6161
self._x = np.asarray(self._x.todense())[0]
6262
self._y = table._Y[row_index]
6363
if sp.issparse(self._y):
64-
self.sparse_y = self._y
64+
self.sparse_y = sp.csr_matrix(self._y)
6565
self._y = np.asarray(self._y.todense())[0]
6666
self._metas = table.metas[row_index]
6767
if sp.issparse(self._metas):
68-
self.sparse_metas = self._metas
68+
self.sparse_metas = sp.csr_matrix(self._metas)
6969
self._metas = np.asarray(self._metas.todense())[0]
7070

7171
@property

Orange/widgets/data/owmergedata.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from AnyQt.QtWidgets import QApplication, QStyle, QSizePolicy
55

66
import numpy as np
7+
import scipy.sparse as sp
78

89
import Orange
910
from Orange.data import StringVariable, ContinuousVariable
11+
from Orange.data.util import hstack
1012
from Orange.widgets import widget, gui, settings
1113
from Orange.widgets.utils import itemmodels
1214
from Orange.widgets.utils.sql import check_sql_input
@@ -362,20 +364,29 @@ def _join_table_by_indices(self, reduced_extra, indices):
362364
def _join_array_by_indices(left, right, indices, string_cols=None):
363365
"""Join (horizontally) two arrays, taking pairs of rows given in indices
364366
"""
365-
tpe = object if object in (left.dtype, right.dtype) else left.dtype
366-
left_width, right_width = left.shape[1], right.shape[1]
367-
arr = np.full((indices.shape[1], left_width + right_width), np.nan, tpe)
368-
if string_cols:
369-
arr[:, string_cols] = ""
370-
for indices, to_change, lookup in (
371-
(indices[0], arr[:, :left_width], left),
372-
(indices[1], arr[:, left_width:], right)):
373-
known = indices != -1
374-
to_change[known] = lookup[indices[known]]
375-
return arr
376-
377-
378-
def test():
367+
def prepare(arr, inds, str_cols):
368+
try:
369+
newarr = arr[inds]
370+
except IndexError:
371+
newarr = np.full_like(arr, np.nan)
372+
else:
373+
empty = np.full(arr.shape[1], np.nan)
374+
if str_cols:
375+
assert arr.dtype == object
376+
empty = empty.astype(object)
377+
empty[str_cols] = ''
378+
newarr[inds == -1] = empty
379+
return newarr
380+
381+
left_width = left.shape[1]
382+
str_left = [i for i in string_cols or () if i < left_width]
383+
str_right = [i - left_width for i in string_cols or () if i >= left_width]
384+
res = hstack((prepare(left, indices[0], str_left),
385+
prepare(right, indices[1], str_right)))
386+
return res
387+
388+
389+
def main():
379390
app = QApplication([])
380391
w = OWMergeData()
381392
data = Orange.data.Table("tests/data-gender-region")
@@ -388,4 +399,4 @@ def test():
388399

389400

390401
if __name__ == "__main__":
391-
test()
402+
main()

Orange/widgets/data/tests/test_owmergedata.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from itertools import chain
44

55
import numpy as np
6+
import scipy.sparse as sp
67

78
from Orange.data import Table, Domain, DiscreteVariable, StringVariable
89
from Orange.widgets.data.owmergedata import OWMergeData, INSTANCEID, INDEX
@@ -425,3 +426,28 @@ def test_best_match(self):
425426
self.assertEqual(self.widget.attr_merge_extra, zoo_images.domain[-1])
426427
self.assertEqual(self.widget.attr_combine_data, zoo.domain[-1])
427428
self.assertEqual(self.widget.attr_combine_extra, zoo_images.domain[-1])
429+
430+
def test_sparse(self):
431+
"""
432+
Merge should work with sparse.
433+
GH-2295
434+
GH-2155
435+
"""
436+
data = Table("iris")[::25]
437+
data_ed_dense = Table("titanic")[::300]
438+
data_ed_sparse = Table("titanic")[::300]
439+
data_ed_sparse.X = sp.csr_matrix(data_ed_sparse.X)
440+
self.send_signal("Data", data)
441+
442+
self.send_signal("Extra Data", data_ed_dense)
443+
output_dense = self.get_output("Data")
444+
self.assertFalse(sp.issparse(output_dense.X))
445+
self.assertFalse(output_dense.is_sparse())
446+
447+
self.send_signal("Extra Data", data_ed_sparse)
448+
output_sparse = self.get_output("Data")
449+
self.assertTrue(sp.issparse(output_sparse.X))
450+
self.assertTrue(output_sparse.is_sparse())
451+
452+
output_sparse.X = output_sparse.X.toarray()
453+
self.assertTablesEqual(output_dense, output_sparse)

0 commit comments

Comments
 (0)