Skip to content

Commit c553c5f

Browse files
committed
DomainEditor: Support sparse data
1 parent 480f116 commit c553c5f

File tree

2 files changed

+51
-15
lines changed

2 files changed

+51
-15
lines changed

Orange/widgets/data/owfile.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,6 @@ def apply_domain_edit(self):
403403
if self.data is not None:
404404
domain, cols = self.domain_editor.get_domain(self.data.domain, self.data)
405405
X, y, m = cols
406-
X = np.array(X).T if len(X) else np.empty((len(self.data), 0))
407-
y = np.array(y).T if len(y) else None
408-
dtpe = object if any(isinstance(m, StringVariable)
409-
for m in domain.metas) else float
410-
m = np.array(m, dtype=dtpe).T if len(m) else None
411406
table = Table.from_numpy(domain, X, y, m, self.data.W)
412407
table.name = self.data.name
413408
table.ids = np.array(self.data.ids)

Orange/widgets/utils/domaineditor.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from itertools import chain
22

33
import numpy as np
4+
import scipy.sparse as sp
45

56
from AnyQt.QtCore import Qt, QAbstractTableModel
67
from AnyQt.QtGui import QColor
78
from AnyQt.QtWidgets import QComboBox, QTableView, QSizePolicy
89

910
from Orange.data import DiscreteVariable, ContinuousVariable, StringVariable, \
1011
TimeVariable, Domain
12+
from Orange.statistics.util import unique
1113
from Orange.widgets import gui
1214
from Orange.widgets.gui import HorizontalGridDelegate
1315
from Orange.widgets.settings import ContextSetting
@@ -196,6 +198,39 @@ def __init__(self, widget):
196198
self.place_delegate = PlaceDelegate(self, VarTableModel.places)
197199
self.setItemDelegateForColumn(Column.place, self.place_delegate)
198200

201+
@staticmethod
202+
def _is_missing(x):
203+
return str(x) in ("nan", "")
204+
205+
@staticmethod
206+
def _iter_vals(x):
207+
"""Iterate over values of sparse or dense arrays."""
208+
for i in range(x.shape[0]):
209+
yield x[i, 0]
210+
211+
@staticmethod
212+
def _to_column(x, to_sparse, dtype=None):
213+
"""Transform list of values to sparse/dense column array."""
214+
x = np.array(x, dtype=dtype).reshape(-1, 1)
215+
if to_sparse:
216+
if dtype is not None:
217+
raise ValueError('Cannot set dtype on sparse matrix.')
218+
x = sp.csc_matrix(x)
219+
return x
220+
221+
@staticmethod
222+
def _merge(cols, force_dense=False):
223+
if len(cols) == 0:
224+
return None
225+
226+
all_dense = not any(sp.issparse(c) for c in cols)
227+
if all_dense:
228+
return np.hstack(cols)
229+
if force_dense:
230+
return np.hstack([c.toarray() if sp.issparse(c) else c for c in cols])
231+
sparse_cols = [c if sp.issparse(c) else sp.csc_matrix(c) for c in cols]
232+
return sp.hstack(sparse_cols).tocsr()
233+
199234
def get_domain(self, domain, data):
200235
"""Create domain (and dataset) from changes made in the widget.
201236
@@ -212,44 +247,50 @@ def get_domain(self, domain, data):
212247
places = [[], [], []] # attributes, class_vars, metas
213248
cols = [[], [], []] # Xcols, Ycols, Mcols
214249

215-
def is_missing(x):
216-
return str(x) in ("nan", "")
217-
218250
for (name, tpe, place, _, _), (orig_var, orig_plc) in \
219251
zip(variables,
220252
chain([(at, Place.feature) for at in domain.attributes],
221253
[(cl, Place.class_var) for cl in domain.class_vars],
222254
[(mt, Place.meta) for mt in domain.metas])):
223255
if place == Place.skip:
224256
continue
257+
225258
if orig_plc == Place.meta:
226259
col_data = data[:, orig_var].metas
227260
elif orig_plc == Place.class_var:
228-
col_data = data[:, orig_var].Y
261+
col_data = data[:, orig_var].Y.reshape(-1, 1)
229262
else:
230263
col_data = data[:, orig_var].X
231-
col_data = col_data.ravel()
264+
is_sparse = sp.issparse(col_data)
232265
if name == orig_var.name and tpe == type(orig_var):
233266
var = orig_var
234267
elif tpe == type(orig_var):
235268
# change the name so that all_vars will get the correct name
236269
orig_var.name = name
237270
var = orig_var
238271
elif tpe == DiscreteVariable:
239-
values = list(str(i) for i in np.unique(col_data) if not is_missing(i))
272+
values = list(str(i) for i in unique(col_data) if not self._is_missing(i))
240273
var = tpe(name, values)
241-
col_data = [np.nan if is_missing(x) else values.index(str(x))
242-
for x in col_data]
274+
col_data = [np.nan if self._is_missing(x) else values.index(str(x))
275+
for x in self._iter_vals(col_data)]
276+
col_data = self._to_column(col_data, is_sparse)
243277
elif tpe == StringVariable and type(orig_var) == DiscreteVariable:
244278
var = tpe(name)
245279
col_data = [orig_var.repr_val(x) if not np.isnan(x) else ""
246-
for x in col_data]
280+
for x in self._iter_vals(col_data)]
281+
col_data = self._to_column(col_data, is_sparse, dtype=object)
247282
else:
248283
var = tpe(name)
249284
places[place].append(var)
250285
cols[place].append(col_data)
286+
287+
# merge columns for X, Y and metas
288+
feats = cols[Place.feature]
289+
X = self._merge(feats) or np.empty((len(data), 0))
290+
Y = self._merge(cols[Place.class_var], force_dense=True)
291+
m = self._merge(cols[Place.meta], force_dense=True)
251292
domain = Domain(*places)
252-
return domain, cols
293+
return domain, [X, Y, m]
253294

254295
def set_domain(self, domain):
255296
self.variables = self.parse_domain(domain)

0 commit comments

Comments
 (0)