Skip to content

Commit 0b73c0a

Browse files
committed
perf(display): optimize nested data flattening and fix js style
1 parent 5cfa8d7 commit 0b73c0a

File tree

4 files changed

+44
-39
lines changed

4 files changed

+44
-39
lines changed

bigframes/display/_flatten.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ class FlattenResult:
3030
dataframe: pd.DataFrame
3131
"""The flattened DataFrame."""
3232

33-
row_groups: dict[str, list[int]]
34-
"""
35-
A mapping from original row index to the new row indices that were created
36-
from it.
37-
"""
33+
row_labels: list[str] | None
34+
"""A list of original row labels for each row in the flattened DataFrame."""
35+
36+
continuation_rows: set[int] | None
37+
"""A set of row indices that are continuation rows."""
3838

3939
cleared_on_continuation: list[str]
4040
"""A list of column names that should be cleared on continuation rows."""
@@ -50,7 +50,8 @@ def flatten_nested_data(
5050
if dataframe.empty:
5151
return FlattenResult(
5252
dataframe=dataframe.copy(),
53-
row_groups={},
53+
row_labels=None,
54+
continuation_rows=None,
5455
cleared_on_continuation=[],
5556
nested_columns=set(),
5657
)
@@ -77,15 +78,19 @@ def flatten_nested_data(
7778
if not array_columns:
7879
return FlattenResult(
7980
dataframe=result_df,
80-
row_groups={},
81+
row_labels=None,
82+
continuation_rows=None,
8183
cleared_on_continuation=clear_on_continuation_cols,
8284
nested_columns=nested_originated_columns,
8385
)
8486

85-
result_df, array_row_groups = _explode_array_columns(result_df, array_columns)
87+
result_df, row_labels, continuation_rows = _explode_array_columns(
88+
result_df, array_columns
89+
)
8690
return FlattenResult(
8791
dataframe=result_df,
88-
row_groups=array_row_groups,
92+
row_labels=row_labels,
93+
continuation_rows=continuation_rows,
8994
cleared_on_continuation=clear_on_continuation_cols,
9095
nested_columns=nested_originated_columns,
9196
)
@@ -192,10 +197,10 @@ def _flatten_array_of_struct_columns(
192197

193198
def _explode_array_columns(
194199
dataframe: pd.DataFrame, array_columns: list[str]
195-
) -> tuple[pd.DataFrame, dict[str, list[int]]]:
200+
) -> tuple[pd.DataFrame, list[str], set[int]]:
196201
"""Explode array columns into new rows."""
197202
if not array_columns:
198-
return dataframe, {}
203+
return dataframe, [], set()
199204

200205
original_cols = dataframe.columns.tolist()
201206
work_df = dataframe
@@ -243,7 +248,7 @@ def _explode_array_columns(
243248

244249
if not exploded_dfs:
245250
# This should not be reached if array_columns is not empty
246-
return dataframe, {}
251+
return dataframe, [], set()
247252

248253
# Merge the exploded columns
249254
merged_df = exploded_dfs[0]
@@ -260,22 +265,20 @@ def _explode_array_columns(
260265
drop=True
261266
)
262267

263-
# Create row groups
264-
array_row_groups = {}
268+
# Generate row labels and continuation mask efficiently
265269
grouping_col_name = (
266270
"_original_index" if original_index_name is None else original_index_name
267271
)
268-
if grouping_col_name in merged_df.columns:
269-
for orig_idx, group in merged_df.groupby(grouping_col_name):
270-
array_row_groups[str(orig_idx)] = group.index.tolist()
272+
row_labels = merged_df[grouping_col_name].astype(str).tolist()
273+
continuation_rows = set(merged_df.index[merged_df["_row_num"] > 0])
271274

272275
# Restore original columns
273276
result_df = merged_df[original_cols]
274277

275278
if original_index_name:
276279
result_df = result_df.set_index(original_index_name)
277280

278-
return result_df, array_row_groups
281+
return result_df, row_labels, continuation_rows
279282

280283

281284
def _flatten_struct_columns(

bigframes/display/html.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def render_html(
6060
table_html_parts.append(
6161
_render_table_body(
6262
flatten_result.dataframe,
63-
flatten_result.row_groups,
63+
flatten_result.row_labels,
64+
flatten_result.continuation_rows,
6465
flatten_result.cleared_on_continuation,
6566
flatten_result.nested_columns,
6667
)
@@ -87,7 +88,8 @@ def _render_table_header(dataframe: pd.DataFrame, orderable_columns: list[str])
8788

8889
def _render_table_body(
8990
dataframe: pd.DataFrame,
90-
array_row_groups: dict[str, list[int]],
91+
row_labels: list[str] | None,
92+
continuation_rows: set[int] | None,
9193
clear_on_continuation: list[str],
9294
nested_originated_columns: set[str],
9395
) -> str:
@@ -99,14 +101,15 @@ def _render_table_body(
99101
row_class = ""
100102
orig_row_idx = None
101103
is_continuation = False
102-
for orig_key, row_indices in array_row_groups.items():
103-
if i in row_indices and row_indices[0] != i:
104-
row_class = "array-continuation"
105-
orig_row_idx = orig_key
106-
is_continuation = True
107-
break
108-
109-
if row_class:
104+
105+
if row_labels:
106+
orig_row_idx = row_labels[i]
107+
108+
if continuation_rows and i in continuation_rows:
109+
is_continuation = True
110+
row_class = "array-continuation"
111+
112+
if orig_row_idx is not None:
110113
body_parts.append(
111114
f' <tr class="{row_class}" data-orig-row="{orig_row_idx}">'
112115
)

tests/js/table_widget.test.js

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import { jest } from '@jest/globals';
2222

23-
/**
23+
/*
2424
* Test suite for the TableWidget frontend component.
2525
*/
2626
describe('TableWidget', () => {
@@ -31,7 +31,7 @@ describe('TableWidget', () => {
3131
/** @type {Function} */
3232
let render;
3333

34-
/**
34+
/*
3535
* Sets up the test environment before each test.
3636
* This includes resetting modules, creating a DOM element,
3737
* and mocking the widget model.
@@ -58,7 +58,7 @@ describe('TableWidget', () => {
5858
expect(render).toBeDefined();
5959
});
6060

61-
/**
61+
/*
6262
* Tests for the render function of the widget.
6363
*/
6464
describe('render', () => {
@@ -91,7 +91,7 @@ describe('TableWidget', () => {
9191
expect(el.querySelector('div:nth-child(3)')).not.toBeNull();
9292
});
9393

94-
/**
94+
/*
9595
* Verifies that clicking a sortable column header triggers a sort action
9696
* with the correct parameters.
9797
*/
@@ -220,7 +220,7 @@ describe('TableWidget', () => {
220220
expect(indicator2.textContent).toBe('●');
221221
});
222222

223-
/**
223+
/*
224224
* Tests that holding the Shift key while clicking a column header
225225
* adds the new column to the existing sort context for multi-column sorting.
226226
*/
@@ -362,7 +362,7 @@ describe('TableWidget', () => {
362362
expect(headers[1].textContent).toBe('value');
363363
});
364364

365-
/**
365+
/*
366366
* Verifies that hovering over a cell in a group of flattened rows
367367
* (i.e., rows originating from the same nested data structure)
368368
* adds a hover class to all cells in that group.

tests/unit/display/test_html.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,11 @@ def test_flatten_nested_data_explodes_arrays():
190190

191191
result = flatten_nested_data(array_data)
192192
flattened = result.dataframe
193-
groups = result.row_groups
193+
row_labels = result.row_labels
194+
continuation_rows = result.continuation_rows
194195
nested_originated_columns = result.nested_columns
195196

196197
assert len(flattened) == 5 # 3 + 2 array elements
197-
assert "0" in groups # First original row
198-
assert len(groups["0"]) == 3 # Three array elements
199-
assert "1" in groups
200-
assert len(groups["1"]) == 2
198+
assert row_labels == ["0", "0", "0", "1", "1"]
199+
assert continuation_rows == {1, 2, 4}
201200
assert "array_col" in nested_originated_columns

0 commit comments

Comments
 (0)