@@ -758,7 +758,8 @@ def _top_level_groupby(df, options: FitOptions):
758758 df_gb = type (df )()
759759 ignore_index = True
760760 df_gb [cat_col_selector_str ] = _concat (
761- [df [col ] for col in cat_col_selector .names ], ignore_index
761+ [_maybe_flatten_list_column (col , df )[col ] for col in cat_col_selector .names ],
762+ ignore_index ,
762763 )
763764 cat_col_selector = ColumnSelector ([cat_col_selector_str ])
764765 else :
@@ -795,9 +796,7 @@ def _top_level_groupby(df, options: FitOptions):
795796
796797 # Perform groupby and flatten column index
797798 # (flattening provides better cudf/pd support)
798- if is_list_col (cat_col_selector , df_gb ):
799- # handle list columns by encoding the list values
800- df_gb = dispatch .flatten_list_column (df_gb [cat_col_selector .names [0 ]])
799+ df_gb = _maybe_flatten_list_column (cat_col_selector .names [0 ], df_gb )
801800 # NOTE: groupby(..., dropna=False) requires pandas>=1.1.0
802801 gb = df_gb .groupby (cat_col_selector .names , dropna = False ).agg (agg_dict )
803802 gb .columns = [
@@ -1414,6 +1413,15 @@ def is_list_col(col_selector, df):
14141413 return has_lists
14151414
14161415
1416+ def _maybe_flatten_list_column (col : str , df ):
1417+ # Flatten the specified column (col) if it is
1418+ # a list dtype. Otherwise, pass back df "as is"
1419+ selector = ColumnSelector ([col ])
1420+ if is_list_col (selector , df ):
1421+ return dispatch .flatten_list_column (df [selector .names [0 ]])
1422+ return df
1423+
1424+
14171425def _hash_bucket (df , num_buckets , col , encode_type = "joint" ):
14181426 if encode_type == "joint" :
14191427 nb = num_buckets [col [0 ]]
0 commit comments