11from itertools import chain
22
33import numpy as np
4+ import scipy .sparse as sp
45
56from AnyQt .QtCore import Qt , QAbstractTableModel
67from AnyQt .QtGui import QColor
78from AnyQt .QtWidgets import QComboBox , QTableView , QSizePolicy
89
910from Orange .data import DiscreteVariable , ContinuousVariable , StringVariable , \
1011 TimeVariable , Domain
12+ from Orange .statistics .util import unique
1113from Orange .widgets import gui
1214from Orange .widgets .gui import HorizontalGridDelegate
1315from Orange .widgets .settings import ContextSetting
@@ -215,6 +217,20 @@ def get_domain(self, domain, data):
215217 def is_missing (x ):
216218 return str (x ) in ("nan" , "" )
217219
220+ def iter_vals (x ):
221+ """Iterate over values of sparse or dense arrays."""
222+ for i in range (x .shape [0 ]):
223+ yield x [i , 0 ]
224+
225+ def to_column (x , to_sparse , dtype = None ):
226+ """Transform list of values to sparse/dense column array."""
227+ x = np .array (x , dtype = dtype ).reshape (- 1 , 1 )
228+ if to_sparse :
229+ if dtype is not None :
230+ raise ValueError ('Cannot set dtype on sparse matrix.' )
231+ x = sp .csc_matrix (x )
232+ return x
233+
218234 for (name , tpe , place , _ , _ ), (orig_var , orig_plc ) in \
219235 zip (variables ,
220236 chain ([(at , Place .feature ) for at in domain .attributes ],
@@ -225,31 +241,51 @@ def is_missing(x):
225241 if orig_plc == Place .meta :
226242 col_data = data [:, orig_var ].metas
227243 elif orig_plc == Place .class_var :
228- col_data = data [:, orig_var ].Y
244+ col_data = data [:, orig_var ].Y . reshape ( - 1 , 1 )
229245 else :
230246 col_data = data [:, orig_var ].X
231- col_data = col_data . ravel ( )
247+ is_sparse = sp . issparse ( col_data )
232248 if name == orig_var .name and tpe == type (orig_var ):
233249 var = orig_var
234250 elif tpe == type (orig_var ):
235251 # change the name so that all_vars will get the correct name
236252 orig_var .name = name
237253 var = orig_var
238254 elif tpe == DiscreteVariable :
239- values = list (str (i ) for i in np . unique (col_data ) if not is_missing (i ))
255+ values = list (str (i ) for i in unique (col_data ) if not is_missing (i ))
240256 var = tpe (name , values )
241257 col_data = [np .nan if is_missing (x ) else values .index (str (x ))
242- for x in col_data ]
258+ for x in iter_vals (col_data )]
259+ col_data = to_column (col_data , is_sparse )
243260 elif tpe == StringVariable and type (orig_var ) == DiscreteVariable :
244261 var = tpe (name )
245262 col_data = [orig_var .repr_val (x ) if not np .isnan (x ) else ""
246- for x in col_data ]
263+ for x in iter_vals (col_data )]
264+ col_data = to_column (col_data , is_sparse , dtype = object )
247265 else :
248266 var = tpe (name )
249267 places [place ].append (var )
250268 cols [place ].append (col_data )
269+
270+ # merge columns for X, Y and metas
271+ def merge (cols , assure_dense = False ):
272+ if len (cols ) == 0 :
273+ return None
274+ if assure_dense and any (sp .issparse (c ) for c in cols ):
275+ cols = [c .toarray () if sp .issparse (c ) else c for c in cols ]
276+ if not any (sp .issparse (c ) for c in cols ):
277+ return np .hstack (cols )
278+ if not all (sp .issparse (c ) for c in cols ):
279+ cols = [c if sp .issparse (c ) else sp .csc_matrix (c )
280+ for c in cols ]
281+ return sp .hstack (cols ).tocsr ()
282+
283+ feats = cols [Place .feature ]
284+ X = merge (feats ) if feats else np .empty ((len (data ), 0 ))
285+ Y = merge (cols [Place .class_var ], assure_dense = True )
286+ m = merge (cols [Place .meta ], assure_dense = True )
251287 domain = Domain (* places )
252- return domain , cols
288+ return domain , [ X , Y , m ]
253289
254290 def set_domain (self , domain ):
255291 self .variables = self .parse_domain (domain )
0 commit comments