@@ -287,18 +287,17 @@ def _set_inputs(self, inputs):
287287 def __call__ (self , df ):
288288 if self .col_names is not None :
289289 # if col_names is a list, return a DataFrame, else return a Series
290- if isinstance ( self ._col_names , list ):
291- dtypes = df . dtypes [ self . _col_names ]
292- columns = parse_index (pd . Index ( self . _col_names ) , store_data = True )
290+ dtype = df . dtypes [ self ._col_names ]
291+ if isinstance ( dtype , pd . Series ):
292+ columns = parse_index (dtype . index , store_data = True )
293293 return self .new_dataframe (
294294 [df ],
295- shape = (df .shape [0 ], len (self . _col_names )),
296- dtypes = dtypes ,
295+ shape = (df .shape [0 ], len (dtype )),
296+ dtypes = dtype ,
297297 index_value = df .index_value ,
298298 columns_value = columns ,
299299 )
300300 else :
301- dtype = df .dtypes [self ._col_names ]
302301 return self .new_series (
303302 [df ],
304303 shape = (df .shape [0 ],),
@@ -439,8 +438,8 @@ def tile_with_columns(cls, op):
439438 in_df = op .inputs [0 ]
440439 out_df = op .outputs [0 ]
441440 col_names = op .col_names
442- if not isinstance (col_names , list ):
443- column_index = calc_columns_index (col_names , in_df )
441+ if not isinstance (out_df , DATAFRAME_TYPE ):
442+ column_index = calc_columns_index (col_names , in_df )[ 0 ]
444443 out_chunks = []
445444 dtype = in_df .dtypes [col_names ]
446445 for i in range (in_df .chunk_shape [0 ]):
@@ -471,7 +470,10 @@ def tile_with_columns(cls, op):
471470 # When chunk columns are ['c1', 'c2', 'c3'], ['c4', 'c5'],
472471 # selected columns are ['c2', 'c3', 'c4', 'c2'], `column_splits` will be
473472 # [(['c2', 'c3'], 0), ('c4', 1), ('c2', 0)].
473+ if not isinstance (col_names , _list_like_types ):
474+ col_names = [col_names ]
474475 selected_index = [calc_columns_index (col , in_df ) for col in col_names ]
476+ selected_index = list (itertools .chain .from_iterable (selected_index ))
475477 condition = np .where (np .diff (selected_index ))[0 ] + 1
476478 column_splits = np .split (col_names , condition )
477479 column_indexes = np .split (selected_index , condition )
@@ -482,19 +484,21 @@ def tile_with_columns(cls, op):
482484 zip (column_splits , column_indexes )
483485 ):
484486 dtypes = in_df .dtypes [columns ]
485- column_nsplits .append (len (columns ))
487+ column_nsplits .append (len (dtypes ))
486488 for j in range (in_df .chunk_shape [0 ]):
487489 c = in_df .cix [(j , column_idx [0 ])]
488490 index_op = DataFrameIndex (
489491 col_names = list (columns ), output_types = [OutputType .dataframe ]
490492 )
491493 out_chunk = index_op .new_chunk (
492494 [c ],
493- shape = (c .shape [0 ], len (columns )),
495+ shape = (c .shape [0 ], len (dtypes )),
494496 index = (j , i ),
495497 dtypes = dtypes ,
496498 index_value = c .index_value ,
497- columns_value = parse_index (pd .Index (columns ), store_data = True ),
499+ columns_value = parse_index (
500+ pd .Index (dtypes .index ), store_data = True
501+ ),
498502 )
499503 out_chunks [j ].append (out_chunk )
500504 out_chunks = [item for cl in out_chunks for item in cl ]
@@ -525,8 +529,6 @@ def execute(cls, ctx, op):
525529 mask = op .mask
526530 if hasattr (mask , "reindex_like" ):
527531 mask = mask .reindex_like (df ).fillna (False )
528- if mask .ndim == 2 :
529- mask = mask [df .columns .tolist ()]
530532 ctx [op .outputs [0 ].key ] = df [mask ]
531533
532534 @classmethod
0 commit comments