Skip to content

Commit 3affd92

Browse files
committed
refactor: resue pandas struct.explode()
1 parent 63e4a3c commit 3affd92

File tree

2 files changed

+179
-181
lines changed

2 files changed

+179
-181
lines changed

bigframes/display/_flatten.py

Lines changed: 70 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import Callable, cast
19+
from typing import cast
2020

2121
import pandas as pd
2222
import pyarrow as pa
@@ -39,54 +39,12 @@ def flatten_nested_data(
3939
nested_originated_columns,
4040
) = _classify_columns(result_df)
4141

42-
# Flatten ARRAY of STRUCT columns
43-
def update_array_columns(col_name: str, new_col_names: list[str]) -> None:
44-
array_columns.remove(col_name)
45-
array_columns.extend(new_col_names)
46-
47-
def create_list_series(
48-
original_arr: pa.Array, field_arr: pa.Array, index: pd.Index, field: pa.Field
49-
) -> pd.Series:
50-
new_list_array = pa.ListArray.from_arrays(
51-
original_arr.offsets, field_arr, mask=original_arr.is_null()
52-
)
53-
return pd.Series(
54-
new_list_array.to_pylist(),
55-
dtype=pd.ArrowDtype(pa.list_(field.type)),
56-
index=index,
57-
)
58-
59-
result_df = _flatten_and_replace_columns(
60-
result_df,
61-
array_of_struct_columns,
62-
nested_originated_columns,
63-
get_struct_type=lambda t: t.value_type,
64-
get_field_values=lambda arr: arr.values.flatten(),
65-
create_series=create_list_series,
66-
update_metadata=update_array_columns,
42+
result_df, array_columns = _flatten_array_of_struct_columns(
43+
result_df, array_of_struct_columns, array_columns, nested_originated_columns
6744
)
6845

69-
# Flatten regular STRUCT columns
70-
def update_clear_on_continuation(col_name: str, new_col_names: list[str]) -> None:
71-
clear_on_continuation_cols.extend(new_col_names)
72-
73-
def create_struct_series(
74-
original_arr: pa.Array, field_arr: pa.Array, index: pd.Index, field: pa.Field
75-
) -> pd.Series:
76-
return pd.Series(
77-
field_arr.to_pylist(),
78-
dtype=pd.ArrowDtype(field.type),
79-
index=index,
80-
)
81-
82-
result_df = _flatten_and_replace_columns(
83-
result_df,
84-
struct_columns,
85-
nested_originated_columns,
86-
get_struct_type=lambda t: t,
87-
get_field_values=lambda arr: arr.flatten(),
88-
create_series=create_struct_series,
89-
update_metadata=update_clear_on_continuation,
46+
result_df, clear_on_continuation_cols = _flatten_struct_columns(
47+
result_df, struct_columns, clear_on_continuation_cols, nested_originated_columns
9048
)
9149

9250
# Now handle ARRAY columns (including the newly created ones from ARRAY of STRUCT)
@@ -146,36 +104,45 @@ def _classify_columns(
146104
)
147105

148106

149-
def _flatten_and_replace_columns(
107+
def _flatten_array_of_struct_columns(
150108
dataframe: pd.DataFrame,
151-
columns: list[str],
109+
array_of_struct_columns: list[str],
110+
array_columns: list[str],
152111
nested_originated_columns: set[str],
153-
get_struct_type: Callable[[pa.DataType], pa.DataType],
154-
get_field_values: Callable[[pa.Array], list[pa.Array]],
155-
create_series: Callable[[pa.Array, pa.Array, pd.Index, pa.Field], pd.Series],
156-
update_metadata: Callable[[str, list[str]], None],
157-
) -> pd.DataFrame:
158-
"""Generic helper to flatten structure-like columns and replace them in the DataFrame."""
112+
) -> tuple[pd.DataFrame, list[str]]:
113+
"""Flatten ARRAY of STRUCT columns into separate array columns for each field."""
159114
result_df = dataframe.copy()
160-
for col_name in columns:
115+
for col_name in array_of_struct_columns:
161116
col_data = result_df[col_name]
162117
pa_type = cast(pd.ArrowDtype, col_data.dtype).pyarrow_dtype
163-
struct_type = get_struct_type(pa_type)
118+
struct_type = pa_type.value_type
164119

120+
# Use PyArrow to reshape the list<struct> into multiple list<field> arrays
165121
arrow_array = pa.array(col_data)
166-
flattened_fields = get_field_values(arrow_array)
122+
offsets = arrow_array.offsets
123+
values = arrow_array.values # StructArray
124+
flattened_fields = values.flatten() # List[Array]
167125

168126
new_cols_to_add = {}
169-
new_col_names = []
127+
new_array_col_names = []
170128

129+
# Create new columns for each struct field
171130
for field_idx in range(struct_type.num_fields):
172131
field = struct_type.field(field_idx)
173132
new_col_name = f"{col_name}.{field.name}"
174133
nested_originated_columns.add(new_col_name)
175-
new_col_names.append(new_col_name)
134+
new_array_col_names.append(new_col_name)
176135

177-
new_cols_to_add[new_col_name] = create_series(
178-
arrow_array, flattened_fields[field_idx], result_df.index, field
136+
# Reconstruct ListArray for this field
137+
# Use mask=arrow_array.is_null() to preserve nulls from the original list
138+
new_list_array = pa.ListArray.from_arrays(
139+
offsets, flattened_fields[field_idx], mask=arrow_array.is_null()
140+
)
141+
142+
new_cols_to_add[new_col_name] = pd.Series(
143+
new_list_array.to_pylist(),
144+
dtype=pd.ArrowDtype(pa.list_(field.type)),
145+
index=result_df.index,
179146
)
180147

181148
col_idx = result_df.columns.to_list().index(col_name)
@@ -190,9 +157,11 @@ def _flatten_and_replace_columns(
190157
axis=1,
191158
)
192159

193-
update_metadata(col_name, new_col_names)
194-
195-
return result_df
160+
# Update array_columns list
161+
array_columns.remove(col_name)
162+
# Add the new array columns
163+
array_columns.extend(new_array_col_names)
164+
return result_df, array_columns
196165

197166

198167
def _explode_array_columns(
@@ -271,3 +240,39 @@ def _explode_array_columns(
271240
return exploded_df, array_row_groups
272241
else:
273242
return dataframe, array_row_groups
243+
244+
245+
def _flatten_struct_columns(
246+
dataframe: pd.DataFrame,
247+
struct_columns: list[str],
248+
clear_on_continuation_cols: list[str],
249+
nested_originated_columns: set[str],
250+
) -> tuple[pd.DataFrame, list[str]]:
251+
"""Flatten regular STRUCT columns using pandas accessor."""
252+
result_df = dataframe.copy()
253+
for col_name in struct_columns:
254+
# Use pandas struct accessor to explode the struct column into a DataFrame of its fields
255+
exploded_struct = result_df[col_name].struct.explode()
256+
257+
# Rename columns to 'parent.child' format
258+
exploded_struct.columns = [
259+
f"{col_name}.{sub_col}" for sub_col in exploded_struct.columns
260+
]
261+
262+
# Update metadata
263+
for new_col in exploded_struct.columns:
264+
nested_originated_columns.add(new_col)
265+
clear_on_continuation_cols.append(new_col)
266+
267+
# Replace the original struct column with the new field columns
268+
col_idx = result_df.columns.to_list().index(col_name)
269+
result_df = pd.concat(
270+
[
271+
result_df.iloc[:, :col_idx],
272+
exploded_struct,
273+
result_df.iloc[:, col_idx + 1 :],
274+
],
275+
axis=1,
276+
)
277+
278+
return result_df, clear_on_continuation_cols

0 commit comments

Comments
 (0)