Skip to content

Commit 15bdf54

Browse files
committed
refactor: simplify flattening logic in _flatten.py
1 parent dfe5fec commit 15bdf54

File tree

1 file changed

+99
-74
lines changed

1 file changed

+99
-74
lines changed

bigframes/display/_flatten.py

Lines changed: 99 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -155,38 +155,37 @@ def _classify_columns(
155155
Returns:
156156
A ColumnClassification object containing lists of column names for each category.
157157
"""
158-
initial_columns = list(dataframe.columns)
159-
struct_columns: list[str] = []
160-
array_columns: list[str] = []
161-
array_of_struct_columns: list[str] = []
162-
clear_on_continuation_cols: list[str] = []
163-
nested_originated_columns: set[str] = set()
164-
165-
for col_name_raw, col_data in dataframe.items():
166-
col_name = str(col_name_raw)
167-
dtype = col_data.dtype
168-
if isinstance(dtype, pd.ArrowDtype):
169-
pa_type = dtype.pyarrow_dtype
170-
if pa.types.is_struct(pa_type):
171-
struct_columns.append(col_name)
172-
nested_originated_columns.add(col_name)
173-
elif pa.types.is_list(pa_type):
174-
array_columns.append(col_name)
175-
nested_originated_columns.add(col_name)
176-
if hasattr(pa_type, "value_type") and (
177-
pa.types.is_struct(pa_type.value_type)
178-
):
179-
array_of_struct_columns.append(col_name)
180-
else:
181-
clear_on_continuation_cols.append(col_name)
182-
elif col_name in initial_columns:
183-
clear_on_continuation_cols.append(col_name)
158+
# Maps column names to their structural category to simplify list building.
159+
categories: dict[str, str] = {}
160+
161+
for col, dtype in dataframe.dtypes.items():
162+
col_name = str(col)
163+
pa_type = getattr(dtype, "pyarrow_dtype", None)
164+
165+
if not pa_type:
166+
categories[col_name] = "clear"
167+
elif pa.types.is_struct(pa_type):
168+
categories[col_name] = "struct"
169+
elif pa.types.is_list(pa_type):
170+
is_struct_array = pa.types.is_struct(pa_type.value_type)
171+
categories[col_name] = "array_of_struct" if is_struct_array else "array"
172+
else:
173+
categories[col_name] = "clear"
174+
184175
return ColumnClassification(
185-
struct_columns=struct_columns,
186-
array_columns=array_columns,
187-
array_of_struct_columns=array_of_struct_columns,
188-
clear_on_continuation_cols=clear_on_continuation_cols,
189-
nested_originated_columns=nested_originated_columns,
176+
struct_columns=[c for c, cat in categories.items() if cat == "struct"],
177+
array_columns=[
178+
c for c, cat in categories.items() if cat in ("array", "array_of_struct")
179+
],
180+
array_of_struct_columns=[
181+
c for c, cat in categories.items() if cat == "array_of_struct"
182+
],
183+
clear_on_continuation_cols=[
184+
c for c, cat in categories.items() if cat == "clear"
185+
],
186+
nested_originated_columns={
187+
c for c, cat in categories.items() if cat != "clear"
188+
},
190189
)
191190

192191

@@ -198,10 +197,6 @@ def _flatten_array_of_struct_columns(
198197
) -> tuple[pd.DataFrame, list[str]]:
199198
"""Flatten ARRAY of STRUCT columns into separate ARRAY columns for each field.
200199
201-
For example, an ARRAY<STRUCT<a INT64, b STRING>> column named 'items' will be
202-
converted into two ARRAY columns: 'items.a' (ARRAY<INT64>) and 'items.b' (ARRAY<STRING>).
203-
This allows us to treat them as standard ARRAY columns for the subsequent explosion step.
204-
205200
Args:
206201
dataframe: The DataFrame to process.
207202
array_of_struct_columns: List of column names that are ARRAYs of STRUCTs.
@@ -214,56 +209,86 @@ def _flatten_array_of_struct_columns(
214209
result_df = dataframe.copy()
215210
for col_name in array_of_struct_columns:
216211
col_data = result_df[col_name]
217-
pa_type = cast(pd.ArrowDtype, col_data.dtype).pyarrow_dtype
218-
struct_type = pa_type.value_type
219-
220-
# Use PyArrow to reshape the list<struct> into multiple list<field> arrays
212+
# Ensure we have a PyArrow array (pa.array handles pandas Series conversion)
221213
arrow_array = pa.array(col_data)
222-
offsets = arrow_array.offsets
223-
values = arrow_array.values # StructArray
224-
flattened_fields = values.flatten() # List[Array]
225-
226-
new_cols_to_add = {}
227-
new_array_col_names = []
228214

229-
# Create new columns for each struct field
230-
for field_idx in range(struct_type.num_fields):
231-
field = struct_type.field(field_idx)
232-
new_col_name = f"{col_name}.{field.name}"
233-
nested_originated_columns.add(new_col_name)
234-
new_array_col_names.append(new_col_name)
215+
# Transpose List<Struct<...>> to {field: List<field_type>}
216+
new_arrays = _transpose_list_of_structs(arrow_array)
235217

236-
# Reconstruct ListArray for this field. This transforms the
237-
# array<struct<f1, f2>> into separate array<f1> and array<f2> columns.
238-
new_list_array = pa.ListArray.from_arrays(
239-
offsets, flattened_fields[field_idx], mask=arrow_array.is_null()
240-
)
241-
242-
new_cols_to_add[new_col_name] = pd.Series(
243-
new_list_array,
244-
dtype=pd.ArrowDtype(pa.list_(field.type)),
245-
index=result_df.index,
246-
)
218+
new_cols_df = pd.DataFrame(
219+
{
220+
f"{col_name}.{field_name}": pd.Series(
221+
arr, dtype=pd.ArrowDtype(arr.type), index=result_df.index
222+
)
223+
for field_name, arr in new_arrays.items()
224+
}
225+
)
247226

248-
col_idx = result_df.columns.to_list().index(col_name)
249-
new_cols_df = pd.DataFrame(new_cols_to_add, index=result_df.index)
227+
# Track the new columns
228+
for new_col in new_cols_df.columns:
229+
nested_originated_columns.add(new_col)
250230

251-
result_df = pd.concat(
252-
[
253-
result_df.iloc[:, :col_idx],
254-
new_cols_df,
255-
result_df.iloc[:, col_idx + 1 :],
256-
],
257-
axis=1,
258-
)
231+
# Update the DataFrame
232+
result_df = _replace_column_in_df(result_df, col_name, new_cols_df)
259233

260234
# Update array_columns list
261235
array_columns.remove(col_name)
262-
# Add the new array columns
263-
array_columns.extend(new_array_col_names)
236+
array_columns.extend(new_cols_df.columns.tolist())
237+
264238
return result_df, array_columns
265239

266240

241+
def _transpose_list_of_structs(arrow_array: pa.ListArray) -> dict[str, pa.ListArray]:
242+
"""Transposes a ListArray of Structs into multiple ListArrays of fields.
243+
244+
Args:
245+
arrow_array: A PyArrow ListArray where the value type is a Struct.
246+
247+
Returns:
248+
A dictionary mapping field names to new ListArrays (one for each field in the struct).
249+
"""
250+
struct_type = arrow_array.type.value_type
251+
offsets = arrow_array.offsets
252+
# arrow_array.values is the underlying StructArray.
253+
# Flattening it gives us the arrays for each field, effectively "removing" the struct layer.
254+
flattened_fields = arrow_array.values.flatten()
255+
validity = arrow_array.is_null()
256+
257+
transposed = {}
258+
for i in range(struct_type.num_fields):
259+
field = struct_type.field(i)
260+
# Reconstruct ListArray for each field using original offsets and validity.
261+
# This transforms List<Struct<A, B>> into List<A> and List<B>.
262+
transposed[field.name] = pa.ListArray.from_arrays(
263+
offsets, flattened_fields[i], mask=validity
264+
)
265+
return transposed
266+
267+
268+
def _replace_column_in_df(
269+
dataframe: pd.DataFrame, col_name: str, new_cols: pd.DataFrame
270+
) -> pd.DataFrame:
271+
"""Replaces a column in a DataFrame with a set of new columns at the same position.
272+
273+
Args:
274+
dataframe: The original DataFrame.
275+
col_name: The name of the column to replace.
276+
new_cols: A DataFrame containing the new columns to insert.
277+
278+
Returns:
279+
A new DataFrame with the substitution made.
280+
"""
281+
col_idx = dataframe.columns.to_list().index(col_name)
282+
return pd.concat(
283+
[
284+
dataframe.iloc[:, :col_idx],
285+
new_cols,
286+
dataframe.iloc[:, col_idx + 1 :],
287+
],
288+
axis=1,
289+
)
290+
291+
267292
def _explode_array_columns(
268293
dataframe: pd.DataFrame, array_columns: list[str]
269294
) -> ExplodeResult:

0 commit comments

Comments
 (0)