Skip to content

Commit 6d28d28

Browse files
committed
fix: resolve bug in _classify_columns logic and enable functional updates
1 parent 59c3a2a commit 6d28d28

File tree

1 file changed

+72
-74
lines changed

1 file changed

+72
-74
lines changed

bigframes/display/_flatten.py

Lines changed: 72 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class FlattenResult:
5050
nested_columns: set[str]
5151

5252

53-
@dataclasses.dataclass
53+
@dataclasses.dataclass(frozen=True)
5454
class ColumnClassification:
5555
"""The result of classifying columns.
5656
@@ -62,11 +62,11 @@ class ColumnClassification:
6262
nested_originated_columns: Columns that were created from nested data.
6363
"""
6464

65-
struct_columns: list[str]
66-
array_columns: list[str]
67-
array_of_struct_columns: list[str]
68-
clear_on_continuation_cols: list[str]
69-
nested_originated_columns: set[str]
65+
struct_columns: tuple[str, ...]
66+
array_columns: tuple[str, ...]
67+
array_of_struct_columns: tuple[str, ...]
68+
clear_on_continuation_cols: tuple[str, ...]
69+
nested_originated_columns: frozenset[str]
7070

7171

7272
@dataclasses.dataclass(frozen=True)
@@ -107,41 +107,50 @@ def flatten_nested_data(
107107
result_df = dataframe.copy()
108108

109109
classification = _classify_columns(result_df)
110-
# Create a mutable structure to track column changes during flattening.
111-
# _flatten_array_of_struct_columns modifies the array_columns list.
112-
columns_info = dataclasses.replace(classification)
113110

114-
result_df, columns_info.array_columns = _flatten_array_of_struct_columns(
111+
# Process ARRAY-of-STRUCT columns into multiple ARRAY columns (one per struct field).
112+
result_df, array_cols, nested_cols = _flatten_array_of_struct_columns(
115113
result_df,
116-
columns_info.array_of_struct_columns,
117-
columns_info.array_columns,
118-
columns_info.nested_originated_columns,
114+
classification.array_of_struct_columns,
115+
classification.array_columns,
116+
classification.nested_originated_columns,
117+
)
118+
classification = dataclasses.replace(
119+
classification, array_columns=array_cols, nested_originated_columns=nested_cols
119120
)
120121

121-
result_df, columns_info.clear_on_continuation_cols = _flatten_struct_columns(
122+
# Flatten top-level STRUCT columns into separate columns.
123+
result_df, clear_cols, nested_cols = _flatten_struct_columns(
122124
result_df,
123-
columns_info.struct_columns,
124-
columns_info.clear_on_continuation_cols,
125-
columns_info.nested_originated_columns,
125+
classification.struct_columns,
126+
classification.clear_on_continuation_cols,
127+
classification.nested_originated_columns,
128+
)
129+
classification = dataclasses.replace(
130+
classification,
131+
clear_on_continuation_cols=clear_cols,
132+
nested_originated_columns=nested_cols,
126133
)
127134

128135
# Now handle ARRAY columns (including the newly created ones from ARRAY of STRUCT)
129-
if not columns_info.array_columns:
136+
if not classification.array_columns:
130137
return FlattenResult(
131138
dataframe=result_df,
132139
row_labels=None,
133140
continuation_rows=None,
134-
cleared_on_continuation=columns_info.clear_on_continuation_cols,
135-
nested_columns=columns_info.nested_originated_columns,
141+
cleared_on_continuation=list(classification.clear_on_continuation_cols),
142+
nested_columns=set(classification.nested_originated_columns),
136143
)
137144

138-
explode_result = _explode_array_columns(result_df, columns_info.array_columns)
145+
explode_result = _explode_array_columns(
146+
result_df, list(classification.array_columns)
147+
)
139148
return FlattenResult(
140149
dataframe=explode_result.dataframe,
141150
row_labels=explode_result.row_labels,
142151
continuation_rows=explode_result.continuation_rows,
143-
cleared_on_continuation=columns_info.clear_on_continuation_cols,
144-
nested_columns=columns_info.nested_originated_columns,
152+
cleared_on_continuation=list(classification.clear_on_continuation_cols),
153+
nested_columns=set(classification.nested_originated_columns),
145154
)
146155

147156

@@ -174,40 +183,43 @@ def _classify_columns(
174183
categories[col_name] = "clear"
175184

176185
return ColumnClassification(
177-
struct_columns=[c for c, cat in categories.items() if cat == "struct"],
178-
array_columns=[
186+
struct_columns=tuple(c for c, cat in categories.items() if cat == "struct"),
187+
array_columns=tuple(
179188
c for c, cat in categories.items() if cat in ("array", "array_of_struct")
180-
],
181-
array_of_struct_columns=[
189+
),
190+
array_of_struct_columns=tuple(
182191
c for c, cat in categories.items() if cat == "array_of_struct"
183-
],
184-
clear_on_continuation_cols=[
192+
),
193+
clear_on_continuation_cols=tuple(
185194
c for c, cat in categories.items() if cat == "clear"
186-
],
187-
nested_originated_columns={
195+
),
196+
nested_originated_columns=frozenset(
188197
c for c, cat in categories.items() if cat != "clear"
189-
},
198+
),
190199
)
191200

192201

193202
def _flatten_array_of_struct_columns(
194203
dataframe: pd.DataFrame,
195-
array_of_struct_columns: list[str],
196-
array_columns: list[str],
197-
nested_originated_columns: set[str],
198-
) -> tuple[pd.DataFrame, list[str]]:
204+
array_of_struct_columns: tuple[str, ...],
205+
array_columns: tuple[str, ...],
206+
nested_originated_columns: frozenset[str],
207+
) -> tuple[pd.DataFrame, tuple[str, ...], frozenset[str]]:
199208
"""Flatten ARRAY of STRUCT columns into separate ARRAY columns for each field.
200209
201210
Args:
202211
dataframe: The DataFrame to process.
203-
array_of_struct_columns: List of column names that are ARRAYs of STRUCTs.
204-
array_columns: The main list of ARRAY columns to be updated.
205-
nested_originated_columns: Set of columns tracked as originating from nested data.
212+
array_of_struct_columns: Column names that are ARRAYs of STRUCTs.
213+
array_columns: The main sequence of ARRAY columns to be updated.
214+
nested_originated_columns: Columns tracked as originating from nested data.
206215
207216
Returns:
208-
A tuple containing the modified DataFrame and the updated list of array columns.
217+
A tuple containing the modified DataFrame, updated array columns, and updated nested columns.
209218
"""
210219
result_df = dataframe.copy()
220+
current_array_columns = list(array_columns)
221+
current_nested_columns = set(nested_originated_columns)
222+
211223
for col_name in array_of_struct_columns:
212224
col_data = result_df[col_name]
213225
# Ensure we have a PyArrow array (pa.array handles pandas Series conversion)
@@ -225,18 +237,13 @@ def _flatten_array_of_struct_columns(
225237
}
226238
)
227239

228-
# Track the new columns
229-
for new_col in new_cols_df.columns:
230-
nested_originated_columns.add(new_col)
231-
232-
# Update the DataFrame
240+
current_nested_columns.update(new_cols_df.columns)
233241
result_df = _replace_column_in_df(result_df, col_name, new_cols_df)
234242

235-
# Update array_columns list
236-
array_columns.remove(col_name)
237-
array_columns.extend(new_cols_df.columns.tolist())
243+
current_array_columns.remove(col_name)
244+
current_array_columns.extend(new_cols_df.columns.tolist())
238245

239-
return result_df, array_columns
246+
return result_df, tuple(current_array_columns), frozenset(current_nested_columns)
240247

241248

242249
def _transpose_list_of_structs(arrow_array: pa.ListArray) -> dict[str, pa.ListArray]:
@@ -395,27 +402,25 @@ def _explode_array_columns(
395402

396403
def _flatten_struct_columns(
397404
dataframe: pd.DataFrame,
398-
struct_columns: list[str],
399-
clear_on_continuation_cols: list[str],
400-
nested_originated_columns: set[str],
401-
) -> tuple[pd.DataFrame, list[str]]:
405+
struct_columns: tuple[str, ...],
406+
clear_on_continuation_cols: tuple[str, ...],
407+
nested_originated_columns: frozenset[str],
408+
) -> tuple[pd.DataFrame, tuple[str, ...], frozenset[str]]:
402409
"""Flatten regular STRUCT columns into separate columns.
403410
404-
A STRUCT column 'user' with fields 'name' and 'age' becomes 'user.name'
405-
and 'user.age'.
406-
407411
Args:
408412
dataframe: The DataFrame to process.
409-
struct_columns: List of STRUCT columns to flatten.
410-
clear_on_continuation_cols: List of columns to clear on continuation,
411-
which will be updated with the new flattened columns.
412-
nested_originated_columns: Set of columns tracked as originating from nested data.
413+
struct_columns: STRUCT columns to flatten.
414+
clear_on_continuation_cols: Columns to clear on continuation.
415+
nested_originated_columns: Columns tracked as originating from nested data.
413416
414417
Returns:
415-
A tuple containing the modified DataFrame and the updated list of
416-
columns to clear on continuation.
418+
A tuple containing the modified DataFrame, updated clear columns, and updated nested columns.
417419
"""
418420
result_df = dataframe.copy()
421+
current_clear_cols = list(clear_on_continuation_cols)
422+
current_nested_cols = set(nested_originated_columns)
423+
419424
for col_name in struct_columns:
420425
col_data = result_df[col_name]
421426
if isinstance(col_data.dtype, pd.ArrowDtype):
@@ -430,8 +435,8 @@ def _flatten_struct_columns(
430435
for field_idx in range(pa_type.num_fields):
431436
field = pa_type.field(field_idx)
432437
new_col_name = f"{col_name}.{field.name}"
433-
nested_originated_columns.add(new_col_name)
434-
clear_on_continuation_cols.append(new_col_name)
438+
current_nested_cols.add(new_col_name)
439+
current_clear_cols.append(new_col_name)
435440

436441
# Create a new Series from the flattened array
437442
new_cols_to_add[new_col_name] = pd.Series(
@@ -440,14 +445,7 @@ def _flatten_struct_columns(
440445
index=result_df.index,
441446
)
442447

443-
col_idx = result_df.columns.to_list().index(col_name)
444448
new_cols_df = pd.DataFrame(new_cols_to_add, index=result_df.index)
445-
result_df = pd.concat(
446-
[
447-
result_df.iloc[:, :col_idx],
448-
new_cols_df,
449-
result_df.iloc[:, col_idx + 1 :],
450-
],
451-
axis=1,
452-
)
453-
return result_df, clear_on_continuation_cols
449+
result_df = _replace_column_in_df(result_df, col_name, new_cols_df)
450+
451+
return result_df, tuple(current_clear_cols), frozenset(current_nested_cols)

0 commit comments

Comments
 (0)