2525from __future__ import annotations
2626
2727import dataclasses
28+ import enum
2829
2930import numpy as np
3031import pandas as pd
@@ -37,11 +38,22 @@ class FlattenResult:
3738 """The result of flattening a DataFrame.
3839
3940 Attributes:
40- dataframe: The flattened DataFrame.
41+ dataframe: The flattened DataFrame. If the original DataFrame had an index
42+ (including MultiIndex), it is preserved in this flattened DataFrame,
43+ duplicated across exploded rows as needed.
4144 row_labels: A list of original row labels for each row in the flattened DataFrame.
42- continuation_rows: A set of row indices that are continuation rows.
43- cleared_on_continuation: A list of column names that should be cleared on continuation rows.
44- nested_columns: A set of column names that were created from nested data.
45+ This corresponds to the original index values (stringified) and serves to
46+ visually group the exploded rows that belong to the same original row.
47+ continuation_rows: A set of row indices in the flattened DataFrame that are
48+ "continuation rows". These are additional rows created to display the
49+ 2nd to Nth elements of an array. The first row (index i-1) contains
50+ the 1st element, while these rows contain subsequent elements.
51+ cleared_on_continuation: A list of column names that should be "cleared"
52+ (displayed as empty) on continuation rows. Typically, these are
53+ scalar columns (non-array) that were replicated during the explosion
54+ process but should only be visually displayed once per original row group.
55+ nested_columns: A set of column names that were created from nested data
56+ (flattened structs or arrays).
4557 """
4658
4759 dataframe : pd .DataFrame
@@ -51,8 +63,15 @@ class FlattenResult:
5163 nested_columns : set [str ]
5264
5365
66+ class _ColumnCategory (enum .Enum ):
67+ STRUCT = "struct"
68+ ARRAY = "array"
69+ ARRAY_OF_STRUCT = "array_of_struct"
70+ CLEAR = "clear"
71+
72+
5473@dataclasses .dataclass (frozen = True )
55- class ColumnClassification :
74+ class _ColumnClassification :
5675 """The result of classifying columns.
5776
5877 Attributes:
@@ -176,47 +195,51 @@ def flatten_nested_data(
176195
177196def _classify_columns (
178197 dataframe : pd .DataFrame ,
179- ) -> ColumnClassification :
198+ ) -> _ColumnClassification :
180199 """Identify all STRUCT and ARRAY columns in the DataFrame.
181200
182201 Args:
183202 dataframe: The DataFrame to inspect.
184203
185204 Returns:
186- A ColumnClassification object containing lists of column names for each category.
205+ A _ColumnClassification object containing lists of column names for each category.
187206 """
188207
189- def get_category (dtype : pd .api .extensions .ExtensionDtype ) -> str :
208+ def get_category (dtype : pd .api .extensions .ExtensionDtype ) -> _ColumnCategory :
190209 pa_type = getattr (dtype , "pyarrow_dtype" , None )
191210 if pa_type :
192211 if pa .types .is_struct (pa_type ):
193- return "struct"
212+ return _ColumnCategory . STRUCT
194213 if pa .types .is_list (pa_type ):
195214 return (
196- "array_of_struct"
215+ _ColumnCategory . ARRAY_OF_STRUCT
197216 if pa .types .is_struct (pa_type .value_type )
198- else "array"
217+ else _ColumnCategory . ARRAY
199218 )
200- return "clear"
219+ return _ColumnCategory . CLEAR
201220
202221 # Maps column names to their structural category to simplify list building.
203222 categories = {
204223 str (col ): get_category (dtype ) for col , dtype in dataframe .dtypes .items ()
205224 }
206225
207- return ColumnClassification (
208- struct_columns = tuple (c for c , cat in categories .items () if cat == "struct" ),
226+ return _ColumnClassification (
227+ struct_columns = tuple (
228+ c for c , cat in categories .items () if cat == _ColumnCategory .STRUCT
229+ ),
209230 array_columns = tuple (
210- c for c , cat in categories .items () if cat in ("array" , "array_of_struct" )
231+ c
232+ for c , cat in categories .items ()
233+ if cat in (_ColumnCategory .ARRAY , _ColumnCategory .ARRAY_OF_STRUCT )
211234 ),
212235 array_of_struct_columns = tuple (
213- c for c , cat in categories .items () if cat == "array_of_struct"
236+ c for c , cat in categories .items () if cat == _ColumnCategory . ARRAY_OF_STRUCT
214237 ),
215238 clear_on_continuation_cols = tuple (
216- c for c , cat in categories .items () if cat == "clear"
239+ c for c , cat in categories .items () if cat == _ColumnCategory . CLEAR
217240 ),
218241 nested_originated_columns = frozenset (
219- c for c , cat in categories .items () if cat != "clear"
242+ c for c , cat in categories .items () if cat != _ColumnCategory . CLEAR
220243 ),
221244 )
222245
@@ -357,7 +380,7 @@ def _explode_array_columns(
357380 if not array_columns :
358381 return ExplodeResult (dataframe , [], set ())
359382
360- work_df , non_array_columns , original_index_name = _prepare_explosion_dataframe (
383+ work_df , non_array_columns , index_names = _prepare_explosion_dataframe (
361384 dataframe , array_columns
362385 )
363386
@@ -386,8 +409,8 @@ def _explode_array_columns(
386409 total_rows = target_offsets [- 1 ].as_py ()
387410 if total_rows == 0 :
388411 empty_df = pd .DataFrame (columns = dataframe .columns )
389- if original_index_name :
390- empty_df . index . name = original_index_name
412+ if index_names :
413+ empty_df = empty_df . set_index ( index_names )
391414 return ExplodeResult (empty_df , [], set ())
392415
393416 # parent_indices maps each result row to its original row index.
@@ -425,45 +448,69 @@ def _explode_array_columns(
425448 result_table = pa .Table .from_pydict (new_columns )
426449 result_df = result_table .to_pandas (types_mapper = pd .ArrowDtype )
427450
428- grouping_col_name = (
429- "_original_index" if original_index_name is None else original_index_name
430- )
431- row_labels = result_df [grouping_col_name ].astype (str ).tolist ()
451+ if index_names :
452+ if len (index_names ) == 1 :
453+ row_labels = result_df [index_names [0 ]].astype (str ).tolist ()
454+ else :
455+ # For MultiIndex, create a tuple string representation
456+ row_labels = (
457+ result_df [index_names ].apply (tuple , axis = 1 ).astype (str ).tolist ()
458+ )
459+ else :
460+ row_labels = result_df ["_original_index" ].astype (str ).tolist ()
432461
433462 continuation_mask = pc .greater (row_nums , 0 ).to_numpy (zero_copy_only = False )
434463 continuation_rows = set (np .flatnonzero (continuation_mask ).tolist ())
435464
436- result_df = result_df [dataframe .columns .tolist ()]
465+ # Select columns: original columns + restored index columns (temporarily)
466+ cols_to_keep = dataframe .columns .tolist ()
467+ if index_names :
468+ cols_to_keep .extend (index_names )
437469
438- if original_index_name :
439- result_df = result_df .set_index (original_index_name )
470+ # Filter columns, but allow index columns to pass through if they are not in original columns
471+ # (which they won't be if they were indices)
472+ result_df = result_df [cols_to_keep ]
473+
474+ if index_names :
475+ result_df = result_df .set_index (index_names )
440476
441477 return ExplodeResult (result_df , row_labels , continuation_rows )
442478
443479
444480def _prepare_explosion_dataframe (
445481 dataframe : pd .DataFrame , array_columns : list [str ]
446- ) -> tuple [pd .DataFrame , list [str ], str | None ]:
482+ ) -> tuple [pd .DataFrame , list [str ], list [ str ] | None ]:
447483 """Prepares the DataFrame for explosion by ensuring grouping columns exist."""
448- work_df = dataframe
484+ work_df = dataframe . copy ()
449485 non_array_columns = work_df .columns .drop (array_columns ).tolist ()
450486
451487 if not non_array_columns :
452- work_df = work_df .copy () # Avoid modifying input
453488 # Add a temporary column to allow grouping if all columns are arrays.
454489 non_array_columns = ["_temp_grouping_col" ]
455490 work_df ["_temp_grouping_col" ] = range (len (work_df ))
456491
457- original_index_name = None
458- if work_df .index .name :
459- original_index_name = work_df .index .name
492+ index_names = None
493+ if work_df .index .nlevels > 1 :
494+ # Handle MultiIndex
495+ names = list (work_df .index .names )
496+ # Assign default names if None to ensure reset_index works and we can track them
497+ names = [n if n is not None else f"level_{ i } " for i , n in enumerate (names )]
498+ work_df .index .names = names
499+ index_names = names
500+ work_df = work_df .reset_index ()
501+ non_array_columns .extend (index_names )
502+ elif work_df .index .name is not None :
503+ # Handle named Index
504+ index_names = [work_df .index .name ]
460505 work_df = work_df .reset_index ()
461- non_array_columns .append ( original_index_name )
506+ non_array_columns .extend ( index_names )
462507 else :
508+ # Handle default/unnamed Index
509+ # We use _original_index for tracking but don't return it as an index to restore
463510 work_df = work_df .reset_index (names = ["_original_index" ])
464511 non_array_columns .append ("_original_index" )
465512
466- return work_df , non_array_columns , original_index_name
513+ return work_df , non_array_columns , index_names
467514
468515
469516def _flatten_struct_columns (
0 commit comments