1616
1717from __future__ import annotations
1818
19- from typing import Callable , cast
19+ from typing import cast
2020
2121import pandas as pd
2222import pyarrow as pa
@@ -39,54 +39,12 @@ def flatten_nested_data(
3939 nested_originated_columns ,
4040 ) = _classify_columns (result_df )
4141
42- # Flatten ARRAY of STRUCT columns
43- def update_array_columns (col_name : str , new_col_names : list [str ]) -> None :
44- array_columns .remove (col_name )
45- array_columns .extend (new_col_names )
46-
47- def create_list_series (
48- original_arr : pa .Array , field_arr : pa .Array , index : pd .Index , field : pa .Field
49- ) -> pd .Series :
50- new_list_array = pa .ListArray .from_arrays (
51- original_arr .offsets , field_arr , mask = original_arr .is_null ()
52- )
53- return pd .Series (
54- new_list_array .to_pylist (),
55- dtype = pd .ArrowDtype (pa .list_ (field .type )),
56- index = index ,
57- )
58-
59- result_df = _flatten_and_replace_columns (
60- result_df ,
61- array_of_struct_columns ,
62- nested_originated_columns ,
63- get_struct_type = lambda t : t .value_type ,
64- get_field_values = lambda arr : arr .values .flatten (),
65- create_series = create_list_series ,
66- update_metadata = update_array_columns ,
42+ result_df , array_columns = _flatten_array_of_struct_columns (
43+ result_df , array_of_struct_columns , array_columns , nested_originated_columns
6744 )
6845
69- # Flatten regular STRUCT columns
70- def update_clear_on_continuation (col_name : str , new_col_names : list [str ]) -> None :
71- clear_on_continuation_cols .extend (new_col_names )
72-
73- def create_struct_series (
74- original_arr : pa .Array , field_arr : pa .Array , index : pd .Index , field : pa .Field
75- ) -> pd .Series :
76- return pd .Series (
77- field_arr .to_pylist (),
78- dtype = pd .ArrowDtype (field .type ),
79- index = index ,
80- )
81-
82- result_df = _flatten_and_replace_columns (
83- result_df ,
84- struct_columns ,
85- nested_originated_columns ,
86- get_struct_type = lambda t : t ,
87- get_field_values = lambda arr : arr .flatten (),
88- create_series = create_struct_series ,
89- update_metadata = update_clear_on_continuation ,
46+ result_df , clear_on_continuation_cols = _flatten_struct_columns (
47+ result_df , struct_columns , clear_on_continuation_cols , nested_originated_columns
9048 )
9149
9250 # Now handle ARRAY columns (including the newly created ones from ARRAY of STRUCT)
@@ -146,36 +104,45 @@ def _classify_columns(
146104 )
147105
148106
149- def _flatten_and_replace_columns (
107+ def _flatten_array_of_struct_columns (
150108 dataframe : pd .DataFrame ,
151- columns : list [str ],
109+ array_of_struct_columns : list [str ],
110+ array_columns : list [str ],
152111 nested_originated_columns : set [str ],
153- get_struct_type : Callable [[pa .DataType ], pa .DataType ],
154- get_field_values : Callable [[pa .Array ], list [pa .Array ]],
155- create_series : Callable [[pa .Array , pa .Array , pd .Index , pa .Field ], pd .Series ],
156- update_metadata : Callable [[str , list [str ]], None ],
157- ) -> pd .DataFrame :
158- """Generic helper to flatten structure-like columns and replace them in the DataFrame."""
112+ ) -> tuple [pd .DataFrame , list [str ]]:
113+ """Flatten ARRAY of STRUCT columns into separate array columns for each field."""
159114 result_df = dataframe .copy ()
160- for col_name in columns :
115+ for col_name in array_of_struct_columns :
161116 col_data = result_df [col_name ]
162117 pa_type = cast (pd .ArrowDtype , col_data .dtype ).pyarrow_dtype
163- struct_type = get_struct_type ( pa_type )
118+ struct_type = pa_type . value_type
164119
120+ # Use PyArrow to reshape the list<struct> into multiple list<field> arrays
165121 arrow_array = pa .array (col_data )
166- flattened_fields = get_field_values (arrow_array )
122+ offsets = arrow_array .offsets
123+ values = arrow_array .values # StructArray
124+ flattened_fields = values .flatten () # List[Array]
167125
168126 new_cols_to_add = {}
169- new_col_names = []
127+ new_array_col_names = []
170128
129+ # Create new columns for each struct field
171130 for field_idx in range (struct_type .num_fields ):
172131 field = struct_type .field (field_idx )
173132 new_col_name = f"{ col_name } .{ field .name } "
174133 nested_originated_columns .add (new_col_name )
175- new_col_names .append (new_col_name )
134+ new_array_col_names .append (new_col_name )
176135
177- new_cols_to_add [new_col_name ] = create_series (
178- arrow_array , flattened_fields [field_idx ], result_df .index , field
136+ # Reconstruct ListArray for this field
137+ # Use mask=arrow_array.is_null() to preserve nulls from the original list
138+ new_list_array = pa .ListArray .from_arrays (
139+ offsets , flattened_fields [field_idx ], mask = arrow_array .is_null ()
140+ )
141+
142+ new_cols_to_add [new_col_name ] = pd .Series (
143+ new_list_array .to_pylist (),
144+ dtype = pd .ArrowDtype (pa .list_ (field .type )),
145+ index = result_df .index ,
179146 )
180147
181148 col_idx = result_df .columns .to_list ().index (col_name )
@@ -190,9 +157,11 @@ def _flatten_and_replace_columns(
190157 axis = 1 ,
191158 )
192159
193- update_metadata (col_name , new_col_names )
194-
195- return result_df
160+ # Update array_columns list
161+ array_columns .remove (col_name )
162+ # Add the new array columns
163+ array_columns .extend (new_array_col_names )
164+ return result_df , array_columns
196165
197166
198167def _explode_array_columns (
@@ -271,3 +240,39 @@ def _explode_array_columns(
271240 return exploded_df , array_row_groups
272241 else :
273242 return dataframe , array_row_groups
243+
244+
245+ def _flatten_struct_columns (
246+ dataframe : pd .DataFrame ,
247+ struct_columns : list [str ],
248+ clear_on_continuation_cols : list [str ],
249+ nested_originated_columns : set [str ],
250+ ) -> tuple [pd .DataFrame , list [str ]]:
251+ """Flatten regular STRUCT columns using pandas accessor."""
252+ result_df = dataframe .copy ()
253+ for col_name in struct_columns :
254+ # Use pandas struct accessor to explode the struct column into a DataFrame of its fields
255+ exploded_struct = result_df [col_name ].struct .explode ()
256+
257+ # Rename columns to 'parent.child' format
258+ exploded_struct .columns = [
259+ f"{ col_name } .{ sub_col } " for sub_col in exploded_struct .columns
260+ ]
261+
262+ # Update metadata
263+ for new_col in exploded_struct .columns :
264+ nested_originated_columns .add (new_col )
265+ clear_on_continuation_cols .append (new_col )
266+
267+ # Replace the original struct column with the new field columns
268+ col_idx = result_df .columns .to_list ().index (col_name )
269+ result_df = pd .concat (
270+ [
271+ result_df .iloc [:, :col_idx ],
272+ exploded_struct ,
273+ result_df .iloc [:, col_idx + 1 :],
274+ ],
275+ axis = 1 ,
276+ )
277+
278+ return result_df , clear_on_continuation_cols
0 commit comments