Skip to content

Commit 5ad1053

Browse files
BryanCutlerHyukjinKwon
authored andcommitted
[SPARK-28128][PYTHON][SQL] Pandas Grouped UDFs skip empty partitions
## What changes were proposed in this pull request? When running FlatMapGroupsInPandasExec or AggregateInPandasExec the shuffle uses a default number of partitions of 200 in "spark.sql.shuffle.partitions". If the data is small, e.g. in testing, many of the partitions will be empty but are treated just the same. This PR checks the `mapPartitionsInternal` iterator to be non-empty before calling `ArrowPythonRunner` to start computation on the iterator. ## How was this patch tested? Existing tests. Ran the following benchmarks a simple example where most partitions are empty: ```python from pyspark.sql.functions import pandas_udf, PandasUDFType from pyspark.sql.types import * df = spark.createDataFrame( [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) def normalize(pdf): v = pdf.v return pdf.assign(v=(v - v.mean()) / v.std()) df.groupby("id").apply(normalize).count() ``` **Before** ``` In [4]: %timeit df.groupby("id").apply(normalize).count() 1.58 s ± 62.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [5]: %timeit df.groupby("id").apply(normalize).count() 1.52 s ± 29.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [6]: %timeit df.groupby("id").apply(normalize).count() 1.52 s ± 37.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` **After this Change** ``` In [2]: %timeit df.groupby("id").apply(normalize).count() 646 ms ± 89.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [3]: %timeit df.groupby("id").apply(normalize).count() 408 ms ± 84.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [4]: %timeit df.groupby("id").apply(normalize).count() 381 ms ± 29.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` Closes apache#24926 from BryanCutler/pyspark-pandas_udf-map-agg-skip-empty-parts-SPARK-28128. Authored-by: Bryan Cutler <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent 113f8c8 commit 5ad1053

File tree

4 files changed

+31
-4
lines changed

4 files changed

+31
-4
lines changed

python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import unittest
1919

2020
from pyspark.rdd import PythonEvalType
21+
from pyspark.sql import Row
2122
from pyspark.sql.functions import array, explode, col, lit, mean, sum, \
2223
udf, pandas_udf, PandasUDFType
2324
from pyspark.sql.types import *
@@ -461,6 +462,18 @@ def test_register_vectorized_udf_basic(self):
461462
expected = [1, 5]
462463
self.assertEqual(actual, expected)
463464

465+
def test_grouped_with_empty_partition(self):
466+
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
467+
expected = [Row(id=1, sum=5), Row(id=2, x=4)]
468+
num_parts = len(data) + 1
469+
df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts))
470+
471+
f = pandas_udf(lambda x: x.sum(),
472+
'int', PandasUDFType.GROUPED_AGG)
473+
474+
result = df.groupBy('id').agg(f(df['x']).alias('sum')).collect()
475+
self.assertEqual(result, expected)
476+
464477

465478
if __name__ == "__main__":
466479
from pyspark.sql.tests.test_pandas_udf_grouped_agg import *

python/pyspark/sql/tests/test_pandas_udf_grouped_map.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,18 @@ def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
504504

505505
self.assertEquals(result.collect()[0]['sum'], 165)
506506

507+
def test_grouped_with_empty_partition(self):
508+
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
509+
expected = [Row(id=1, x=5), Row(id=1, x=5), Row(id=2, x=4)]
510+
num_parts = len(data) + 1
511+
df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts))
512+
513+
f = pandas_udf(lambda pdf: pdf.assign(x=pdf['x'].sum()),
514+
'id long, x int', PandasUDFType.GROUPED_MAP)
515+
516+
result = df.groupBy('id').apply(f).collect()
517+
self.assertEqual(result, expected)
518+
507519

508520
if __name__ == "__main__":
509521
from pyspark.sql.tests.test_pandas_udf_grouped_map import *

sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ case class AggregateInPandasExec(
105105
StructField(s"_$i", dt)
106106
})
107107

108-
inputRDD.mapPartitionsInternal { iter =>
108+
// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
109+
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
109110
val prunedProj = UnsafeProjection.create(allInputs, child.output)
110111

111112
val grouped = if (groupingExpressions.isEmpty) {
@@ -151,6 +152,6 @@ case class AggregateInPandasExec(
151152
val joinedRow = joined(leftRow, aggOutputRow)
152153
resultProj(joinedRow)
153154
}
154-
}
155+
}}
155156
}
156157
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ case class FlatMapGroupsInPandasExec(
125125
val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
126126
val dedupSchema = StructType.fromAttributes(dedupAttributes)
127127

128-
inputRDD.mapPartitionsInternal { iter =>
128+
// Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty
129+
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
129130
val grouped = if (groupingAttributes.isEmpty) {
130131
Iterator(iter)
131132
} else {
@@ -156,6 +157,6 @@ case class FlatMapGroupsInPandasExec(
156157
flattenedBatch.setNumRows(batch.numRows())
157158
flattenedBatch.rowIterator.asScala
158159
}.map(unsafeProj)
159-
}
160+
}}
160161
}
161162
}

0 commit comments

Comments
 (0)