2727import dataclasses
2828from typing import cast
2929
30+ import numpy as np
3031import pandas as pd
3132import pyarrow as pa
33+ import pyarrow .compute as pc
3234
3335
3436@dataclasses .dataclass (frozen = True )
@@ -356,90 +358,114 @@ def _explode_array_columns(
356358 if not array_columns :
357359 return ExplodeResult (dataframe , [], set ())
358360
359- # Group by all non-array columns to maintain context.
360- # _row_num tracks the index within the exploded array to synchronize multiple
361- # arrays. Continuation rows (index > 0) are tracked for display clearing.
362- original_cols = dataframe .columns .tolist ()
363- work_df = dataframe
361+ work_df , non_array_columns , original_index_name = _prepare_explosion_dataframe (
362+ dataframe , array_columns
363+ )
364364
365- non_array_columns = work_df .columns .drop (array_columns ).tolist ()
366- if not non_array_columns :
367- work_df = work_df .copy () # Avoid modifying input
368- # Add a temporary column to allow grouping if all columns are arrays.
369- # This ensures we can still group by "original row" even if there are no scalar columns.
370- non_array_columns = ["_temp_grouping_col" ]
371- work_df ["_temp_grouping_col" ] = range (len (work_df ))
365+ if work_df .empty :
366+ return ExplodeResult (dataframe , [], set ())
372367
373- # Preserve original index
374- if work_df .index .name :
375- original_index_name = work_df .index .name
376- work_df = work_df .reset_index ()
377- non_array_columns .append (original_index_name )
378- else :
379- original_index_name = None
380- work_df = work_df .reset_index (names = ["_original_index" ])
381- non_array_columns .append ("_original_index" )
368+ table = pa .Table .from_pandas (work_df )
369+ arrays = [table .column (col ).combine_chunks () for col in array_columns ]
370+ lengths = []
371+ for arr in arrays :
372+ row_lengths = pc .list_value_length (arr )
373+ # Treat null lists as length 1 to match pandas explode behavior for scalars.
374+ row_lengths = pc .if_else (
375+ pc .is_null (row_lengths , nan_is_null = True ), 1 , row_lengths
376+ )
377+ lengths .append (row_lengths )
382378
383- exploded_dfs = []
384- for col in array_columns :
385- # Explode each array column individually
386- col_series = work_df [col ]
387- target_dtype = None
388- if isinstance (col_series .dtype , pd .ArrowDtype ):
389- pa_type = col_series .dtype .pyarrow_dtype
390- if pa .types .is_list (pa_type ):
391- target_dtype = pd .ArrowDtype (pa_type .value_type )
392- # Use to_list() to avoid pandas attempting to create a 2D numpy
393- # array if the list elements have the same length.
394- col_series = pd .Series (
395- col_series .to_list (), index = col_series .index , dtype = object
396- )
379+ if not lengths :
380+ return ExplodeResult (dataframe , [], set ())
397381
398- exploded = work_df [non_array_columns ].assign (** {col : col_series }).explode (col )
382+ max_lens = lengths [0 ] if len (lengths ) == 1 else pc .max_element_wise (* lengths )
383+ max_lens = max_lens .cast (pa .int64 ())
384+ current_offsets = pc .cumulative_sum (max_lens )
385+ target_offsets = pa .concat_arrays ([pa .array ([0 ], type = pa .int64 ()), current_offsets ])
399386
400- if target_dtype is not None :
401- # Re-cast to arrow dtype if possible
402- exploded [col ] = exploded [col ].astype (target_dtype )
387+ total_rows = target_offsets [- 1 ].as_py ()
388+ if total_rows == 0 :
389+ empty_df = pd .DataFrame (columns = dataframe .columns )
390+ if original_index_name :
391+ empty_df .index .name = original_index_name
392+ return ExplodeResult (empty_df , [], set ())
403393
404- # Track position in the array for alignment
405- exploded ["_row_num" ] = exploded .groupby (non_array_columns ).cumcount ()
406- exploded_dfs .append (exploded )
394+ # parent_indices maps each result row to its original row index.
395+ dummy_values = pa .nulls (total_rows , type = pa .null ())
396+ dummy_list_array = pa .ListArray .from_arrays (target_offsets , dummy_values )
397+ parent_indices = pc .list_parent_indices (dummy_list_array )
407398
408- if not exploded_dfs :
409- # This should not be reached if array_columns is not empty
410- return ExplodeResult ( dataframe , [], set () )
399+ range_k = pa . array ( range ( total_rows ))
400+ starts = target_offsets . take ( parent_indices )
401+ row_nums = pc . subtract ( range_k , starts )
411402
412- # Merge the exploded columns
413- merged_df = exploded_dfs [0 ]
414- for i in range (1 , len (exploded_dfs )):
415- merged_df = pd .merge (
416- merged_df ,
417- exploded_dfs [i ],
418- on = non_array_columns + ["_row_num" ],
419- how = "outer" ,
420- )
403+ new_columns = {}
404+ for col_name in non_array_columns :
405+ new_columns [col_name ] = table .column (col_name ).take (parent_indices )
421406
422- # Restore original column order and sort
423- merged_df = merged_df .sort_values (non_array_columns + ["_row_num" ]).reset_index (
424- drop = True
425- )
407+ for col_name , arr in zip (array_columns , arrays ):
408+ actual_lens_scattered = pc .list_value_length (arr ).take (parent_indices )
409+ valid_mask = pc .less (row_nums , actual_lens_scattered )
410+ starts_scattered = arr .offsets .take (parent_indices )
411+
412+ # safe_mask ensures we don't access out of bounds even if masked out.
413+ safe_mask = pc .fill_null (valid_mask , False )
414+ candidate_indices = pc .add (starts_scattered , row_nums )
415+ safe_indices = pc .if_else (safe_mask , candidate_indices , 0 )
416+
417+ if len (arr .values ) == 0 :
418+ final_values = pa .nulls (total_rows , type = arr .type .value_type )
419+ else :
420+ taken_values = arr .values .take (safe_indices )
421+ final_values = pc .if_else (safe_mask , taken_values , None )
422+
423+ new_columns [col_name ] = final_values
424+
425+ result_df = pa .Table .from_pydict (new_columns ).to_pandas ()
426426
427- # Generate row labels and continuation mask efficiently
428427 grouping_col_name = (
429428 "_original_index" if original_index_name is None else original_index_name
430429 )
431- row_labels = merged_df [grouping_col_name ].astype (str ).tolist ()
432- continuation_rows = set (merged_df .index [merged_df ["_row_num" ] > 0 ])
430+ row_labels = result_df [grouping_col_name ].astype (str ).tolist ()
433431
434- # Restore original columns
435- result_df = merged_df [original_cols ]
432+ # The continuation_mask is a boolean mask where row_num > 0.
433+ continuation_mask = pc .greater (row_nums , 0 ).to_numpy (zero_copy_only = False )
434+ continuation_rows = set (np .flatnonzero (continuation_mask ))
435+
436+ result_df = result_df [dataframe .columns .tolist ()]
436437
437438 if original_index_name :
438439 result_df = result_df .set_index (original_index_name )
439440
440441 return ExplodeResult (result_df , row_labels , continuation_rows )
441442
442443
444+ def _prepare_explosion_dataframe (
445+ dataframe : pd .DataFrame , array_columns : list [str ]
446+ ) -> tuple [pd .DataFrame , list [str ], str | None ]:
447+ """Prepares the DataFrame for explosion by ensuring grouping columns exist."""
448+ work_df = dataframe
449+ non_array_columns = work_df .columns .drop (array_columns ).tolist ()
450+
451+ if not non_array_columns :
452+ work_df = work_df .copy () # Avoid modifying input
453+ # Add a temporary column to allow grouping if all columns are arrays.
454+ non_array_columns = ["_temp_grouping_col" ]
455+ work_df ["_temp_grouping_col" ] = range (len (work_df ))
456+
457+ original_index_name = None
458+ if work_df .index .name :
459+ original_index_name = work_df .index .name
460+ work_df = work_df .reset_index ()
461+ non_array_columns .append (original_index_name )
462+ else :
463+ work_df = work_df .reset_index (names = ["_original_index" ])
464+ non_array_columns .append ("_original_index" )
465+
466+ return work_df , non_array_columns , original_index_name
467+
468+
443469def _flatten_struct_columns (
444470 dataframe : pd .DataFrame ,
445471 struct_columns : tuple [str , ...],
@@ -466,8 +492,6 @@ def _flatten_struct_columns(
466492 if isinstance (col_data .dtype , pd .ArrowDtype ):
467493 pa_type = cast (pd .ArrowDtype , col_data .dtype ).pyarrow_dtype
468494
469- # Use PyArrow to flatten the struct column without row iteration
470- # combine_chunks() ensures we have a single array if it was chunked
471495 arrow_array = pa .array (col_data )
472496 flattened_fields = arrow_array .flatten ()
473497
@@ -478,7 +502,6 @@ def _flatten_struct_columns(
478502 current_nested_cols .add (new_col_name )
479503 current_clear_cols .append (new_col_name )
480504
481- # Create a new Series from the flattened array
482505 new_cols_to_add [new_col_name ] = pd .Series (
483506 flattened_fields [field_idx ],
484507 dtype = pd .ArrowDtype (field .type ),
0 commit comments