Skip to content

Commit f32a53f

Browse files
committed
feat: use dataclass for flatten_nested_data
1 parent 0a88b10 commit f32a53f

File tree

3 files changed

+53
-25
lines changed

3 files changed

+53
-25
lines changed

bigframes/display/_flatten.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,44 @@
1616

1717
from __future__ import annotations
1818

19+
import dataclasses
1920
from typing import cast
2021

2122
import pandas as pd
2223
import pyarrow as pa
2324

2425

26+
@dataclasses.dataclass(frozen=True)
27+
class FlattenResult:
28+
"""The result of flattening a DataFrame."""
29+
30+
dataframe: pd.DataFrame
31+
"""The flattened DataFrame."""
32+
33+
row_groups: dict[str, list[int]]
34+
"""
35+
A mapping from original row index to the new row indices that were created
36+
from it.
37+
"""
38+
39+
cleared_on_continuation: list[str]
40+
"""A list of column names that should be cleared on continuation rows."""
41+
42+
nested_columns: set[str]
43+
"""A set of column names that were created from nested data."""
44+
45+
2546
def flatten_nested_data(
2647
dataframe: pd.DataFrame,
27-
) -> tuple[pd.DataFrame, dict[str, list[int]], list[str], set[str]]:
48+
) -> FlattenResult:
2849
"""Flatten nested STRUCT and ARRAY columns for display."""
2950
if dataframe.empty:
30-
return dataframe.copy(), {}, [], set()
51+
return FlattenResult(
52+
dataframe=dataframe.copy(),
53+
row_groups={},
54+
cleared_on_continuation=[],
55+
nested_columns=set(),
56+
)
3157

3258
result_df = dataframe.copy()
3359

@@ -49,19 +75,19 @@ def flatten_nested_data(
4975

5076
# Now handle ARRAY columns (including the newly created ones from ARRAY of STRUCT)
5177
if not array_columns:
52-
return (
53-
result_df,
54-
{},
55-
clear_on_continuation_cols,
56-
nested_originated_columns,
78+
return FlattenResult(
79+
dataframe=result_df,
80+
row_groups={},
81+
cleared_on_continuation=clear_on_continuation_cols,
82+
nested_columns=nested_originated_columns,
5783
)
5884

5985
result_df, array_row_groups = _explode_array_columns(result_df, array_columns)
60-
return (
61-
result_df,
62-
array_row_groups,
63-
clear_on_continuation_cols,
64-
nested_originated_columns,
86+
return FlattenResult(
87+
dataframe=result_df,
88+
row_groups=array_row_groups,
89+
cleared_on_continuation=clear_on_continuation_cols,
90+
nested_columns=nested_originated_columns,
6591
)
6692

6793

bigframes/display/html.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,20 @@ def render_html(
4949
) -> str:
5050
"""Render a pandas DataFrame to HTML with specific styling and nested data support."""
5151
# Flatten nested data first
52-
(
53-
flattened_df,
54-
array_row_groups,
55-
clear_on_continuation,
56-
nested_originated_columns,
57-
) = _flatten.flatten_nested_data(dataframe)
52+
flatten_result = _flatten.flatten_nested_data(dataframe)
5853

5954
orderable_columns = orderable_columns or []
6055
classes = "dataframe table table-striped table-hover"
6156
table_html_parts = [f'<table border="1" class="{classes}" id="{table_id}">']
62-
table_html_parts.append(_render_table_header(flattened_df, orderable_columns))
57+
table_html_parts.append(
58+
_render_table_header(flatten_result.dataframe, orderable_columns)
59+
)
6360
table_html_parts.append(
6461
_render_table_body(
65-
flattened_df,
66-
array_row_groups,
67-
clear_on_continuation,
68-
nested_originated_columns,
62+
flatten_result.dataframe,
63+
flatten_result.row_groups,
64+
flatten_result.cleared_on_continuation,
65+
flatten_result.nested_columns,
6966
)
7067
)
7168
table_html_parts.append("</table>")

tests/unit/display/test_html.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ def test_flatten_nested_data_flattens_structs():
165165
}
166166
)
167167

168-
flattened, _, _, nested_originated_columns = flatten_nested_data(struct_data)
168+
result = flatten_nested_data(struct_data)
169+
flattened = result.dataframe
170+
nested_originated_columns = result.nested_columns
169171

170172
assert "struct_col.name" in flattened.columns
171173
assert "struct_col.age" in flattened.columns
@@ -186,7 +188,10 @@ def test_flatten_nested_data_explodes_arrays():
186188
}
187189
)
188190

189-
flattened, groups, _, nested_originated_columns = flatten_nested_data(array_data)
191+
result = flatten_nested_data(array_data)
192+
flattened = result.dataframe
193+
groups = result.row_groups
194+
nested_originated_columns = result.nested_columns
190195

191196
assert len(flattened) == 5 # 3 + 2 array elements
192197
assert "0" in groups # First original row

0 commit comments

Comments
 (0)