Skip to content

Commit 21b2504

Browse files
committed
DomainEditor: Support sparse data
1 parent a7ca8be commit 21b2504

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

Orange/widgets/data/owfile.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,6 @@ def apply_domain_edit(self):
403403
if self.data is not None:
404404
domain, cols = self.domain_editor.get_domain(self.data.domain, self.data)
405405
X, y, m = cols
406-
X = np.array(X).T if len(X) else np.empty((len(self.data), 0))
407-
y = np.array(y).T if len(y) else None
408-
dtpe = object if any(isinstance(m, StringVariable)
409-
for m in domain.metas) else float
410-
m = np.array(m, dtype=dtpe).T if len(m) else None
411406
table = Table.from_numpy(domain, X, y, m, self.data.W)
412407
table.name = self.data.name
413408
table.ids = np.array(self.data.ids)

Orange/widgets/utils/domaineditor.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from itertools import chain
22

33
import numpy as np
4+
import scipy.sparse as sp
45

56
from AnyQt.QtCore import Qt, QAbstractTableModel
67
from AnyQt.QtGui import QColor
78
from AnyQt.QtWidgets import QComboBox, QTableView, QSizePolicy
89

910
from Orange.data import DiscreteVariable, ContinuousVariable, StringVariable, \
1011
TimeVariable, Domain
12+
from Orange.statistics.util import unique
1113
from Orange.widgets import gui
1214
from Orange.widgets.gui import HorizontalGridDelegate
1315
from Orange.widgets.settings import ContextSetting
@@ -215,6 +217,20 @@ def get_domain(self, domain, data):
215217
def is_missing(x):
216218
return str(x) in ("nan", "")
217219

220+
def iter_vals(x):
221+
"""Iterate over values of sparse or dense arrays."""
222+
for i in range(x.shape[0]):
223+
yield x[i, 0]
224+
225+
def to_column(x, to_sparse, dtype=None):
226+
"""Transform list of values to sparse/dense column array."""
227+
x = np.array(x, dtype=dtype).reshape(-1, 1)
228+
if to_sparse:
229+
if dtype is not None:
230+
raise ValueError('Cannot set dtype on sparse matrix.')
231+
x = sp.csc_matrix(x)
232+
return x
233+
218234
for (name, tpe, place, _, _), (orig_var, orig_plc) in \
219235
zip(variables,
220236
chain([(at, Place.feature) for at in domain.attributes],
@@ -225,31 +241,51 @@ def is_missing(x):
225241
if orig_plc == Place.meta:
226242
col_data = data[:, orig_var].metas
227243
elif orig_plc == Place.class_var:
228-
col_data = data[:, orig_var].Y
244+
col_data = data[:, orig_var].Y.reshape(-1, 1)
229245
else:
230246
col_data = data[:, orig_var].X
231-
col_data = col_data.ravel()
247+
is_sparse = sp.issparse(col_data)
232248
if name == orig_var.name and tpe == type(orig_var):
233249
var = orig_var
234250
elif tpe == type(orig_var):
235251
# change the name so that all_vars will get the correct name
236252
orig_var.name = name
237253
var = orig_var
238254
elif tpe == DiscreteVariable:
239-
values = list(str(i) for i in np.unique(col_data) if not is_missing(i))
255+
values = list(str(i) for i in unique(col_data) if not is_missing(i))
240256
var = tpe(name, values)
241257
col_data = [np.nan if is_missing(x) else values.index(str(x))
242-
for x in col_data]
258+
for x in iter_vals(col_data)]
259+
col_data = to_column(col_data, is_sparse)
243260
elif tpe == StringVariable and type(orig_var) == DiscreteVariable:
244261
var = tpe(name)
245262
col_data = [orig_var.repr_val(x) if not np.isnan(x) else ""
246-
for x in col_data]
263+
for x in iter_vals(col_data)]
264+
col_data = to_column(col_data, is_sparse, dtype=object)
247265
else:
248266
var = tpe(name)
249267
places[place].append(var)
250268
cols[place].append(col_data)
269+
270+
# merge columns for X, Y and metas
271+
def merge(cols, assure_dense=False):
272+
if len(cols) == 0:
273+
return None
274+
if assure_dense and any(sp.issparse(c) for c in cols):
275+
cols = [c.toarray() if sp.issparse(c) else c for c in cols]
276+
if not any(sp.issparse(c) for c in cols):
277+
return np.hstack(cols)
278+
if not all(sp.issparse(c) for c in cols):
279+
cols = [c if sp.issparse(c) else sp.csc_matrix(c)
280+
for c in cols]
281+
return sp.hstack(cols).tocsr()
282+
283+
feats = cols[Place.feature]
284+
X = merge(feats) if feats else np.empty((len(data), 0))
285+
Y = merge(cols[Place.class_var], assure_dense=True)
286+
m = merge(cols[Place.meta], assure_dense=True)
251287
domain = Domain(*places)
252-
return domain, cols
288+
return domain, [X, Y, m]
253289

254290
def set_domain(self, domain):
255291
self.variables = self.parse_domain(domain)

0 commit comments

Comments
 (0)