@@ -50,7 +50,7 @@ class FlattenResult:
5050 nested_columns : set [str ]
5151
5252
53- @dataclasses .dataclass
53+ @dataclasses .dataclass ( frozen = True )
5454class ColumnClassification :
5555 """The result of classifying columns.
5656
@@ -62,11 +62,11 @@ class ColumnClassification:
6262 nested_originated_columns: Columns that were created from nested data.
6363 """
6464
65- struct_columns : list [str ]
66- array_columns : list [str ]
67- array_of_struct_columns : list [str ]
68- clear_on_continuation_cols : list [str ]
69- nested_originated_columns : set [str ]
65+ struct_columns : tuple [str , ... ]
66+ array_columns : tuple [str , ... ]
67+ array_of_struct_columns : tuple [str , ... ]
68+ clear_on_continuation_cols : tuple [str , ... ]
69+ nested_originated_columns : frozenset [str ]
7070
7171
7272@dataclasses .dataclass (frozen = True )
@@ -107,41 +107,50 @@ def flatten_nested_data(
107107 result_df = dataframe .copy ()
108108
109109 classification = _classify_columns (result_df )
110- # Create a mutable structure to track column changes during flattening.
111- # _flatten_array_of_struct_columns modifies the array_columns list.
112- columns_info = dataclasses .replace (classification )
113110
114- result_df , columns_info .array_columns = _flatten_array_of_struct_columns (
111+ # Process ARRAY-of-STRUCT columns into multiple ARRAY columns (one per struct field).
112+ result_df , array_cols , nested_cols = _flatten_array_of_struct_columns (
115113 result_df ,
116- columns_info .array_of_struct_columns ,
117- columns_info .array_columns ,
118- columns_info .nested_originated_columns ,
114+ classification .array_of_struct_columns ,
115+ classification .array_columns ,
116+ classification .nested_originated_columns ,
117+ )
118+ classification = dataclasses .replace (
119+ classification , array_columns = array_cols , nested_originated_columns = nested_cols
119120 )
120121
121- result_df , columns_info .clear_on_continuation_cols = _flatten_struct_columns (
122+ # Flatten top-level STRUCT columns into separate columns.
123+ result_df , clear_cols , nested_cols = _flatten_struct_columns (
122124 result_df ,
123- columns_info .struct_columns ,
124- columns_info .clear_on_continuation_cols ,
125- columns_info .nested_originated_columns ,
125+ classification .struct_columns ,
126+ classification .clear_on_continuation_cols ,
127+ classification .nested_originated_columns ,
128+ )
129+ classification = dataclasses .replace (
130+ classification ,
131+ clear_on_continuation_cols = clear_cols ,
132+ nested_originated_columns = nested_cols ,
126133 )
127134
128135 # Now handle ARRAY columns (including the newly created ones from ARRAY of STRUCT)
129- if not columns_info .array_columns :
136+ if not classification .array_columns :
130137 return FlattenResult (
131138 dataframe = result_df ,
132139 row_labels = None ,
133140 continuation_rows = None ,
134- cleared_on_continuation = columns_info .clear_on_continuation_cols ,
135- nested_columns = columns_info .nested_originated_columns ,
141+ cleared_on_continuation = list ( classification .clear_on_continuation_cols ) ,
142+ nested_columns = set ( classification .nested_originated_columns ) ,
136143 )
137144
138- explode_result = _explode_array_columns (result_df , columns_info .array_columns )
145+ explode_result = _explode_array_columns (
146+ result_df , list (classification .array_columns )
147+ )
139148 return FlattenResult (
140149 dataframe = explode_result .dataframe ,
141150 row_labels = explode_result .row_labels ,
142151 continuation_rows = explode_result .continuation_rows ,
143- cleared_on_continuation = columns_info .clear_on_continuation_cols ,
144- nested_columns = columns_info .nested_originated_columns ,
152+ cleared_on_continuation = list ( classification .clear_on_continuation_cols ) ,
153+ nested_columns = set ( classification .nested_originated_columns ) ,
145154 )
146155
147156
@@ -174,40 +183,43 @@ def _classify_columns(
174183 categories [col_name ] = "clear"
175184
176185 return ColumnClassification (
177- struct_columns = [ c for c , cat in categories .items () if cat == "struct" ] ,
178- array_columns = [
186+ struct_columns = tuple ( c for c , cat in categories .items () if cat == "struct" ) ,
187+ array_columns = tuple (
179188 c for c , cat in categories .items () if cat in ("array" , "array_of_struct" )
180- ] ,
181- array_of_struct_columns = [
189+ ) ,
190+ array_of_struct_columns = tuple (
182191 c for c , cat in categories .items () if cat == "array_of_struct"
183- ] ,
184- clear_on_continuation_cols = [
192+ ) ,
193+ clear_on_continuation_cols = tuple (
185194 c for c , cat in categories .items () if cat == "clear"
186- ] ,
187- nested_originated_columns = {
195+ ) ,
196+ nested_originated_columns = frozenset (
188197 c for c , cat in categories .items () if cat != "clear"
189- } ,
198+ ) ,
190199 )
191200
192201
193202def _flatten_array_of_struct_columns (
194203 dataframe : pd .DataFrame ,
195- array_of_struct_columns : list [str ],
196- array_columns : list [str ],
197- nested_originated_columns : set [str ],
198- ) -> tuple [pd .DataFrame , list [str ]]:
204+ array_of_struct_columns : tuple [str , ... ],
205+ array_columns : tuple [str , ... ],
206+ nested_originated_columns : frozenset [str ],
207+ ) -> tuple [pd .DataFrame , tuple [ str , ...], frozenset [str ]]:
199208 """Flatten ARRAY of STRUCT columns into separate ARRAY columns for each field.
200209
201210 Args:
202211 dataframe: The DataFrame to process.
203- array_of_struct_columns: List of column names that are ARRAYs of STRUCTs.
204- array_columns: The main list of ARRAY columns to be updated.
205- nested_originated_columns: Set of columns tracked as originating from nested data.
212+ array_of_struct_columns: Column names that are ARRAYs of STRUCTs.
213+ array_columns: The main sequence of ARRAY columns to be updated.
214+ nested_originated_columns: Columns tracked as originating from nested data.
206215
207216 Returns:
208- A tuple containing the modified DataFrame and the updated list of array columns.
217+ A tuple containing the modified DataFrame, updated array columns, and updated nested columns.
209218 """
210219 result_df = dataframe .copy ()
220+ current_array_columns = list (array_columns )
221+ current_nested_columns = set (nested_originated_columns )
222+
211223 for col_name in array_of_struct_columns :
212224 col_data = result_df [col_name ]
213225 # Ensure we have a PyArrow array (pa.array handles pandas Series conversion)
@@ -225,18 +237,13 @@ def _flatten_array_of_struct_columns(
225237 }
226238 )
227239
228- # Track the new columns
229- for new_col in new_cols_df .columns :
230- nested_originated_columns .add (new_col )
231-
232- # Update the DataFrame
240+ current_nested_columns .update (new_cols_df .columns )
233241 result_df = _replace_column_in_df (result_df , col_name , new_cols_df )
234242
235- # Update array_columns list
236- array_columns .remove (col_name )
237- array_columns .extend (new_cols_df .columns .tolist ())
243+ current_array_columns .remove (col_name )
244+ current_array_columns .extend (new_cols_df .columns .tolist ())
238245
239- return result_df , array_columns
246+ return result_df , tuple ( current_array_columns ), frozenset ( current_nested_columns )
240247
241248
242249def _transpose_list_of_structs (arrow_array : pa .ListArray ) -> dict [str , pa .ListArray ]:
@@ -395,27 +402,25 @@ def _explode_array_columns(
395402
396403def _flatten_struct_columns (
397404 dataframe : pd .DataFrame ,
398- struct_columns : list [str ],
399- clear_on_continuation_cols : list [str ],
400- nested_originated_columns : set [str ],
401- ) -> tuple [pd .DataFrame , list [str ]]:
405+ struct_columns : tuple [str , ... ],
406+ clear_on_continuation_cols : tuple [str , ... ],
407+ nested_originated_columns : frozenset [str ],
408+ ) -> tuple [pd .DataFrame , tuple [ str , ...], frozenset [str ]]:
402409 """Flatten regular STRUCT columns into separate columns.
403410
404- A STRUCT column 'user' with fields 'name' and 'age' becomes 'user.name'
405- and 'user.age'.
406-
407411 Args:
408412 dataframe: The DataFrame to process.
409- struct_columns: List of STRUCT columns to flatten.
410- clear_on_continuation_cols: List of columns to clear on continuation,
411- which will be updated with the new flattened columns.
412- nested_originated_columns: Set of columns tracked as originating from nested data.
413+ struct_columns: STRUCT columns to flatten.
414+ clear_on_continuation_cols: Columns to clear on continuation.
415+ nested_originated_columns: Columns tracked as originating from nested data.
413416
414417 Returns:
415- A tuple containing the modified DataFrame and the updated list of
416- columns to clear on continuation.
418+ A tuple containing the modified DataFrame, updated clear columns, and updated nested columns.
417419 """
418420 result_df = dataframe .copy ()
421+ current_clear_cols = list (clear_on_continuation_cols )
422+ current_nested_cols = set (nested_originated_columns )
423+
419424 for col_name in struct_columns :
420425 col_data = result_df [col_name ]
421426 if isinstance (col_data .dtype , pd .ArrowDtype ):
@@ -430,8 +435,8 @@ def _flatten_struct_columns(
430435 for field_idx in range (pa_type .num_fields ):
431436 field = pa_type .field (field_idx )
432437 new_col_name = f"{ col_name } .{ field .name } "
433- nested_originated_columns .add (new_col_name )
434- clear_on_continuation_cols .append (new_col_name )
438+ current_nested_cols .add (new_col_name )
439+ current_clear_cols .append (new_col_name )
435440
436441 # Create a new Series from the flattened array
437442 new_cols_to_add [new_col_name ] = pd .Series (
@@ -440,14 +445,7 @@ def _flatten_struct_columns(
440445 index = result_df .index ,
441446 )
442447
443- col_idx = result_df .columns .to_list ().index (col_name )
444448 new_cols_df = pd .DataFrame (new_cols_to_add , index = result_df .index )
445- result_df = pd .concat (
446- [
447- result_df .iloc [:, :col_idx ],
448- new_cols_df ,
449- result_df .iloc [:, col_idx + 1 :],
450- ],
451- axis = 1 ,
452- )
453- return result_df , clear_on_continuation_cols
449+ result_df = _replace_column_in_df (result_df , col_name , new_cols_df )
450+
451+ return result_df , tuple (current_clear_cols ), frozenset (current_nested_cols )
0 commit comments