Skip to content

Commit 03eba5e

Browse files
committed
refactor: replace magic strings for col categories with a private Enum
1 parent f74f82a commit 03eba5e

File tree

4 files changed

+271
-92
lines changed

4 files changed

+271
-92
lines changed

bigframes/display/_flatten.py

Lines changed: 83 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from __future__ import annotations
2626

2727
import dataclasses
28+
import enum
2829

2930
import numpy as np
3031
import pandas as pd
@@ -37,11 +38,22 @@ class FlattenResult:
3738
"""The result of flattening a DataFrame.
3839
3940
Attributes:
40-
dataframe: The flattened DataFrame.
41+
dataframe: The flattened DataFrame. If the original DataFrame had an index
42+
(including MultiIndex), it is preserved in this flattened DataFrame,
43+
duplicated across exploded rows as needed.
4144
row_labels: A list of original row labels for each row in the flattened DataFrame.
42-
continuation_rows: A set of row indices that are continuation rows.
43-
cleared_on_continuation: A list of column names that should be cleared on continuation rows.
44-
nested_columns: A set of column names that were created from nested data.
45+
This corresponds to the original index values (stringified) and serves to
46+
visually group the exploded rows that belong to the same original row.
47+
continuation_rows: A set of row indices in the flattened DataFrame that are
48+
"continuation rows". These are additional rows created to display the
49+
2nd to Nth elements of an array. The first row (index i-1) contains
50+
the 1st element, while these rows contain subsequent elements.
51+
cleared_on_continuation: A list of column names that should be "cleared"
52+
(displayed as empty) on continuation rows. Typically, these are
53+
scalar columns (non-array) that were replicated during the explosion
54+
process but should only be visually displayed once per original row group.
55+
nested_columns: A set of column names that were created from nested data
56+
(flattened structs or arrays).
4557
"""
4658

4759
dataframe: pd.DataFrame
@@ -51,8 +63,15 @@ class FlattenResult:
5163
nested_columns: set[str]
5264

5365

66+
class _ColumnCategory(enum.Enum):
67+
STRUCT = "struct"
68+
ARRAY = "array"
69+
ARRAY_OF_STRUCT = "array_of_struct"
70+
CLEAR = "clear"
71+
72+
5473
@dataclasses.dataclass(frozen=True)
55-
class ColumnClassification:
74+
class _ColumnClassification:
5675
"""The result of classifying columns.
5776
5877
Attributes:
@@ -176,47 +195,51 @@ def flatten_nested_data(
176195

177196
def _classify_columns(
178197
dataframe: pd.DataFrame,
179-
) -> ColumnClassification:
198+
) -> _ColumnClassification:
180199
"""Identify all STRUCT and ARRAY columns in the DataFrame.
181200
182201
Args:
183202
dataframe: The DataFrame to inspect.
184203
185204
Returns:
186-
A ColumnClassification object containing lists of column names for each category.
205+
A _ColumnClassification object containing lists of column names for each category.
187206
"""
188207

189-
def get_category(dtype: pd.api.extensions.ExtensionDtype) -> str:
208+
def get_category(dtype: pd.api.extensions.ExtensionDtype) -> _ColumnCategory:
190209
pa_type = getattr(dtype, "pyarrow_dtype", None)
191210
if pa_type:
192211
if pa.types.is_struct(pa_type):
193-
return "struct"
212+
return _ColumnCategory.STRUCT
194213
if pa.types.is_list(pa_type):
195214
return (
196-
"array_of_struct"
215+
_ColumnCategory.ARRAY_OF_STRUCT
197216
if pa.types.is_struct(pa_type.value_type)
198-
else "array"
217+
else _ColumnCategory.ARRAY
199218
)
200-
return "clear"
219+
return _ColumnCategory.CLEAR
201220

202221
# Maps column names to their structural category to simplify list building.
203222
categories = {
204223
str(col): get_category(dtype) for col, dtype in dataframe.dtypes.items()
205224
}
206225

207-
return ColumnClassification(
208-
struct_columns=tuple(c for c, cat in categories.items() if cat == "struct"),
226+
return _ColumnClassification(
227+
struct_columns=tuple(
228+
c for c, cat in categories.items() if cat == _ColumnCategory.STRUCT
229+
),
209230
array_columns=tuple(
210-
c for c, cat in categories.items() if cat in ("array", "array_of_struct")
231+
c
232+
for c, cat in categories.items()
233+
if cat in (_ColumnCategory.ARRAY, _ColumnCategory.ARRAY_OF_STRUCT)
211234
),
212235
array_of_struct_columns=tuple(
213-
c for c, cat in categories.items() if cat == "array_of_struct"
236+
c for c, cat in categories.items() if cat == _ColumnCategory.ARRAY_OF_STRUCT
214237
),
215238
clear_on_continuation_cols=tuple(
216-
c for c, cat in categories.items() if cat == "clear"
239+
c for c, cat in categories.items() if cat == _ColumnCategory.CLEAR
217240
),
218241
nested_originated_columns=frozenset(
219-
c for c, cat in categories.items() if cat != "clear"
242+
c for c, cat in categories.items() if cat != _ColumnCategory.CLEAR
220243
),
221244
)
222245

@@ -357,7 +380,7 @@ def _explode_array_columns(
357380
if not array_columns:
358381
return ExplodeResult(dataframe, [], set())
359382

360-
work_df, non_array_columns, original_index_name = _prepare_explosion_dataframe(
383+
work_df, non_array_columns, index_names = _prepare_explosion_dataframe(
361384
dataframe, array_columns
362385
)
363386

@@ -386,8 +409,8 @@ def _explode_array_columns(
386409
total_rows = target_offsets[-1].as_py()
387410
if total_rows == 0:
388411
empty_df = pd.DataFrame(columns=dataframe.columns)
389-
if original_index_name:
390-
empty_df.index.name = original_index_name
412+
if index_names:
413+
empty_df = empty_df.set_index(index_names)
391414
return ExplodeResult(empty_df, [], set())
392415

393416
# parent_indices maps each result row to its original row index.
@@ -425,45 +448,69 @@ def _explode_array_columns(
425448
result_table = pa.Table.from_pydict(new_columns)
426449
result_df = result_table.to_pandas(types_mapper=pd.ArrowDtype)
427450

428-
grouping_col_name = (
429-
"_original_index" if original_index_name is None else original_index_name
430-
)
431-
row_labels = result_df[grouping_col_name].astype(str).tolist()
451+
if index_names:
452+
if len(index_names) == 1:
453+
row_labels = result_df[index_names[0]].astype(str).tolist()
454+
else:
455+
# For MultiIndex, create a tuple string representation
456+
row_labels = (
457+
result_df[index_names].apply(tuple, axis=1).astype(str).tolist()
458+
)
459+
else:
460+
row_labels = result_df["_original_index"].astype(str).tolist()
432461

433462
continuation_mask = pc.greater(row_nums, 0).to_numpy(zero_copy_only=False)
434463
continuation_rows = set(np.flatnonzero(continuation_mask).tolist())
435464

436-
result_df = result_df[dataframe.columns.tolist()]
465+
# Select columns: original columns + restored index columns (temporarily)
466+
cols_to_keep = dataframe.columns.tolist()
467+
if index_names:
468+
cols_to_keep.extend(index_names)
437469

438-
if original_index_name:
439-
result_df = result_df.set_index(original_index_name)
470+
# Filter columns, but allow index columns to pass through if they are not in original columns
471+
# (which they won't be if they were indices)
472+
result_df = result_df[cols_to_keep]
473+
474+
if index_names:
475+
result_df = result_df.set_index(index_names)
440476

441477
return ExplodeResult(result_df, row_labels, continuation_rows)
442478

443479

444480
def _prepare_explosion_dataframe(
445481
dataframe: pd.DataFrame, array_columns: list[str]
446-
) -> tuple[pd.DataFrame, list[str], str | None]:
482+
) -> tuple[pd.DataFrame, list[str], list[str] | None]:
447483
"""Prepares the DataFrame for explosion by ensuring grouping columns exist."""
448-
work_df = dataframe
484+
work_df = dataframe.copy()
449485
non_array_columns = work_df.columns.drop(array_columns).tolist()
450486

451487
if not non_array_columns:
452-
work_df = work_df.copy() # Avoid modifying input
453488
# Add a temporary column to allow grouping if all columns are arrays.
454489
non_array_columns = ["_temp_grouping_col"]
455490
work_df["_temp_grouping_col"] = range(len(work_df))
456491

457-
original_index_name = None
458-
if work_df.index.name:
459-
original_index_name = work_df.index.name
492+
index_names = None
493+
if work_df.index.nlevels > 1:
494+
# Handle MultiIndex
495+
names = list(work_df.index.names)
496+
# Assign default names if None to ensure reset_index works and we can track them
497+
names = [n if n is not None else f"level_{i}" for i, n in enumerate(names)]
498+
work_df.index.names = names
499+
index_names = names
500+
work_df = work_df.reset_index()
501+
non_array_columns.extend(index_names)
502+
elif work_df.index.name is not None:
503+
# Handle named Index
504+
index_names = [work_df.index.name]
460505
work_df = work_df.reset_index()
461-
non_array_columns.append(original_index_name)
506+
non_array_columns.extend(index_names)
462507
else:
508+
# Handle default/unnamed Index
509+
# We use _original_index for tracking but don't return it as an index to restore
463510
work_df = work_df.reset_index(names=["_original_index"])
464511
non_array_columns.append("_original_index")
465512

466-
return work_df, non_array_columns, original_index_name
513+
return work_df, non_array_columns, index_names
467514

468515

469516
def _flatten_struct_columns(

bigframes/display/html.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,21 @@ def _render_table_body(
138138
right_columns: list[Any],
139139
show_ellipsis: bool,
140140
) -> str:
141-
"""Render the body of the HTML table."""
141+
"""Render the table body.
142+
143+
Args:
144+
dataframe: The flattened dataframe to render.
145+
row_labels: Optional labels for each row, used for visual grouping of exploded rows.
146+
See `bigframes.display._flatten.FlattenResult` for details.
147+
continuation_rows: Indices of rows that are continuations of array explosion.
148+
See `bigframes.display._flatten.FlattenResult` for details.
149+
clear_on_continuation: Columns to render as empty in continuation rows.
150+
See `bigframes.display._flatten.FlattenResult` for details.
151+
nested_originated_columns: Columns created from nested data, used for alignment.
152+
left_columns: Columns to display on the left.
153+
right_columns: Columns to display on the right.
154+
show_ellipsis: Whether to show an ellipsis row.
155+
"""
142156
body_parts = [" <tbody>"]
143157
precision = options.display.precision
144158

@@ -315,8 +329,8 @@ def _get_obj_metadata(
315329

316330
def get_anywidget_bundle(
317331
obj: Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
318-
include=None,
319-
exclude=None,
332+
include: typing.Container[str] | None = None,
333+
exclude: typing.Container[str] | None = None,
320334
) -> tuple[dict[str, Any], dict[str, Any]]:
321335
"""
322336
Helper method to create and return the anywidget mimebundle.
@@ -413,9 +427,9 @@ def repr_mimebundle_head(
413427

414428
def repr_mimebundle(
415429
obj: Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
416-
include=None,
417-
exclude=None,
418-
):
430+
include: typing.Container[str] | None = None,
431+
exclude: typing.Container[str] | None = None,
432+
) -> dict[str, str] | tuple[dict[str, Any], dict[str, Any]] | None:
419433
"""Custom display method for IPython/Jupyter environments."""
420434
# TODO(b/467647693): Anywidget integration has been tested in Jupyter, VS Code, and
421435
# BQ Studio, but there is a known compatibility issue with Marimo that needs to be addressed.

0 commit comments

Comments
 (0)