-
Notifications
You must be signed in to change notification settings - Fork 63
feat: support nested STRUCT and ARRAY data display in anywidget mode #2359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 14 commits
f20cde5
19e2c4f
4b68243
8a7609a
ceca74d
63e4a3c
3affd92
c53da80
fa37000
60785f3
0a88b10
f32a53f
3944249
ce59668
41df7b3
e364674
8682d55
159d6a5
68b7fbb
5cfa8d7
0b73c0a
21a5d5c
36a9a37
4d46e3c
0f48f82
a8a39dc
dfe5fec
15bdf54
59c3a2a
6d28d28
09635e6
2de5a3c
fc122a5
9a19966
9886e5f
b2166ed
a34802e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,313 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Utilities for flattening nested data structures for display.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import dataclasses | ||
| from typing import cast | ||
|
|
||
| import pandas as pd | ||
| import pyarrow as pa | ||
|
|
||
|
|
||
| @dataclasses.dataclass(frozen=True) | ||
| class FlattenResult: | ||
| """The result of flattening a DataFrame.""" | ||
|
|
||
| dataframe: pd.DataFrame | ||
| """The flattened DataFrame.""" | ||
|
|
||
| row_groups: dict[str, list[int]] | ||
| """ | ||
| A mapping from original row index to the new row indices that were created | ||
| from it. | ||
| """ | ||
|
|
||
| cleared_on_continuation: list[str] | ||
| """A list of column names that should be cleared on continuation rows.""" | ||
|
|
||
| nested_columns: set[str] | ||
| """A set of column names that were created from nested data.""" | ||
|
|
||
|
|
||
| def flatten_nested_data( | ||
| dataframe: pd.DataFrame, | ||
| ) -> FlattenResult: | ||
| """Flatten nested STRUCT and ARRAY columns for display.""" | ||
| if dataframe.empty: | ||
| return FlattenResult( | ||
| dataframe=dataframe.copy(), | ||
| row_groups={}, | ||
| cleared_on_continuation=[], | ||
| nested_columns=set(), | ||
| ) | ||
|
|
||
| result_df = dataframe.copy() | ||
|
|
||
| ( | ||
| struct_columns, | ||
| array_columns, | ||
| array_of_struct_columns, | ||
| clear_on_continuation_cols, | ||
| nested_originated_columns, | ||
| ) = _classify_columns(result_df) | ||
|
|
||
| result_df, array_columns = _flatten_array_of_struct_columns( | ||
| result_df, array_of_struct_columns, array_columns, nested_originated_columns | ||
| ) | ||
|
|
||
| result_df, clear_on_continuation_cols = _flatten_struct_columns( | ||
| result_df, struct_columns, clear_on_continuation_cols, nested_originated_columns | ||
| ) | ||
|
|
||
| # Now handle ARRAY columns (including the newly created ones from ARRAY of STRUCT) | ||
| if not array_columns: | ||
| return FlattenResult( | ||
| dataframe=result_df, | ||
| row_groups={}, | ||
| cleared_on_continuation=clear_on_continuation_cols, | ||
| nested_columns=nested_originated_columns, | ||
| ) | ||
|
|
||
| result_df, array_row_groups = _explode_array_columns(result_df, array_columns) | ||
| return FlattenResult( | ||
| dataframe=result_df, | ||
| row_groups=array_row_groups, | ||
| cleared_on_continuation=clear_on_continuation_cols, | ||
| nested_columns=nested_originated_columns, | ||
| ) | ||
|
|
||
|
|
||
| def _classify_columns( | ||
| dataframe: pd.DataFrame, | ||
| ) -> tuple[list[str], list[str], list[str], list[str], set[str]]: | ||
| """Identify all STRUCT and ARRAY columns.""" | ||
| initial_columns = list(dataframe.columns) | ||
| struct_columns: list[str] = [] | ||
| array_columns: list[str] = [] | ||
| array_of_struct_columns: list[str] = [] | ||
| clear_on_continuation_cols: list[str] = [] | ||
| nested_originated_columns: set[str] = set() | ||
|
|
||
| for col_name_raw, col_data in dataframe.items(): | ||
| col_name = str(col_name_raw) | ||
| dtype = col_data.dtype | ||
| if isinstance(dtype, pd.ArrowDtype): | ||
| pa_type = dtype.pyarrow_dtype | ||
| if pa.types.is_struct(pa_type): | ||
| struct_columns.append(col_name) | ||
| nested_originated_columns.add(col_name) | ||
| elif pa.types.is_list(pa_type): | ||
| array_columns.append(col_name) | ||
| nested_originated_columns.add(col_name) | ||
| if hasattr(pa_type, "value_type") and ( | ||
| pa.types.is_struct(pa_type.value_type) | ||
| ): | ||
| array_of_struct_columns.append(col_name) | ||
| else: | ||
| clear_on_continuation_cols.append(col_name) | ||
| elif col_name in initial_columns: | ||
| clear_on_continuation_cols.append(col_name) | ||
| return ( | ||
| struct_columns, | ||
| array_columns, | ||
| array_of_struct_columns, | ||
| clear_on_continuation_cols, | ||
| nested_originated_columns, | ||
| ) | ||
|
|
||
|
|
||
| def _flatten_array_of_struct_columns( | ||
| dataframe: pd.DataFrame, | ||
| array_of_struct_columns: list[str], | ||
| array_columns: list[str], | ||
| nested_originated_columns: set[str], | ||
| ) -> tuple[pd.DataFrame, list[str]]: | ||
| """Flatten ARRAY of STRUCT columns into separate array columns for each field.""" | ||
| result_df = dataframe.copy() | ||
| for col_name in array_of_struct_columns: | ||
| col_data = result_df[col_name] | ||
| pa_type = cast(pd.ArrowDtype, col_data.dtype).pyarrow_dtype | ||
| struct_type = pa_type.value_type | ||
|
|
||
| # Use PyArrow to reshape the list<struct> into multiple list<field> arrays | ||
| arrow_array = pa.array(col_data) | ||
| offsets = arrow_array.offsets | ||
| values = arrow_array.values # StructArray | ||
| flattened_fields = values.flatten() # List[Array] | ||
|
|
||
| new_cols_to_add = {} | ||
| new_array_col_names = [] | ||
|
|
||
| # Create new columns for each struct field | ||
| for field_idx in range(struct_type.num_fields): | ||
| field = struct_type.field(field_idx) | ||
| new_col_name = f"{col_name}.{field.name}" | ||
| nested_originated_columns.add(new_col_name) | ||
| new_array_col_names.append(new_col_name) | ||
|
|
||
| # Reconstruct ListArray for this field | ||
| # Use mask=arrow_array.is_null() to preserve nulls from the original list | ||
| new_list_array = pa.ListArray.from_arrays( | ||
| offsets, flattened_fields[field_idx], mask=arrow_array.is_null() | ||
| ) | ||
|
|
||
| new_cols_to_add[new_col_name] = pd.Series( | ||
| new_list_array.to_pylist(), | ||
|
||
| dtype=pd.ArrowDtype(pa.list_(field.type)), | ||
|
||
| index=result_df.index, | ||
| ) | ||
|
|
||
| col_idx = result_df.columns.to_list().index(col_name) | ||
| new_cols_df = pd.DataFrame(new_cols_to_add, index=result_df.index) | ||
|
|
||
| result_df = pd.concat( | ||
| [ | ||
| result_df.iloc[:, :col_idx], | ||
| new_cols_df, | ||
| result_df.iloc[:, col_idx + 1 :], | ||
| ], | ||
| axis=1, | ||
| ) | ||
|
|
||
| # Update array_columns list | ||
| array_columns.remove(col_name) | ||
| # Add the new array columns | ||
| array_columns.extend(new_array_col_names) | ||
| return result_df, array_columns | ||
|
|
||
|
|
||
| def _explode_array_columns( | ||
| dataframe: pd.DataFrame, array_columns: list[str] | ||
| ) -> tuple[pd.DataFrame, dict[str, list[int]]]: | ||
| """Explode array columns into new rows.""" | ||
| exploded_rows = [] | ||
| array_row_groups: dict[str, list[int]] = {} | ||
| non_array_columns = dataframe.columns.drop(array_columns).tolist() | ||
| non_array_df = dataframe[non_array_columns] | ||
|
|
||
| for orig_idx in dataframe.index: | ||
| non_array_data = non_array_df.loc[orig_idx].to_dict() | ||
| array_values = {} | ||
| max_len_in_row = 0 | ||
| non_na_array_found = False | ||
|
|
||
| for col_name in array_columns: | ||
| val = dataframe.loc[orig_idx, col_name] | ||
|
||
| if val is not None and not ( | ||
| isinstance(val, list) and len(val) == 1 and pd.isna(val[0]) | ||
| ): | ||
| array_values[col_name] = list(val) | ||
| max_len_in_row = max(max_len_in_row, len(val)) | ||
| non_na_array_found = True | ||
| else: | ||
| array_values[col_name] = [] | ||
|
|
||
| if not non_na_array_found: | ||
| new_row = non_array_data.copy() | ||
| for col_name in array_columns: | ||
| new_row[f"{col_name}"] = pd.NA | ||
| exploded_rows.append(new_row) | ||
| orig_key = str(orig_idx) | ||
| if orig_key not in array_row_groups: | ||
| array_row_groups[orig_key] = [] | ||
| array_row_groups[orig_key].append(len(exploded_rows) - 1) | ||
| continue | ||
|
|
||
| # Create one row per array element, up to max_len_in_row | ||
| for array_idx in range(max_len_in_row): | ||
|
||
| new_row = non_array_data.copy() | ||
|
|
||
| # Add the specific array element for this index | ||
| for col_name in array_columns: | ||
| if array_idx < len(array_values.get(col_name, [])): | ||
| new_row[f"{col_name}"] = array_values[col_name][array_idx] | ||
| else: | ||
| new_row[f"{col_name}"] = pd.NA | ||
|
|
||
| exploded_rows.append(new_row) | ||
|
|
||
| # Track which rows belong to which original row | ||
| orig_key = str(orig_idx) | ||
| if orig_key not in array_row_groups: | ||
| array_row_groups[orig_key] = [] | ||
| array_row_groups[orig_key].append(len(exploded_rows) - 1) | ||
|
|
||
| if exploded_rows: | ||
| # Reconstruct the DataFrame to maintain original column order | ||
| exploded_df = pd.DataFrame(exploded_rows)[dataframe.columns] | ||
| for col in exploded_df.columns: | ||
| # After explosion, object columns that are all-numeric (except for NAs) | ||
| # should be converted to a numeric dtype for proper alignment. | ||
| if exploded_df[col].dtype == "object": | ||
| try: | ||
| # Use nullable integer type to preserve integers | ||
| exploded_df[col] = exploded_df[col].astype(pd.Int64Dtype()) | ||
| except (ValueError, TypeError): | ||
| # Fallback for non-integer numerics | ||
| try: | ||
| exploded_df[col] = pd.to_numeric(exploded_df[col]) | ||
| except (ValueError, TypeError): | ||
| # Keep as object if not numeric | ||
| pass | ||
| return exploded_df, array_row_groups | ||
| else: | ||
| return dataframe, array_row_groups | ||
|
|
||
|
|
||
| def _flatten_struct_columns( | ||
| dataframe: pd.DataFrame, | ||
| struct_columns: list[str], | ||
| clear_on_continuation_cols: list[str], | ||
| nested_originated_columns: set[str], | ||
| ) -> tuple[pd.DataFrame, list[str]]: | ||
| """Flatten regular STRUCT columns.""" | ||
| result_df = dataframe.copy() | ||
| for col_name in struct_columns: | ||
| col_data = result_df[col_name] | ||
| if isinstance(col_data.dtype, pd.ArrowDtype): | ||
| pa_type = cast(pd.ArrowDtype, col_data.dtype).pyarrow_dtype | ||
|
|
||
| # Use PyArrow to flatten the struct column without row iteration | ||
| # combine_chunks() ensures we have a single array if it was chunked | ||
| arrow_array = pa.array(col_data) | ||
| flattened_fields = arrow_array.flatten() | ||
|
|
||
| new_cols_to_add = {} | ||
| for field_idx in range(pa_type.num_fields): | ||
| field = pa_type.field(field_idx) | ||
| new_col_name = f"{col_name}.{field.name}" | ||
| nested_originated_columns.add(new_col_name) | ||
| clear_on_continuation_cols.append(new_col_name) | ||
|
|
||
| # Create a new Series from the flattened array | ||
| new_cols_to_add[new_col_name] = pd.Series( | ||
| flattened_fields[field_idx].to_pylist(), | ||
| dtype=pd.ArrowDtype(field.type), | ||
| index=result_df.index, | ||
| ) | ||
|
|
||
| col_idx = result_df.columns.to_list().index(col_name) | ||
| new_cols_df = pd.DataFrame(new_cols_to_add, index=result_df.index) | ||
| result_df = pd.concat( | ||
| [ | ||
| result_df.iloc[:, :col_idx], | ||
| new_cols_df, | ||
| result_df.iloc[:, col_idx + 1 :], | ||
| ], | ||
| axis=1, | ||
| ) | ||
| return result_df, clear_on_continuation_cols | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using some of the pyarrow compute functions to simplify these functions. For example, https://arrow.apache.org/docs/python/generated/pyarrow.compute.list_flatten.html could potentially save use from manually reconstructing arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion. I've refactored the array explosion logic to use a much faster vectorized approach with pandas.explode and merge, which removes the Python loops entirely.