@@ -43,6 +43,40 @@ class FlattenResult:
4343 """A set of column names that were created from nested data."""
4444
4545
46+ @dataclasses .dataclass (frozen = True )
47+ class ColumnClassification :
48+ """The result of classifying columns."""
49+
50+ struct_columns : list [str ]
51+ """Columns that are STRUCTs."""
52+
53+ array_columns : list [str ]
54+ """Columns that are ARRAYs."""
55+
56+ array_of_struct_columns : list [str ]
57+ """Columns that are ARRAYs of STRUCTs."""
58+
59+ clear_on_continuation_cols : list [str ]
60+ """Columns that should be cleared on continuation rows."""
61+
62+ nested_originated_columns : set [str ]
63+ """Columns that were created from nested data."""
64+
65+
66+ @dataclasses .dataclass (frozen = True )
67+ class ExplodeResult :
68+ """The result of exploding array columns."""
69+
70+ dataframe : pd .DataFrame
71+ """The exploded DataFrame."""
72+
73+ row_labels : list [str ]
74+ """Labels for the rows."""
75+
76+ continuation_rows : set [int ]
77+ """Indices of continuation rows."""
78+
79+
4680def flatten_nested_data (
4781 dataframe : pd .DataFrame ,
4882) -> FlattenResult :
@@ -58,13 +92,16 @@ def flatten_nested_data(
5892
5993 result_df = dataframe .copy ()
6094
61- (
62- struct_columns ,
63- array_columns ,
64- array_of_struct_columns ,
65- clear_on_continuation_cols ,
66- nested_originated_columns ,
67- ) = _classify_columns (result_df )
95+ classification = _classify_columns (result_df )
96+ # Extract lists to allow modification
97+ # TODO(b/469966526): The modification of these lists in place by subsequent functions
98+ # (e.g. _flatten_array_of_struct_columns removing items from array_columns) suggests
99+ # that the data flow here could be cleaner, but keeping it as is for now.
100+ struct_columns = classification .struct_columns
101+ array_columns = classification .array_columns
102+ array_of_struct_columns = classification .array_of_struct_columns
103+ clear_on_continuation_cols = classification .clear_on_continuation_cols
104+ nested_originated_columns = classification .nested_originated_columns
68105
69106 result_df , array_columns = _flatten_array_of_struct_columns (
70107 result_df , array_of_struct_columns , array_columns , nested_originated_columns
@@ -84,21 +121,19 @@ def flatten_nested_data(
84121 nested_columns = nested_originated_columns ,
85122 )
86123
87- result_df , row_labels , continuation_rows = _explode_array_columns (
88- result_df , array_columns
89- )
124+ explode_result = _explode_array_columns (result_df , array_columns )
90125 return FlattenResult (
91- dataframe = result_df ,
92- row_labels = row_labels ,
93- continuation_rows = continuation_rows ,
126+ dataframe = explode_result . dataframe ,
127+ row_labels = explode_result . row_labels ,
128+ continuation_rows = explode_result . continuation_rows ,
94129 cleared_on_continuation = clear_on_continuation_cols ,
95130 nested_columns = nested_originated_columns ,
96131 )
97132
98133
99134def _classify_columns (
100135 dataframe : pd .DataFrame ,
101- ) -> tuple [ list [ str ], list [ str ], list [ str ], list [ str ], set [ str ]] :
136+ ) -> ColumnClassification :
102137 """Identify all STRUCT and ARRAY columns."""
103138 initial_columns = list (dataframe .columns )
104139 struct_columns : list [str ] = []
@@ -126,12 +161,12 @@ def _classify_columns(
126161 clear_on_continuation_cols .append (col_name )
127162 elif col_name in initial_columns :
128163 clear_on_continuation_cols .append (col_name )
129- return (
130- struct_columns ,
131- array_columns ,
132- array_of_struct_columns ,
133- clear_on_continuation_cols ,
134- nested_originated_columns ,
164+ return ColumnClassification (
165+ struct_columns = struct_columns ,
166+ array_columns = array_columns ,
167+ array_of_struct_columns = array_of_struct_columns ,
168+ clear_on_continuation_cols = clear_on_continuation_cols ,
169+ nested_originated_columns = nested_originated_columns ,
135170 )
136171
137172
@@ -197,10 +232,10 @@ def _flatten_array_of_struct_columns(
197232
198233def _explode_array_columns (
199234 dataframe : pd .DataFrame , array_columns : list [str ]
200- ) -> tuple [ pd . DataFrame , list [ str ], set [ int ]] :
235+ ) -> ExplodeResult :
201236 """Explode array columns into new rows."""
202237 if not array_columns :
203- return dataframe , [], set ()
238+ return ExplodeResult ( dataframe , [], set () )
204239
205240 original_cols = dataframe .columns .tolist ()
206241 work_df = dataframe
@@ -248,7 +283,7 @@ def _explode_array_columns(
248283
249284 if not exploded_dfs :
250285 # This should not be reached if array_columns is not empty
251- return dataframe , [], set ()
286+ return ExplodeResult ( dataframe , [], set () )
252287
253288 # Merge the exploded columns
254289 merged_df = exploded_dfs [0 ]
@@ -278,7 +313,7 @@ def _explode_array_columns(
278313 if original_index_name :
279314 result_df = result_df .set_index (original_index_name )
280315
281- return result_df , row_labels , continuation_rows
316+ return ExplodeResult ( result_df , row_labels , continuation_rows )
282317
283318
284319def _flatten_struct_columns (
0 commit comments