Skip to content

Commit 5524ed9

Browse files
authored
Merge pull request #2245 from nikicc/edit-domain-sparse
[FIX] Support Sparse Data in Domain Editor
2 parents 1d37f63 + bc60360 commit 5524ed9

File tree

4 files changed

+82
-21
lines changed

4 files changed

+82
-21
lines changed

Orange/statistics/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def nanmean(x):
290290
return np.nansum(x.data) / n_values
291291

292292

293-
def unique(x, return_counts=True):
293+
def unique(x, return_counts=False):
294294
""" Equivalent of np.unique that supports sparse or dense matrices. """
295295
if not sp.issparse(x):
296296
return np.unique(x, return_counts=return_counts)

Orange/widgets/data/owfile.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,6 @@ def apply_domain_edit(self):
403403
if self.data is not None:
404404
domain, cols = self.domain_editor.get_domain(self.data.domain, self.data)
405405
X, y, m = cols
406-
X = np.array(X).T if len(X) else np.empty((len(self.data), 0))
407-
y = np.array(y).T if len(y) else None
408-
dtpe = object if any(isinstance(m, StringVariable)
409-
for m in domain.metas) else float
410-
m = np.array(m, dtype=dtpe).T if len(m) else None
411406
table = Table.from_numpy(domain, X, y, m, self.data.W)
412407
table.name = self.data.name
413408
table.ids = np.array(self.data.ids)

Orange/widgets/data/tests/test_owfile.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
# pylint: disable=missing-docstring
33
from os import path, remove
44
from unittest.mock import Mock
5+
import pickle
6+
import tempfile
7+
58

69
import numpy as np
10+
import scipy.sparse as sp
711

812
from AnyQt.QtCore import QMimeData, QPoint, Qt, QUrl
913
from AnyQt.QtGui import QDragEnterEvent, QDropEvent
@@ -195,3 +199,19 @@ def test_check_datetime_disabled(self):
195199
for i in range(4):
196200
vartype_delegate.setEditorData(combo, idx(i))
197201
self.assertEqual(combo.count(), counts[i])
202+
203+
def test_domain_edit_on_sparse_data(self):
204+
iris = Table("iris")
205+
iris.X = sp.csr_matrix(iris.X)
206+
207+
f = tempfile.NamedTemporaryFile(suffix='.pickle', delete=False)
208+
pickle.dump(iris, f)
209+
f.close()
210+
211+
self.widget.add_path(f.name)
212+
self.widget.load_data()
213+
214+
output = self.get_output("Data")
215+
self.assertIsInstance(output, Table)
216+
self.assertEqual(iris.X.shape, output.X.shape)
217+
self.assertTrue(sp.issparse(output.X))

Orange/widgets/utils/domaineditor.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from itertools import chain
22

33
import numpy as np
4+
import scipy.sparse as sp
45

56
from AnyQt.QtCore import Qt, QAbstractTableModel
67
from AnyQt.QtGui import QColor
78
from AnyQt.QtWidgets import QComboBox, QTableView, QSizePolicy
89

910
from Orange.data import DiscreteVariable, ContinuousVariable, StringVariable, \
1011
TimeVariable, Domain
12+
from Orange.statistics.util import unique
1113
from Orange.widgets import gui
1214
from Orange.widgets.gui import HorizontalGridDelegate
1315
from Orange.widgets.settings import ContextSetting
@@ -196,6 +198,37 @@ def __init__(self, widget):
196198
self.place_delegate = PlaceDelegate(self, VarTableModel.places)
197199
self.setItemDelegateForColumn(Column.place, self.place_delegate)
198200

201+
@staticmethod
202+
def _is_missing(x):
203+
return str(x) in ("nan", "")
204+
205+
@staticmethod
206+
def _iter_vals(x):
207+
"""Iterate over values of sparse or dense arrays."""
208+
for i in range(x.shape[0]):
209+
yield x[i, 0]
210+
211+
@staticmethod
212+
def _to_column(x, to_sparse, dtype=None):
213+
"""Transform list of values to sparse/dense column array."""
214+
x = np.array(x, dtype=dtype).reshape(-1, 1)
215+
if to_sparse:
216+
x = sp.csc_matrix(x)
217+
return x
218+
219+
@staticmethod
220+
def _merge(cols, force_dense=False):
221+
if len(cols) == 0:
222+
return None
223+
224+
all_dense = not any(sp.issparse(c) for c in cols)
225+
if all_dense:
226+
return np.hstack(cols)
227+
if force_dense:
228+
return np.hstack([c.toarray() if sp.issparse(c) else c for c in cols])
229+
sparse_cols = [c if sp.issparse(c) else sp.csc_matrix(c) for c in cols]
230+
return sp.hstack(sparse_cols).tocsr()
231+
199232
def get_domain(self, domain, data):
200233
"""Create domain (and dataset) from changes made in the widget.
201234
@@ -212,44 +245,57 @@ def get_domain(self, domain, data):
212245
places = [[], [], []] # attributes, class_vars, metas
213246
cols = [[], [], []] # Xcols, Ycols, Mcols
214247

215-
def is_missing(x):
216-
return str(x) in ("nan", "")
217-
218248
for (name, tpe, place, _, _), (orig_var, orig_plc) in \
219249
zip(variables,
220250
chain([(at, Place.feature) for at in domain.attributes],
221251
[(cl, Place.class_var) for cl in domain.class_vars],
222252
[(mt, Place.meta) for mt in domain.metas])):
223253
if place == Place.skip:
224254
continue
225-
if orig_plc == Place.meta:
226-
col_data = data[:, orig_var].metas
227-
elif orig_plc == Place.class_var:
228-
col_data = data[:, orig_var].Y
229-
else:
230-
col_data = data[:, orig_var].X
231-
col_data = col_data.ravel()
255+
256+
col_data = self._get_column(data, orig_var, orig_plc)
257+
is_sparse = sp.issparse(col_data)
232258
if name == orig_var.name and tpe == type(orig_var):
233259
var = orig_var
234260
elif tpe == type(orig_var):
235261
# change the name so that all_vars will get the correct name
236262
orig_var.name = name
237263
var = orig_var
238264
elif tpe == DiscreteVariable:
239-
values = list(str(i) for i in np.unique(col_data) if not is_missing(i))
265+
values = list(str(i) for i in unique(col_data) if not self._is_missing(i))
240266
var = tpe(name, values)
241-
col_data = [np.nan if is_missing(x) else values.index(str(x))
242-
for x in col_data]
267+
col_data = [np.nan if self._is_missing(x) else values.index(str(x))
268+
for x in self._iter_vals(col_data)]
269+
col_data = self._to_column(col_data, is_sparse)
243270
elif tpe == StringVariable and type(orig_var) == DiscreteVariable:
244271
var = tpe(name)
245272
col_data = [orig_var.repr_val(x) if not np.isnan(x) else ""
246-
for x in col_data]
273+
for x in self._iter_vals(col_data)]
274+
# don't obey sparsity for StringVariable since they are
275+
# in metas which are transformed to dense below
276+
col_data = self._to_column(col_data, False, dtype=object)
247277
else:
248278
var = tpe(name)
249279
places[place].append(var)
250280
cols[place].append(col_data)
281+
282+
# merge columns for X, Y and metas
283+
feats = cols[Place.feature]
284+
X = self._merge(feats) if len(feats) else np.empty((len(data), 0))
285+
Y = self._merge(cols[Place.class_var], force_dense=True)
286+
m = self._merge(cols[Place.meta], force_dense=True)
251287
domain = Domain(*places)
252-
return domain, cols
288+
return domain, [X, Y, m]
289+
290+
def _get_column(self, data, source_var, source_place):
291+
""" Extract column from data and preserve sparsity. """
292+
if source_place == Place.meta:
293+
col_data = data[:, source_var].metas
294+
elif source_place == Place.class_var:
295+
col_data = data[:, source_var].Y.reshape(-1, 1)
296+
else:
297+
col_data = data[:, source_var].X
298+
return col_data
253299

254300
def set_domain(self, domain):
255301
self.variables = self.parse_domain(domain)

0 commit comments

Comments
 (0)