Skip to content

Commit c54a56a

Browse files
committed
OWFile: Add test for loading sparse data
1 parent 21b2504 commit c54a56a

File tree

2 files changed

+56
-33
lines changed

2 files changed

+56
-33
lines changed

Orange/widgets/data/tests/test_owfile.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
# pylint: disable=missing-docstring
33
from os import path, remove
44
from unittest.mock import Mock
5+
import pickle
6+
import tempfile
7+
58

69
import numpy as np
10+
import scipy.sparse as sp
711

812
from AnyQt.QtCore import QMimeData, QPoint, Qt, QUrl
913
from AnyQt.QtGui import QDragEnterEvent, QDropEvent
@@ -195,3 +199,19 @@ def test_check_datetime_disabled(self):
195199
for i in range(4):
196200
vartype_delegate.setEditorData(combo, idx(i))
197201
self.assertEqual(combo.count(), counts[i])
202+
203+
def test_domain_edit_on_sparse_data(self):
204+
iris = Table("iris")
205+
iris.X = sp.csr_matrix(iris.X)
206+
207+
f = tempfile.NamedTemporaryFile(suffix='.pickle', delete=False)
208+
pickle.dump(iris, f)
209+
f.close()
210+
211+
self.widget.add_path(f.name)
212+
self.widget.load_data()
213+
214+
output = self.get_output("Data")
215+
self.assertIsInstance(output, Table)
216+
self.assertEqual(iris.X.shape, output.X.shape)
217+
self.assertTrue(sp.issparse(output.X))

Orange/widgets/utils/domaineditor.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,35 @@ def __init__(self, widget):
198198
self.place_delegate = PlaceDelegate(self, VarTableModel.places)
199199
self.setItemDelegateForColumn(Column.place, self.place_delegate)
200200

201+
@staticmethod
202+
def _iter_vals(x):
203+
"""Iterate over values of sparse or dense arrays."""
204+
for i in range(x.shape[0]):
205+
yield x[i, 0]
206+
207+
@staticmethod
208+
def _to_column(x, to_sparse, dtype=None):
209+
"""Transform list of values to sparse/dense column array."""
210+
x = np.array(x, dtype=dtype).reshape(-1, 1)
211+
if to_sparse:
212+
if dtype is not None:
213+
raise ValueError('Cannot set dtype on sparse matrix.')
214+
x = sp.csc_matrix(x)
215+
return x
216+
217+
@staticmethod
218+
def _merge(cols, assure_dense=False):
219+
if len(cols) == 0:
220+
return None
221+
if assure_dense and any(sp.issparse(c) for c in cols):
222+
cols = [c.toarray() if sp.issparse(c) else c for c in cols]
223+
if not any(sp.issparse(c) for c in cols):
224+
return np.hstack(cols)
225+
if not all(sp.issparse(c) for c in cols):
226+
cols = [c if sp.issparse(c) else sp.csc_matrix(c)
227+
for c in cols]
228+
return sp.hstack(cols).tocsr()
229+
201230
def get_domain(self, domain, data):
202231
"""Create domain (and dataset) from changes made in the widget.
203232
@@ -217,20 +246,6 @@ def get_domain(self, domain, data):
217246
def is_missing(x):
218247
return str(x) in ("nan", "")
219248

220-
def iter_vals(x):
221-
"""Iterate over values of sparse or dense arrays."""
222-
for i in range(x.shape[0]):
223-
yield x[i, 0]
224-
225-
def to_column(x, to_sparse, dtype=None):
226-
"""Transform list of values to sparse/dense column array."""
227-
x = np.array(x, dtype=dtype).reshape(-1, 1)
228-
if to_sparse:
229-
if dtype is not None:
230-
raise ValueError('Cannot set dtype on sparse matrix.')
231-
x = sp.csc_matrix(x)
232-
return x
233-
234249
for (name, tpe, place, _, _), (orig_var, orig_plc) in \
235250
zip(variables,
236251
chain([(at, Place.feature) for at in domain.attributes],
@@ -255,35 +270,23 @@ def to_column(x, to_sparse, dtype=None):
255270
values = list(str(i) for i in unique(col_data) if not is_missing(i))
256271
var = tpe(name, values)
257272
col_data = [np.nan if is_missing(x) else values.index(str(x))
258-
for x in iter_vals(col_data)]
259-
col_data = to_column(col_data, is_sparse)
273+
for x in self._iter_vals(col_data)]
274+
col_data = self._to_column(col_data, is_sparse)
260275
elif tpe == StringVariable and type(orig_var) == DiscreteVariable:
261276
var = tpe(name)
262277
col_data = [orig_var.repr_val(x) if not np.isnan(x) else ""
263-
for x in iter_vals(col_data)]
264-
col_data = to_column(col_data, is_sparse, dtype=object)
278+
for x in self._iter_vals(col_data)]
279+
col_data = self._to_column(col_data, is_sparse, dtype=object)
265280
else:
266281
var = tpe(name)
267282
places[place].append(var)
268283
cols[place].append(col_data)
269284

270285
# merge columns for X, Y and metas
271-
def merge(cols, assure_dense=False):
272-
if len(cols) == 0:
273-
return None
274-
if assure_dense and any(sp.issparse(c) for c in cols):
275-
cols = [c.toarray() if sp.issparse(c) else c for c in cols]
276-
if not any(sp.issparse(c) for c in cols):
277-
return np.hstack(cols)
278-
if not all(sp.issparse(c) for c in cols):
279-
cols = [c if sp.issparse(c) else sp.csc_matrix(c)
280-
for c in cols]
281-
return sp.hstack(cols).tocsr()
282-
283286
feats = cols[Place.feature]
284-
X = merge(feats) if feats else np.empty((len(data), 0))
285-
Y = merge(cols[Place.class_var], assure_dense=True)
286-
m = merge(cols[Place.meta], assure_dense=True)
287+
X = self._merge(feats) if feats else np.empty((len(data), 0))
288+
Y = self._merge(cols[Place.class_var], assure_dense=True)
289+
m = self._merge(cols[Place.meta], assure_dense=True)
287290
domain = Domain(*places)
288291
return domain, [X, Y, m]
289292

0 commit comments

Comments
 (0)