Skip to content

Commit 2c86913

Browse files
committed
OWDataProjectionWidget: Fix reloading sparse data
1 parent 3ce2111 commit 2c86913

File tree

4 files changed

+48
-5
lines changed

4 files changed

+48
-5
lines changed

Orange/data/util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,18 @@ def hstack(arrays):
9595
return np.hstack(arrays)
9696

9797

98+
def array_equal(a1, a2):
99+
"""array_equal that supports sparse and dense arrays with missing values"""
100+
if a1.shape != a2.shape:
101+
return False
102+
i1, j1, v1 = sp.find(a1)
103+
i2, j2, v2 = sp.find(a2)
104+
a1 = v1 if sp.issparse(a1) else a1[i2, j2]
105+
a2 = v2 if sp.issparse(a2) else a2[i1, j1]
106+
index_equal = set(zip(i1, j1)) == set(zip(i2, j2))
107+
return index_equal and np.allclose(a1, a2, equal_nan=True)
108+
109+
98110
def assure_array_dense(a):
99111
if sp.issparse(a):
100112
a = a.toarray()

Orange/tests/test_util.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from Orange.util import export_globals, flatten, deprecated, try_, deepgetattr, \
99
OrangeDeprecationWarning
1010
from Orange.data import Table
11-
from Orange.data.util import vstack, hstack
11+
from Orange.data.util import vstack, hstack, array_equal
1212
from Orange.statistics.util import stats
13+
from Orange.tests.test_statistics import dense_sparse
1314

1415
SOMETHING = 0xf00babe
1516

@@ -118,3 +119,28 @@ def test_stats_sparse(self):
118119
data = Table("iris")
119120
sparse_x = sp.csr_matrix(data.X)
120121
self.assertTrue(stats(data.X).all() == stats(sparse_x).all())
122+
123+
@dense_sparse
124+
def test_array_equal(self, array):
125+
a1 = array([[0., 2.], [3., np.nan]])
126+
a2 = array([[0., 2.], [3., np.nan]])
127+
self.assertTrue(array_equal(a1, a2))
128+
129+
a3 = np.array([[0., 2.], [3., np.nan]])
130+
self.assertTrue(array_equal(a1, a3))
131+
self.assertTrue(array_equal(a3, a1))
132+
133+
@dense_sparse
134+
def test_array_not_equal(self, array):
135+
a1 = array([[0., 2.], [3., np.nan]])
136+
a2 = array([[0., 2.], [4., np.nan]])
137+
self.assertFalse(array_equal(a1, a2))
138+
139+
a3 = array([[0., 2.], [3., np.nan], [4., 5.]])
140+
self.assertFalse(array_equal(a1, a3))
141+
142+
def test_csc_array_equal(self):
143+
a1 = sp.csc_matrix(([1, 4, 5], ([0, 0, 1], [0, 2, 2])), shape=(2, 3))
144+
a2 = sp.csc_matrix(([5, 1, 4], ([1, 0, 0], [2, 0, 2])), shape=(2, 3))
145+
a2[0, 1] = 0 # explicitly setting to 0
146+
self.assertTrue(array_equal(a1, a2))

Orange/widgets/visualize/tests/test_owprojectionwidget.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,10 @@ def test_get_coordinates_data(self):
173173
self.assertEqual(len(self.widget.get_coordinates_data()[0]), 9)
174174
self.widget.valid_data = np.zeros((10,), dtype=bool)
175175
self.assertIsNone(self.widget.get_coordinates_data()[0])
176+
177+
def test_sparse_data_reload(self):
178+
table = Table("heart_disease").to_sparse()
179+
self.widget.setup_plot = Mock()
180+
self.send_signal(self.widget.Inputs.data, table)
181+
self.send_signal(self.widget.Inputs.data, table)
182+
self.widget.setup_plot.assert_called_once()

Orange/widgets/visualize/utils/widget.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from Orange.data import (
99
Table, ContinuousVariable, Domain, Variable, StringVariable
1010
)
11-
from Orange.data.util import get_unique_names
11+
from Orange.data.util import get_unique_names, array_equal
1212
from Orange.data.sql.table import SqlTable
1313
from Orange.preprocess.preprocess import Preprocess, ApplyDomain
1414
from Orange.statistics.util import bincount
@@ -432,9 +432,7 @@ def set_data(self, data):
432432
self.use_context()
433433
self._invalidated = not (
434434
data_existed and self.data is not None and
435-
effective_data.X.shape == self.effective_data.X.shape and
436-
np.allclose(effective_data.X,
437-
self.effective_data.X, equal_nan=True))
435+
array_equal(effective_data.X, self.effective_data.X))
438436
if self._invalidated:
439437
self.clear()
440438
self.enable_controls()

0 commit comments

Comments
 (0)