diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 667af40c36bc..b17daa386d45 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1134,7 +1134,8 @@ def __repr__(self): return "GroupArrowUDFSerializer" -# Serializer for SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF +# Serializer for SQL_GROUPED_AGG_ARROW_UDF, SQL_WINDOW_AGG_ARROW_UDF, +# and SQL_GROUPED_AGG_ARROW_ITER_UDF class ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer): def __init__( self, @@ -1156,59 +1157,8 @@ def __init__( def load_stream(self, stream): """ - Flatten the struct into Arrow's record batches. - """ - import pyarrow as pa - - dataframes_in_group = None - - while dataframes_in_group is None or dataframes_in_group > 0: - dataframes_in_group = read_int(stream) - - if dataframes_in_group == 1: - batches = ArrowStreamSerializer.load_stream(self, stream) - if hasattr(pa, "concat_batches"): - yield pa.concat_batches(batches) - else: - # pyarrow.concat_batches not supported in old versions - yield pa.RecordBatch.from_struct_array( - pa.concat_arrays([b.to_struct_array() for b in batches]) - ) - - elif dataframes_in_group != 0: - raise PySparkValueError( - errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", - messageParameters={"dataframes_in_group": str(dataframes_in_group)}, - ) - - def __repr__(self): - return "ArrowStreamAggArrowUDFSerializer" - - -# Serializer for SQL_GROUPED_AGG_ARROW_ITER_UDF -class ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer): - def __init__( - self, - timezone, - safecheck, - assign_cols_by_name, - arrow_cast, - ): - super().__init__( - timezone=timezone, - safecheck=safecheck, - assign_cols_by_name=False, - arrow_cast=True, - ) - self._timezone = timezone - self._safecheck = safecheck - self._assign_cols_by_name = assign_cols_by_name - self._arrow_cast = arrow_cast - - def load_stream(self, stream): - """ - Yield an iterator that produces one list of column arrays per batch. - Each group yields Iterator[List[pa.Array]], allowing UDF to process batches one by one + Yield an iterator that produces one tuple of column arrays per batch. + Each group yields Iterator[Tuple[pa.Array, ...]], allowing UDF to process batches one by one without consuming all batches upfront. """ dataframes_in_group = None @@ -1234,7 +1184,7 @@ def load_stream(self, stream): ) def __repr__(self): - return "ArrowStreamAggArrowIterUDFSerializer" + return "ArrowStreamAggArrowUDFSerializer" # Serializer for SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 50e71fb6da9d..7f83f8e2cd6b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -66,7 +66,6 @@ ArrowStreamArrowUDFSerializer, ArrowStreamAggPandasUDFSerializer, ArrowStreamAggArrowUDFSerializer, - ArrowStreamAggArrowIterUDFSerializer, ArrowBatchUDFSerializer, ArrowStreamUDTFSerializer, ArrowStreamArrowUDTFSerializer, @@ -2737,12 +2736,9 @@ def read_udfs(pickleSer, infile, eval_type): or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF ): ser = GroupArrowUDFSerializer(runner_conf.assign_cols_by_name) - elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF: - ser = ArrowStreamAggArrowIterUDFSerializer( - runner_conf.timezone, True, runner_conf.assign_cols_by_name, True - ) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, + PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF, PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, ): ser = ArrowStreamAggArrowUDFSerializer( @@ -3265,6 +3261,44 @@ def mapper(a): batch_iter = (tuple(batch_columns[o] for o in arg_offsets) for batch_columns in a) return f(batch_iter) + elif eval_type in ( + PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, + PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, + ): + import pyarrow as pa + + # For SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF, + # convert iterator of batch columns to a concatenated RecordBatch + def mapper(a): + # a is Iterator[Tuple[pa.Array, ...]] - convert to RecordBatch + batches = [] + for batch_columns in a: + # batch_columns is Tuple[pa.Array, ...] - convert to RecordBatch + batch = pa.RecordBatch.from_arrays( + batch_columns, names=["_%d" % i for i in range(len(batch_columns))] + ) + batches.append(batch) + + # Concatenate all batches into one + if hasattr(pa, "concat_batches"): + concatenated_batch = pa.concat_batches(batches) + else: + # pyarrow.concat_batches not supported in old versions + concatenated_batch = pa.RecordBatch.from_struct_array( + pa.concat_arrays([b.to_struct_array() for b in batches]) + ) + + # Extract series using offsets (concatenated_batch.columns[o] gives pa.Array) + result = tuple( + f(*[concatenated_batch.columns[o] for o in arg_offsets]) for arg_offsets, f in udfs + ) + # In the special case of a single UDF this will return a single result rather + # than a tuple of results; this is the format that the JVM side expects. + if len(result) == 1: + return result[0] + else: + return result + else: def mapper(a):