@@ -268,6 +268,27 @@ def from_table(cls, domain, source, row_indices=...):
268268
269269 def get_columns (row_indices , src_cols , n_rows , dtype = np .float64 ,
270270 is_sparse = False ):
271+ def match_type (x , force_1d = False ):
272+ """ Assure that matrix and column are both dense or sparse.
273+
274+ Args:
275+ x (np.ndarray, scipy.sparse): data
276+ force_1d (bool): If set, flatten resulting array to 1d.
277+
278+ Returns:
279+ array of correct density.
280+ """
281+ if is_sparse == sp .issparse (x ):
282+ return x
283+ if is_sparse :
284+ x = np .asarray (x )
285+ return sp .csc_matrix (x .reshape (- 1 , 1 ).astype (np .float ))
286+ x = x .toarray ()
287+ if force_1d :
288+ x = np .ravel (x )
289+ return x
290+
291+ match_type_1d = lambda x : match_type (x , force_1d = True )
271292
272293 if not len (src_cols ):
273294 if is_sparse :
@@ -278,33 +299,23 @@ def get_columns(row_indices, src_cols, n_rows, dtype=np.float64,
278299 n_src_attrs = len (source .domain .attributes )
279300 if all (isinstance (x , Integral ) and 0 <= x < n_src_attrs
280301 for x in src_cols ):
281- return _subarray (source .X , row_indices , src_cols )
302+ return match_type ( _subarray (source .X , row_indices , src_cols ) )
282303 if all (isinstance (x , Integral ) and x < 0 for x in src_cols ):
283- arr = _subarray (source .metas , row_indices ,
284- [- 1 - x for x in src_cols ])
304+ arr = match_type ( _subarray (source .metas , row_indices ,
305+ [- 1 - x for x in src_cols ]))
285306 if arr .dtype != dtype :
286307 return arr .astype (dtype )
287308 return arr
288309 if all (isinstance (x , Integral ) and x >= n_src_attrs
289310 for x in src_cols ):
290- return _subarray (source ._Y , row_indices ,
291- [x - n_src_attrs for x in src_cols ])
311+ return match_type ( _subarray (source ._Y , row_indices ,
312+ [x - n_src_attrs for x in src_cols ]))
292313
293314 if is_sparse :
294315 a = sp .dok_matrix ((n_rows , len (src_cols )), dtype = dtype )
295316 else :
296317 a = np .empty ((n_rows , len (src_cols )), dtype = dtype )
297318
298- def match_type (x ):
299- """ Assure that matrix and column are both dense or sparse. """
300- if is_sparse == sp .issparse (x ):
301- return x
302- elif is_sparse :
303- x = np .asarray (x )
304- return sp .csc_matrix (x .reshape (- 1 , 1 ).astype (np .float ))
305- else :
306- return np .ravel (x .toarray ())
307-
308319 shared_cache = _conversion_cache
309320 for i , col in enumerate (src_cols ):
310321 if col is None :
@@ -316,22 +327,22 @@ def match_type(x):
316327 col .compute_shared (source )
317328 shared = shared_cache [id (col .compute_shared ), id (source )]
318329 if row_indices is not ...:
319- a [:, i ] = match_type (
330+ a [:, i ] = match_type_1d (
320331 col (source , shared_data = shared )[row_indices ])
321332 else :
322- a [:, i ] = match_type (
333+ a [:, i ] = match_type_1d (
323334 col (source , shared_data = shared ))
324335 else :
325336 if row_indices is not ...:
326- a [:, i ] = match_type (col (source )[row_indices ])
337+ a [:, i ] = match_type_1d (col (source )[row_indices ])
327338 else :
328- a [:, i ] = match_type (col (source ))
339+ a [:, i ] = match_type_1d (col (source ))
329340 elif col < 0 :
330- a [:, i ] = match_type (source .metas [row_indices , - 1 - col ])
341+ a [:, i ] = match_type_1d (source .metas [row_indices , - 1 - col ])
331342 elif col < n_src_attrs :
332- a [:, i ] = match_type (source .X [row_indices , col ])
343+ a [:, i ] = match_type_1d (source .X [row_indices , col ])
333344 else :
334- a [:, i ] = match_type (
345+ a [:, i ] = match_type_1d (
335346 source ._Y [row_indices , col - n_src_attrs ])
336347
337348 if is_sparse :
0 commit comments