Skip to content

Commit fc122a5

Browse files
committed
refactor(anywidget): optimize and style cleanup for flatten logic
- Replaced Python-based row explosion with optimized PyArrow computation for nested arrays. - Cleaned up comments in to strictly adhere to Google Python Style Guide (focused on 'why', removed redundant 'what'). - Renamed variable to for clarity. - Verified changes with Python unit tests and JavaScript frontend tests.
1 parent 2de5a3c commit fc122a5

File tree

1 file changed

+89
-66
lines changed

1 file changed

+89
-66
lines changed

bigframes/display/_flatten.py

Lines changed: 89 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
import dataclasses
2828
from typing import cast
2929

30+
import numpy as np
3031
import pandas as pd
3132
import pyarrow as pa
33+
import pyarrow.compute as pc
3234

3335

3436
@dataclasses.dataclass(frozen=True)
@@ -356,90 +358,114 @@ def _explode_array_columns(
356358
if not array_columns:
357359
return ExplodeResult(dataframe, [], set())
358360

359-
# Group by all non-array columns to maintain context.
360-
# _row_num tracks the index within the exploded array to synchronize multiple
361-
# arrays. Continuation rows (index > 0) are tracked for display clearing.
362-
original_cols = dataframe.columns.tolist()
363-
work_df = dataframe
361+
work_df, non_array_columns, original_index_name = _prepare_explosion_dataframe(
362+
dataframe, array_columns
363+
)
364364

365-
non_array_columns = work_df.columns.drop(array_columns).tolist()
366-
if not non_array_columns:
367-
work_df = work_df.copy() # Avoid modifying input
368-
# Add a temporary column to allow grouping if all columns are arrays.
369-
# This ensures we can still group by "original row" even if there are no scalar columns.
370-
non_array_columns = ["_temp_grouping_col"]
371-
work_df["_temp_grouping_col"] = range(len(work_df))
365+
if work_df.empty:
366+
return ExplodeResult(dataframe, [], set())
372367

373-
# Preserve original index
374-
if work_df.index.name:
375-
original_index_name = work_df.index.name
376-
work_df = work_df.reset_index()
377-
non_array_columns.append(original_index_name)
378-
else:
379-
original_index_name = None
380-
work_df = work_df.reset_index(names=["_original_index"])
381-
non_array_columns.append("_original_index")
368+
table = pa.Table.from_pandas(work_df)
369+
arrays = [table.column(col).combine_chunks() for col in array_columns]
370+
lengths = []
371+
for arr in arrays:
372+
row_lengths = pc.list_value_length(arr)
373+
# Treat null lists as length 1 to match pandas explode behavior for scalars.
374+
row_lengths = pc.if_else(
375+
pc.is_null(row_lengths, nan_is_null=True), 1, row_lengths
376+
)
377+
lengths.append(row_lengths)
382378

383-
exploded_dfs = []
384-
for col in array_columns:
385-
# Explode each array column individually
386-
col_series = work_df[col]
387-
target_dtype = None
388-
if isinstance(col_series.dtype, pd.ArrowDtype):
389-
pa_type = col_series.dtype.pyarrow_dtype
390-
if pa.types.is_list(pa_type):
391-
target_dtype = pd.ArrowDtype(pa_type.value_type)
392-
# Use to_list() to avoid pandas attempting to create a 2D numpy
393-
# array if the list elements have the same length.
394-
col_series = pd.Series(
395-
col_series.to_list(), index=col_series.index, dtype=object
396-
)
379+
if not lengths:
380+
return ExplodeResult(dataframe, [], set())
397381

398-
exploded = work_df[non_array_columns].assign(**{col: col_series}).explode(col)
382+
max_lens = lengths[0] if len(lengths) == 1 else pc.max_element_wise(*lengths)
383+
max_lens = max_lens.cast(pa.int64())
384+
current_offsets = pc.cumulative_sum(max_lens)
385+
target_offsets = pa.concat_arrays([pa.array([0], type=pa.int64()), current_offsets])
399386

400-
if target_dtype is not None:
401-
# Re-cast to arrow dtype if possible
402-
exploded[col] = exploded[col].astype(target_dtype)
387+
total_rows = target_offsets[-1].as_py()
388+
if total_rows == 0:
389+
empty_df = pd.DataFrame(columns=dataframe.columns)
390+
if original_index_name:
391+
empty_df.index.name = original_index_name
392+
return ExplodeResult(empty_df, [], set())
403393

404-
# Track position in the array for alignment
405-
exploded["_row_num"] = exploded.groupby(non_array_columns).cumcount()
406-
exploded_dfs.append(exploded)
394+
# parent_indices maps each result row to its original row index.
395+
dummy_values = pa.nulls(total_rows, type=pa.null())
396+
dummy_list_array = pa.ListArray.from_arrays(target_offsets, dummy_values)
397+
parent_indices = pc.list_parent_indices(dummy_list_array)
407398

408-
if not exploded_dfs:
409-
# This should not be reached if array_columns is not empty
410-
return ExplodeResult(dataframe, [], set())
399+
range_k = pa.array(range(total_rows))
400+
starts = target_offsets.take(parent_indices)
401+
row_nums = pc.subtract(range_k, starts)
411402

412-
# Merge the exploded columns
413-
merged_df = exploded_dfs[0]
414-
for i in range(1, len(exploded_dfs)):
415-
merged_df = pd.merge(
416-
merged_df,
417-
exploded_dfs[i],
418-
on=non_array_columns + ["_row_num"],
419-
how="outer",
420-
)
403+
new_columns = {}
404+
for col_name in non_array_columns:
405+
new_columns[col_name] = table.column(col_name).take(parent_indices)
421406

422-
# Restore original column order and sort
423-
merged_df = merged_df.sort_values(non_array_columns + ["_row_num"]).reset_index(
424-
drop=True
425-
)
407+
for col_name, arr in zip(array_columns, arrays):
408+
actual_lens_scattered = pc.list_value_length(arr).take(parent_indices)
409+
valid_mask = pc.less(row_nums, actual_lens_scattered)
410+
starts_scattered = arr.offsets.take(parent_indices)
411+
412+
# safe_mask ensures we don't access out of bounds even if masked out.
413+
safe_mask = pc.fill_null(valid_mask, False)
414+
candidate_indices = pc.add(starts_scattered, row_nums)
415+
safe_indices = pc.if_else(safe_mask, candidate_indices, 0)
416+
417+
if len(arr.values) == 0:
418+
final_values = pa.nulls(total_rows, type=arr.type.value_type)
419+
else:
420+
taken_values = arr.values.take(safe_indices)
421+
final_values = pc.if_else(safe_mask, taken_values, None)
422+
423+
new_columns[col_name] = final_values
424+
425+
result_df = pa.Table.from_pydict(new_columns).to_pandas()
426426

427-
# Generate row labels and continuation mask efficiently
428427
grouping_col_name = (
429428
"_original_index" if original_index_name is None else original_index_name
430429
)
431-
row_labels = merged_df[grouping_col_name].astype(str).tolist()
432-
continuation_rows = set(merged_df.index[merged_df["_row_num"] > 0])
430+
row_labels = result_df[grouping_col_name].astype(str).tolist()
433431

434-
# Restore original columns
435-
result_df = merged_df[original_cols]
432+
# The continuation_mask is a boolean mask where row_num > 0.
433+
continuation_mask = pc.greater(row_nums, 0).to_numpy(zero_copy_only=False)
434+
continuation_rows = set(np.flatnonzero(continuation_mask))
435+
436+
result_df = result_df[dataframe.columns.tolist()]
436437

437438
if original_index_name:
438439
result_df = result_df.set_index(original_index_name)
439440

440441
return ExplodeResult(result_df, row_labels, continuation_rows)
441442

442443

444+
def _prepare_explosion_dataframe(
445+
dataframe: pd.DataFrame, array_columns: list[str]
446+
) -> tuple[pd.DataFrame, list[str], str | None]:
447+
"""Prepares the DataFrame for explosion by ensuring grouping columns exist."""
448+
work_df = dataframe
449+
non_array_columns = work_df.columns.drop(array_columns).tolist()
450+
451+
if not non_array_columns:
452+
work_df = work_df.copy() # Avoid modifying input
453+
# Add a temporary column to allow grouping if all columns are arrays.
454+
non_array_columns = ["_temp_grouping_col"]
455+
work_df["_temp_grouping_col"] = range(len(work_df))
456+
457+
original_index_name = None
458+
if work_df.index.name:
459+
original_index_name = work_df.index.name
460+
work_df = work_df.reset_index()
461+
non_array_columns.append(original_index_name)
462+
else:
463+
work_df = work_df.reset_index(names=["_original_index"])
464+
non_array_columns.append("_original_index")
465+
466+
return work_df, non_array_columns, original_index_name
467+
468+
443469
def _flatten_struct_columns(
444470
dataframe: pd.DataFrame,
445471
struct_columns: tuple[str, ...],
@@ -466,8 +492,6 @@ def _flatten_struct_columns(
466492
if isinstance(col_data.dtype, pd.ArrowDtype):
467493
pa_type = cast(pd.ArrowDtype, col_data.dtype).pyarrow_dtype
468494

469-
# Use PyArrow to flatten the struct column without row iteration
470-
# combine_chunks() ensures we have a single array if it was chunked
471495
arrow_array = pa.array(col_data)
472496
flattened_fields = arrow_array.flatten()
473497

@@ -478,7 +502,6 @@ def _flatten_struct_columns(
478502
current_nested_cols.add(new_col_name)
479503
current_clear_cols.append(new_col_name)
480504

481-
# Create a new Series from the flattened array
482505
new_cols_to_add[new_col_name] = pd.Series(
483506
flattened_fields[field_idx],
484507
dtype=pd.ArrowDtype(field.type),

0 commit comments

Comments
 (0)