@@ -268,6 +268,15 @@ def from_table(cls, domain, source, row_indices=...):
268268
269269 def get_columns (row_indices , src_cols , n_rows , dtype = np .float64 ,
270270 is_sparse = False ):
271+ def match_type (x ):
272+ """ Assure that matrix and column are both dense or sparse. """
273+ if is_sparse == sp .issparse (x ):
274+ return x
275+ elif is_sparse :
276+ x = np .asarray (x )
277+ return sp .csc_matrix (x .reshape (- 1 , 1 ).astype (np .float ))
278+ else :
279+ return np .ravel (x .toarray ())
271280
272281 if not len (src_cols ):
273282 if is_sparse :
@@ -278,33 +287,23 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
278287 n_src_attrs = len (source .domain .attributes )
279288 if all (isinstance (x , Integral ) and 0 <= x < n_src_attrs
280289 for x in src_cols ):
281- return _subarray (source .X , row_indices , src_cols )
290+ return match_type ( _subarray (source .X , row_indices , src_cols ) )
282291 if all (isinstance (x , Integral ) and x < 0 for x in src_cols ):
283- arr = _subarray (source .metas , row_indices ,
284- [- 1 - x for x in src_cols ])
292+ arr = match_type ( _subarray (source .metas , row_indices ,
293+ [- 1 - x for x in src_cols ]))
285294 if arr .dtype != dtype :
286295 return arr .astype (dtype )
287296 return arr
288297 if all (isinstance (x , Integral ) and x >= n_src_attrs
289298 for x in src_cols ):
290- return _subarray (source ._Y , row_indices ,
291- [x - n_src_attrs for x in src_cols ])
299+ return match_type ( _subarray (source ._Y , row_indices ,
300+ [x - n_src_attrs for x in src_cols ]))
292301
293302 if is_sparse :
294303 a = sp .dok_matrix ((n_rows , len (src_cols )), dtype = dtype )
295304 else :
296305 a = np .empty ((n_rows , len (src_cols )), dtype = dtype )
297306
298- def match_type (x ):
299- """ Assure that matrix and column are both dense or sparse. """
300- if is_sparse == sp .issparse (x ):
301- return x
302- elif is_sparse :
303- x = np .asarray (x )
304- return sp .csc_matrix (x .reshape (- 1 , 1 ).astype (np .float ))
305- else :
306- return np .ravel (x .toarray ())
307-
308307 shared_cache = _conversion_cache
309308 for i , col in enumerate (src_cols ):
310309 if col is None :
0 commit comments