|
25 | 25 | from __future__ import annotations |
26 | 26 |
|
27 | 27 | import dataclasses |
28 | | -from typing import cast |
29 | 28 |
|
30 | 29 | import numpy as np |
31 | 30 | import pandas as pd |
32 | 31 | import pyarrow as pa |
33 | | -import pyarrow.compute as pc |
| 32 | +import pyarrow.compute as pc # type: ignore |
34 | 33 |
|
35 | 34 |
|
36 | 35 | @dataclasses.dataclass(frozen=True) |
@@ -431,9 +430,8 @@ def _explode_array_columns( |
431 | 430 | ) |
432 | 431 | row_labels = result_df[grouping_col_name].astype(str).tolist() |
433 | 432 |
|
434 | | - # The continuation_mask is a boolean mask where row_num > 0. |
435 | 433 | continuation_mask = pc.greater(row_nums, 0).to_numpy(zero_copy_only=False) |
436 | | - continuation_rows = set(np.flatnonzero(continuation_mask)) |
| 434 | + continuation_rows = set(np.flatnonzero(continuation_mask).tolist()) |
437 | 435 |
|
438 | 436 | result_df = result_df[dataframe.columns.tolist()] |
439 | 437 |
|
@@ -485,33 +483,43 @@ def _flatten_struct_columns( |
485 | 483 | Returns: |
486 | 484 | A FlattenStructsResult containing the updated DataFrame and columns. |
487 | 485 | """ |
488 | | - result_df = dataframe.copy() |
| 486 | + if not struct_columns: |
| 487 | + return FlattenStructsResult( |
| 488 | + dataframe=dataframe.copy(), |
| 489 | + clear_on_continuation_cols=clear_on_continuation_cols, |
| 490 | + nested_originated_columns=nested_originated_columns, |
| 491 | + ) |
| 492 | + |
| 493 | + # Convert to PyArrow table for efficient flattening |
| 494 | + table = pa.Table.from_pandas(dataframe, preserve_index=False) |
| 495 | + |
489 | 496 | current_clear_cols = list(clear_on_continuation_cols) |
490 | 497 | current_nested_cols = set(nested_originated_columns) |
491 | 498 |
|
| 499 | + # Identify new columns that will be created to update metadata |
492 | 500 | for col_name in struct_columns: |
493 | | - col_data = result_df[col_name] |
494 | | - if isinstance(col_data.dtype, pd.ArrowDtype): |
495 | | - pa_type = cast(pd.ArrowDtype, col_data.dtype).pyarrow_dtype |
| 501 | + idx = table.schema.get_field_index(col_name) |
| 502 | + if idx == -1: |
| 503 | + continue |
| 504 | + |
| 505 | + field = table.schema.field(idx) |
| 506 | + if pa.types.is_struct(field.type): |
| 507 | + for i in range(field.type.num_fields): |
| 508 | + child_field = field.type.field(i) |
| 509 | + new_col_name = f"{col_name}.{child_field.name}" |
| 510 | + current_nested_cols.add(new_col_name) |
| 511 | + current_clear_cols.append(new_col_name) |
| 512 | + |
| 513 | + # Expand all struct columns into "parent.child" columns. |
| 514 | + flattened_table = table.flatten() |
| 515 | + |
| 516 | + # Convert back to pandas, using ArrowDtype to preserve types and ignoring metadata |
| 517 | + # to avoid issues with stale struct type info. |
| 518 | + result_df = flattened_table.to_pandas( |
| 519 | + types_mapper=pd.ArrowDtype, ignore_metadata=True |
| 520 | + ) |
496 | 521 |
|
497 | | - arrow_array = pa.array(col_data) |
498 | | - flattened_fields = arrow_array.flatten() |
499 | | - |
500 | | - new_cols_to_add = {} |
501 | | - for field_idx in range(pa_type.num_fields): |
502 | | - field = pa_type.field(field_idx) |
503 | | - new_col_name = f"{col_name}.{field.name}" |
504 | | - current_nested_cols.add(new_col_name) |
505 | | - current_clear_cols.append(new_col_name) |
506 | | - |
507 | | - new_cols_to_add[new_col_name] = pd.Series( |
508 | | - flattened_fields[field_idx], |
509 | | - dtype=pd.ArrowDtype(field.type), |
510 | | - index=result_df.index, |
511 | | - ) |
512 | | - |
513 | | - new_cols_df = pd.DataFrame(new_cols_to_add, index=result_df.index) |
514 | | - result_df = _replace_column_in_df(result_df, col_name, new_cols_df) |
| 522 | + result_df.index = dataframe.index |
515 | 523 |
|
516 | 524 | return FlattenStructsResult( |
517 | 525 | dataframe=result_df, |
|
0 commit comments