Skip to content

Commit 9e12201

Browse files
Kimahrimanzhengruifeng
authored andcommitted
[SPARK-49547][SQL][PYTHON] Add iterator of RecordBatch API to applyInArrow
<!-- Thanks for sending a pull request! Here are some tips for you: 1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html 2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html 3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'. 4. Be sure to keep the PR description updated to reflect all changes. 5. Please write your PR title to summarize what this PR proposes. 6. If possible, provide a concise example to reproduce the issue for a faster review. 7. If you want to add a new configuration, please read the guideline first for naming configurations in 'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'. 8. If you want to add or modify an error type or message, please read the guideline first in 'common/utils/src/main/resources/error/README.md'. --> ### What changes were proposed in this pull request? <!-- Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. See the examples below. 1. If you refactor some codes with changing classes, showing the class hierarchy will help reviewers. 2. If you fix some SQL features, you can provide some references of other DBMSes. 3. If there is design documentation, please add the link. 4. If there is a discussion in the mailing list, please add the link. --> Add the option to `applyInArrow` to take a function that takes an iterator of `RecordBatch` and returns an iterator of `RecordBatch`. A new eval type is added `SQL_GROUPED_MAP_ARROW_ITER_UDF`, and is detected via type hints on the function. ### Why are the changes needed? <!-- Please clarify why the changes are needed. For instance, 1. If you propose a new API, clarify the use case for a new API. 2. If you fix a bug, you can clarify why it is a bug. --> Having a single Table as input and a single Table as output requires collecting all inputs and outputs in memory for a single batch. This can require excessive memory for certain edge cases with large groups. Inputs and outputs already get serialized as record batches, so simply expose this lazy iterator directly instead of forcing materialization into a table. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as the documentation fix. If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible. If possible, please also clarify if this is a user-facing change compared to the released Spark versions or within the unreleased branches such as master. If no, write 'No'. --> Yes, a new function signature supported by `applyInArrow`. Example: ```python import pyarrow as pa import pyarrow.compute as pc def sum_func(key: Tuple[pa.Scalar, ...], batches: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: total = 0 for batch in batches: total += pc.sum(batch.column("v")).as_py() yield pyarrow.RecordBatch.from_pydict({"id": [key[0].as_py()], "v": [total]}) df.groupby("id").applyInArrow(sum_func, schema="id long, v double").show() ``` ``` +---+----+ | id| v| +---+----+ | 1| 3.0| | 2|18.0| +---+----+ ``` ### How was this patch tested? <!-- If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. If benchmark tests were added, please run the benchmarks in GitHub Actions for the consistent environment, and the instructions could accord to: https://spark.apache.org/developer-tools.html#github-workflow-benchmarks. --> Updated existing UTs to test both Table signatures and RecordBatch signatures ### Was this patch authored or co-authored using generative AI tooling? <!-- If generative AI tooling has been used in the process of authoring this patch, please include the phrase: 'Generated-by: ' followed by the name of the tool and its version. If no, write 'No'. Please refer to the [ASF Generative Tooling Guidance](https://www.apache.org/legal/generative-tooling.html) for details. --> No Closes #52440 from Kimahriman/apply-in-arrow-iter-eval. Authored-by: Adam Binford <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent e56ab2f commit 9e12201

File tree

14 files changed

+479
-125
lines changed

14 files changed

+479
-125
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ private[spark] object PythonEvalType {
6666
val SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF = 212
6767
val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF = 213
6868
val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF = 214
69+
val SQL_GROUPED_MAP_ARROW_ITER_UDF = 215
6970

7071
// Arrow UDFs
7172
val SQL_SCALAR_ARROW_UDF = 250

python/pyspark/sql/connect/group.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pyspark.sql.group import GroupedData as PySparkGroupedData
3636
from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps
3737
from pyspark.sql.pandas.functions import _validate_vectorized_udf # type: ignore[attr-defined]
38+
from pyspark.sql.pandas.typehints import infer_group_arrow_eval_type_from_func
3839
from pyspark.sql.types import NumericType, StructType
3940

4041
import pyspark.sql.connect.plan as plan
@@ -472,13 +473,22 @@ def applyInArrow(
472473
from pyspark.sql.connect.udf import UserDefinedFunction
473474
from pyspark.sql.connect.dataframe import DataFrame
474475

475-
_validate_vectorized_udf(func, PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF)
476+
try:
477+
# Try to infer the eval type from type hints
478+
eval_type = infer_group_arrow_eval_type_from_func(func)
479+
except Exception:
480+
warnings.warn("Cannot infer the eval type from type hints. ", UserWarning)
481+
482+
if eval_type is None:
483+
eval_type = PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
484+
485+
_validate_vectorized_udf(func, eval_type)
476486
if isinstance(schema, str):
477487
schema = cast(StructType, self._df._session._parse_ddl(schema))
478488
udf_obj = UserDefinedFunction(
479489
func,
480490
returnType=schema,
481-
evalType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
491+
evalType=eval_type,
482492
)
483493

484494
res = DataFrame(

python/pyspark/sql/pandas/_typing/__init__.pyi

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ from typing import (
2020
Any,
2121
Callable,
2222
Iterable,
23+
Iterator,
2324
NewType,
2425
Tuple,
2526
Type,
@@ -59,6 +60,7 @@ PandasGroupedMapUDFTransformWithStateType = Literal[211]
5960
PandasGroupedMapUDFTransformWithStateInitStateType = Literal[212]
6061
GroupedMapUDFTransformWithStateType = Literal[213]
6162
GroupedMapUDFTransformWithStateInitStateType = Literal[214]
63+
ArrowGroupedMapIterUDFType = Literal[215]
6264

6365
# Arrow UDFs
6466
ArrowScalarUDFType = Literal[250]
@@ -430,10 +432,18 @@ PandasCogroupedMapFunction = Union[
430432
Callable[[Any, DataFrameLike, DataFrameLike], DataFrameLike],
431433
]
432434

433-
ArrowGroupedMapFunction = Union[
435+
ArrowGroupedMapTableFunction = Union[
434436
Callable[[pyarrow.Table], pyarrow.Table],
435437
Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table], pyarrow.Table],
436438
]
439+
ArrowGroupedMapIterFunction = Union[
440+
Callable[[Iterator[pyarrow.RecordBatch]], Iterator[pyarrow.RecordBatch]],
441+
Callable[
442+
[Tuple[pyarrow.Scalar, ...], Iterator[pyarrow.RecordBatch]], Iterator[pyarrow.RecordBatch]
443+
],
444+
]
445+
ArrowGroupedMapFunction = Union[ArrowGroupedMapTableFunction, ArrowGroupedMapIterFunction]
446+
437447
ArrowCogroupedMapFunction = Union[
438448
Callable[[pyarrow.Table, pyarrow.Table], pyarrow.Table],
439449
Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table, pyarrow.Table], pyarrow.Table],

python/pyspark/sql/pandas/functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ def vectorized_udf(
700700
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
701701
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
702702
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
703+
PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
703704
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
704705
None,
705706
]: # None means it should infer the type from type hints.
@@ -779,6 +780,7 @@ def _validate_vectorized_udf(f, evalType, kind: str = "pandas") -> int:
779780
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
780781
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
781782
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
783+
PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
782784
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
783785
PythonEvalType.SQL_ARROW_BATCHED_UDF,
784786
]:

python/pyspark/sql/pandas/group_ops.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pyspark.util import PythonEvalType
2323
from pyspark.sql.column import Column
2424
from pyspark.sql.dataframe import DataFrame
25+
from pyspark.sql.pandas.typehints import infer_group_arrow_eval_type_from_func
2526
from pyspark.sql.streaming.state import GroupStateTimeout
2627
from pyspark.sql.streaming.stateful_processor import StatefulProcessor
2728
from pyspark.sql.types import StructType
@@ -703,27 +704,33 @@ def applyInArrow(
703704
Maps each group of the current :class:`DataFrame` using an Arrow udf and returns the result
704705
as a `DataFrame`.
705706
706-
The function should take a `pyarrow.Table` and return another
707-
`pyarrow.Table`. Alternatively, the user can pass a function that takes
708-
a tuple of `pyarrow.Scalar` grouping key(s) and a `pyarrow.Table`.
709-
For each group, all columns are passed together as a `pyarrow.Table`
710-
to the user-function and the returned `pyarrow.Table` are combined as a
711-
:class:`DataFrame`.
707+
The function can take one of two forms: It can take a `pyarrow.Table` and return a
708+
`pyarrow.Table`, or it can take an iterator of `pyarrow.RecordBatch` and yield
709+
`pyarrow.RecordBatch`. Alternatively each form can take a tuple of `pyarrow.Scalar`
710+
as the first argument in addition to the input type above. For each group, all columns
711+
are passed together in the `pyarrow.Table` or `pyarrow.RecordBatch`, and the returned
712+
`pyarrow.Table` or iterator of `pyarrow.RecordBatch` are combined as a :class:`DataFrame`.
712713
713714
The `schema` should be a :class:`StructType` describing the schema of the returned
714-
`pyarrow.Table`. The column labels of the returned `pyarrow.Table` must either match
715-
the field names in the defined schema if specified as strings, or match the
716-
field data types by position if not strings, e.g. integer indices.
717-
The length of the returned `pyarrow.Table` can be arbitrary.
715+
`pyarrow.Table` or `pyarrow.RecordBatch`. The column labels of the returned `pyarrow.Table`
716+
or `pyarrow.RecordBatch` must either match the field names in the defined schema if
717+
specified as strings, or match the field data types by position if not strings, e.g.
718+
integer indices. The length of the returned `pyarrow.Table` or iterator of
719+
`pyarrow.RecordBatch` can be arbitrary.
718720
719721
.. versionadded:: 4.0.0
720722
723+
.. versionchanged:: 4.1.0
724+
Added support for an iterator of `pyarrow.RecordBatch` API.
725+
721726
Parameters
722727
----------
723728
func : function
724-
a Python native function that takes a `pyarrow.Table` and outputs a
725-
`pyarrow.Table`, or that takes one tuple (grouping keys) and a
726-
`pyarrow.Table` and outputs a `pyarrow.Table`.
729+
a Python native function that either takes a `pyarrow.Table` and outputs a
730+
`pyarrow.Table` or takes an iterator of `pyarrow.RecordBatch` and yields
731+
`pyarrow.RecordBatch`. Additionally, each form can take a tuple of grouping keys
732+
as the first argument, with the `pyarrow.Table` or iterator of `pyarrow.RecordBatch`
733+
as the second argument.
727734
schema : :class:`pyspark.sql.types.DataType` or str
728735
the return type of the `func` in PySpark. The value can be either a
729736
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
@@ -752,6 +759,28 @@ def applyInArrow(
752759
| 2| 1.1094003924504583|
753760
+---+-------------------+
754761
762+
The function can also take and return an iterator of `pyarrow.RecordBatch` using type
763+
hints.
764+
765+
>>> df = spark.createDataFrame(
766+
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
767+
... ("id", "v")) # doctest: +SKIP
768+
>>> def sum_func(
769+
... batches: Iterator[pyarrow.RecordBatch]
770+
... ) -> Iterator[pyarrow.RecordBatch]: # doctest: +SKIP
771+
... total = 0
772+
... for batch in batches:
773+
... total += pc.sum(batch.column("v")).as_py()
774+
... yield pyarrow.RecordBatch.from_pydict({"v": [total]})
775+
>>> df.groupby("id").applyInArrow(
776+
... sum_func, schema="v double").show() # doctest: +SKIP
777+
+----+
778+
| v|
779+
+----+
780+
| 3.0|
781+
|18.0|
782+
+----+
783+
755784
Alternatively, the user can pass a function that takes two arguments.
756785
In this case, the grouping key(s) will be passed as the first argument and the data will
757786
be passed as the second argument. The grouping key(s) will be passed as a tuple of Arrow
@@ -796,11 +825,28 @@ def applyInArrow(
796825
| 2| 2| 3.0|
797826
+---+-----------+----+
798827
828+
>>> def sum_func(
829+
... key: Tuple[pyarrow.Scalar, ...], batches: Iterator[pyarrow.RecordBatch]
830+
... ) -> Iterator[pyarrow.RecordBatch]: # doctest: +SKIP
831+
... total = 0
832+
... for batch in batches:
833+
... total += pc.sum(batch.column("v")).as_py()
834+
... yield pyarrow.RecordBatch.from_pydict({"id": [key[0].as_py()], "v": [total]})
835+
>>> df.groupby("id").applyInArrow(
836+
... sum_func, schema="id long, v double").show() # doctest: +SKIP
837+
+---+----+
838+
| id| v|
839+
+---+----+
840+
| 1| 3.0|
841+
| 2|18.0|
842+
+---+----+
843+
799844
Notes
800845
-----
801-
This function requires a full shuffle. All the data of a group will be loaded
802-
into memory, so the user should be aware of the potential OOM risk if data is skewed
803-
and certain groups are too large to fit in memory.
846+
This function requires a full shuffle. If using the `pyarrow.Table` API, all data of a
847+
group will be loaded into memory, so the user should be aware of the potential OOM risk
848+
if data is skewed and certain groups are too large to fit in memory, and can use the
849+
iterator of `pyarrow.RecordBatch` API to mitigate this.
804850
805851
This API is unstable, and for developers.
806852
@@ -813,9 +859,18 @@ def applyInArrow(
813859

814860
assert isinstance(self, GroupedData)
815861

862+
try:
863+
# Try to infer the eval type from type hints
864+
eval_type = infer_group_arrow_eval_type_from_func(func)
865+
except Exception:
866+
warnings.warn("Cannot infer the eval type from type hints. ", UserWarning)
867+
868+
if eval_type is None:
869+
eval_type = PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
870+
816871
# The usage of the pandas_udf is internal so type checking is disabled.
817872
udf = pandas_udf(
818-
func, returnType=schema, functionType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
873+
func, returnType=schema, functionType=eval_type
819874
) # type: ignore[call-overload]
820875
df = self._df
821876
udf_column = udf(*[df[col] for col in df.columns])

python/pyspark/sql/pandas/serializers.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from decimal import Decimal
2323
from itertools import groupby
24-
from typing import TYPE_CHECKING, Optional
24+
from typing import TYPE_CHECKING, Iterator, Optional
2525

2626
import pyspark
2727
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
@@ -1116,19 +1116,22 @@ def load_stream(self, stream):
11161116
"""
11171117
import pyarrow as pa
11181118

1119+
def process_group(batches: "Iterator[pa.RecordBatch]"):
1120+
for batch in batches:
1121+
struct = batch.column(0)
1122+
yield pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))
1123+
11191124
dataframes_in_group = None
11201125

11211126
while dataframes_in_group is None or dataframes_in_group > 0:
11221127
dataframes_in_group = read_int(stream)
11231128

11241129
if dataframes_in_group == 1:
1125-
structs = [
1126-
batch.column(0) for batch in ArrowStreamSerializer.load_stream(self, stream)
1127-
]
1128-
yield [
1129-
pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))
1130-
for struct in structs
1131-
]
1130+
batch_iter = process_group(ArrowStreamSerializer.load_stream(self, stream))
1131+
yield batch_iter
1132+
# Make sure the batches are fully iterated before getting the next group
1133+
for _ in batch_iter:
1134+
pass
11321135

11331136
elif dataframes_in_group != 0:
11341137
raise PySparkValueError(

python/pyspark/sql/pandas/typehints.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
ArrowScalarUDFType,
3030
ArrowScalarIterUDFType,
3131
ArrowGroupedAggUDFType,
32+
ArrowGroupedMapIterUDFType,
33+
ArrowGroupedMapUDFType,
34+
ArrowGroupedMapFunction,
3235
)
3336

3437

@@ -303,6 +306,94 @@ def infer_eval_type_for_udf( # type: ignore[no-untyped-def]
303306
return None
304307

305308

309+
def infer_group_arrow_eval_type(
310+
sig: Signature,
311+
type_hints: Dict[str, Any],
312+
) -> Optional[Union["ArrowGroupedMapUDFType", "ArrowGroupedMapIterUDFType"]]:
313+
from pyspark.sql.pandas.functions import PythonEvalType
314+
315+
require_minimum_pyarrow_version()
316+
317+
import pyarrow as pa
318+
319+
annotations = {}
320+
for param in sig.parameters.values():
321+
if param.annotation is not param.empty:
322+
annotations[param.name] = type_hints.get(param.name, param.annotation)
323+
324+
# Check if all arguments have type hints
325+
parameters_sig = [
326+
annotations[parameter] for parameter in sig.parameters if parameter in annotations
327+
]
328+
if len(parameters_sig) != len(sig.parameters):
329+
raise PySparkValueError(
330+
errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
331+
messageParameters={"target": "all parameters", "sig": str(sig)},
332+
)
333+
334+
# Check if the return has a type hint
335+
return_annotation = type_hints.get("return", sig.return_annotation)
336+
if sig.empty is return_annotation:
337+
raise PySparkValueError(
338+
errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
339+
messageParameters={"target": "the return type", "sig": str(sig)},
340+
)
341+
342+
# Iterator[pa.RecordBatch] -> Iterator[pa.RecordBatch]
343+
is_iterator_batch = (
344+
len(parameters_sig) == 1
345+
and check_iterator_annotation( # Iterator
346+
parameters_sig[0],
347+
parameter_check_func=lambda t: t == pa.RecordBatch,
348+
)
349+
and check_iterator_annotation(
350+
return_annotation, parameter_check_func=lambda t: t == pa.RecordBatch
351+
)
352+
)
353+
# Tuple[pa.Scalar, ...], Iterator[pa.RecordBatch] -> Iterator[pa.RecordBatch]
354+
is_iterator_batch_with_keys = (
355+
len(parameters_sig) == 2
356+
and check_iterator_annotation( # Iterator
357+
parameters_sig[1],
358+
parameter_check_func=lambda t: t == pa.RecordBatch,
359+
)
360+
and check_iterator_annotation(
361+
return_annotation, parameter_check_func=lambda t: t == pa.RecordBatch
362+
)
363+
)
364+
365+
if is_iterator_batch or is_iterator_batch_with_keys:
366+
return PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
367+
368+
# pa.Table -> pa.Table
369+
is_table = (
370+
len(parameters_sig) == 1 and parameters_sig[0] == pa.Table and return_annotation == pa.Table
371+
)
372+
# Tuple[pa.Scalar, ...], pa.Table -> pa.Table
373+
is_table_with_keys = (
374+
len(parameters_sig) == 2 and parameters_sig[1] == pa.Table and return_annotation == pa.Table
375+
)
376+
if is_table or is_table_with_keys:
377+
return PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
378+
379+
return None
380+
381+
382+
def infer_group_arrow_eval_type_from_func(
383+
f: "ArrowGroupedMapFunction",
384+
) -> Optional[Union["ArrowGroupedMapUDFType", "ArrowGroupedMapIterUDFType"]]:
385+
argspec = getfullargspec(f)
386+
if len(argspec.annotations) > 0:
387+
try:
388+
type_hints = get_type_hints(f)
389+
except NameError:
390+
type_hints = {}
391+
392+
return infer_group_arrow_eval_type(signature(f), type_hints)
393+
else:
394+
return None
395+
396+
306397
def check_tuple_annotation(
307398
annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] = None
308399
) -> bool:

0 commit comments

Comments
 (0)