From 350f76503825aeaa9d178ceace41638c95eeb56e Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:59:30 -0800 Subject: [PATCH 1/4] refactor: consolidate ArrowStreamAggArrowIterUDFSerializer into ArrowStreamAggArrowUDFSerializer --- python/pyspark/worker.py | 50 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 50e71fb6da9d..3b94df71a63f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -66,7 +66,6 @@ ArrowStreamArrowUDFSerializer, ArrowStreamAggPandasUDFSerializer, ArrowStreamAggArrowUDFSerializer, - ArrowStreamAggArrowIterUDFSerializer, ArrowBatchUDFSerializer, ArrowStreamUDTFSerializer, ArrowStreamArrowUDTFSerializer, @@ -2737,12 +2736,9 @@ def read_udfs(pickleSer, infile, eval_type): or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF ): ser = GroupArrowUDFSerializer(runner_conf.assign_cols_by_name) - elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF: - ser = ArrowStreamAggArrowIterUDFSerializer( - runner_conf.timezone, True, runner_conf.assign_cols_by_name, True - ) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, + PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF, PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, ): ser = ArrowStreamAggArrowUDFSerializer( @@ -3265,6 +3261,50 @@ def mapper(a): batch_iter = (tuple(batch_columns[o] for o in arg_offsets) for batch_columns in a) return f(batch_iter) + elif eval_type in ( + PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, + PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, + ): + import pyarrow as pa + + udfs = [] + for i in range(num_udfs): + udfs.append( + read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=i, profiler=profiler + ) + ) + + # For SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF, + # convert iterator of batch columns to a concatenated RecordBatch + def mapper(a): + # a is Iterator[Tuple[pa.Array, ...]] - convert to RecordBatch + batches = [] + for batch_columns in a: + # batch_columns is Tuple[pa.Array, ...] - convert to RecordBatch + batch = pa.RecordBatch.from_arrays(batch_columns) + batches.append(batch) + + # Concatenate all batches into one + if hasattr(pa, "concat_batches"): + concatenated_batch = pa.concat_batches(batches) + else: + # pyarrow.concat_batches not supported in old versions + concatenated_batch = pa.RecordBatch.from_struct_array( + pa.concat_arrays([b.to_struct_array() for b in batches]) + ) + + # Extract series using offsets (concatenated_batch.columns[o] gives pa.Array) + result = tuple( + f(*[concatenated_batch.columns[o] for o in arg_offsets]) for arg_offsets, f in udfs + ) + # In the special case of a single UDF this will return a single result rather + # than a tuple of results; this is the format that the JVM side expects. + if len(result) == 1: + return result[0] + else: + return result + else: def mapper(a): From 294e98f7798b591832f57dfc3b9c396bf38c523f Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Thu, 4 Dec 2025 15:02:42 -0800 Subject: [PATCH 2/4] refactor: remove ArrowStreamAggArrowIterUDFSerializer --- python/pyspark/sql/pandas/serializers.py | 59 ++---------------------- 1 file changed, 4 insertions(+), 55 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 667af40c36bc..28ff04d9d188 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1134,7 +1134,7 @@ def __repr__(self): return "GroupArrowUDFSerializer" -# Serializer for SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF +# Serializer for SQL_GROUPED_AGG_ARROW_UDF, SQL_WINDOW_AGG_ARROW_UDF, and SQL_GROUPED_AGG_ARROW_ITER_UDF class ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer): def __init__( self, @@ -1156,59 +1156,8 @@ def __init__( def load_stream(self, stream): """ - Flatten the struct into Arrow's record batches. - """ - import pyarrow as pa - - dataframes_in_group = None - - while dataframes_in_group is None or dataframes_in_group > 0: - dataframes_in_group = read_int(stream) - - if dataframes_in_group == 1: - batches = ArrowStreamSerializer.load_stream(self, stream) - if hasattr(pa, "concat_batches"): - yield pa.concat_batches(batches) - else: - # pyarrow.concat_batches not supported in old versions - yield pa.RecordBatch.from_struct_array( - pa.concat_arrays([b.to_struct_array() for b in batches]) - ) - - elif dataframes_in_group != 0: - raise PySparkValueError( - errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", - messageParameters={"dataframes_in_group": str(dataframes_in_group)}, - ) - - def __repr__(self): - return "ArrowStreamAggArrowUDFSerializer" - - -# Serializer for SQL_GROUPED_AGG_ARROW_ITER_UDF -class ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer): - def __init__( - self, - timezone, - safecheck, - assign_cols_by_name, - arrow_cast, - ): - super().__init__( - timezone=timezone, - safecheck=safecheck, - assign_cols_by_name=False, - arrow_cast=True, - ) - self._timezone = timezone - self._safecheck = safecheck - self._assign_cols_by_name = assign_cols_by_name - self._arrow_cast = arrow_cast - - def load_stream(self, stream): - """ - Yield an iterator that produces one list of column arrays per batch. - Each group yields Iterator[List[pa.Array]], allowing UDF to process batches one by one + Yield an iterator that produces one tuple of column arrays per batch. + Each group yields Iterator[Tuple[pa.Array, ...]], allowing UDF to process batches one by one without consuming all batches upfront. """ dataframes_in_group = None @@ -1234,7 +1183,7 @@ def load_stream(self, stream): ) def __repr__(self): - return "ArrowStreamAggArrowIterUDFSerializer" + return "ArrowStreamAggArrowUDFSerializer" # Serializer for SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF From 72b730afd787318d262ade811819b74998fea2fd Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Thu, 4 Dec 2025 21:54:49 -0800 Subject: [PATCH 3/4] fix: test and comment --- python/pyspark/sql/pandas/serializers.py | 3 ++- python/pyspark/worker.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 28ff04d9d188..b17daa386d45 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1134,7 +1134,8 @@ def __repr__(self): return "GroupArrowUDFSerializer" -# Serializer for SQL_GROUPED_AGG_ARROW_UDF, SQL_WINDOW_AGG_ARROW_UDF, and SQL_GROUPED_AGG_ARROW_ITER_UDF +# Serializer for SQL_GROUPED_AGG_ARROW_UDF, SQL_WINDOW_AGG_ARROW_UDF, +# and SQL_GROUPED_AGG_ARROW_ITER_UDF class ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer): def __init__( self, diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3b94df71a63f..a7bda89575a2 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -3282,7 +3282,9 @@ def mapper(a): batches = [] for batch_columns in a: # batch_columns is Tuple[pa.Array, ...] - convert to RecordBatch - batch = pa.RecordBatch.from_arrays(batch_columns) + batch = pa.RecordBatch.from_arrays( + batch_columns, names=["_%d" % i for i in range(len(batch_columns))] + ) batches.append(batch) # Concatenate all batches into one From 7b9526fb9dc2f49c50f3d2623e461b5345b0fb40 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 5 Dec 2025 14:26:30 -0800 Subject: [PATCH 4/4] fix: merge --- python/pyspark/worker.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index a7bda89575a2..7f83f8e2cd6b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -3267,14 +3267,6 @@ def mapper(a): ): import pyarrow as pa - udfs = [] - for i in range(num_udfs): - udfs.append( - read_single_udf( - pickleSer, infile, eval_type, runner_conf, udf_index=i, profiler=profiler - ) - ) - # For SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF, # convert iterator of batch columns to a concatenated RecordBatch def mapper(a):