Skip to content

Commit 6e8eab1

Browse files
authored
Merge pull request #1456 from nikicc/fix-ensure-copy
[FIX] Table: Fix ensure_copy for sparse matrices
2 parents dc8cc2c + 75bbe05 commit 6e8eab1

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

Orange/data/table.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -912,15 +912,20 @@ def is_copy(self):
912912

913913
def ensure_copy(self):
914914
"""
915-
Ensure that the table owns its data; copy arrays when necessary
915+
Ensure that the table owns its data; copy arrays when necessary.
916916
"""
917-
if self.X.base is not None:
917+
def is_view(x):
918+
# Sparse matrices don't have views like numpy arrays. Since indexing on
919+
# them creates copies in constructor we can skip this check here.
920+
return not sp.issparse(x) and x.base is not None
921+
922+
if is_view(self.X):
918923
self.X = self.X.copy()
919-
if self._Y.base is not None:
924+
if is_view(self._Y):
920925
self._Y = self._Y.copy()
921-
if self.metas.base is not None:
926+
if is_view(self.metas):
922927
self.metas = self.metas.copy()
923-
if self.W.base is not None:
928+
if is_view(self.W):
924929
self.W = self.W.copy()
925930

926931
def copy(self):

Orange/tests/test_table.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,19 @@ def test_copy(self):
649649
self.assertFalse(np.all(t.Y == copy.Y))
650650
self.assertFalse(np.all(t.metas == copy.metas))
651651

652+
def test_copy_sparse(self):
653+
t = data.Table('iris')
654+
t.X = csr_matrix(t.X)
655+
copy = t.copy()
656+
657+
self.assertEqual((t.X != copy.X).nnz, 0) # sparse matrices match by content
658+
np.testing.assert_equal(t.Y, copy.Y)
659+
np.testing.assert_equal(t.metas, copy.metas)
660+
661+
self.assertNotEqual(id(t.X), id(copy.X))
662+
self.assertNotEqual(id(t._Y), id(copy._Y))
663+
self.assertNotEqual(id(t.metas), id(copy.metas))
664+
652665
def test_concatenate(self):
653666
d1 = data.Domain([data.ContinuousVariable('a1')])
654667
t1 = data.Table.from_numpy(d1, [[1],

0 commit comments

Comments
 (0)