Skip to content

Commit 36ed5ee

Browse files
shujingyang-dbzhengruifeng
authored andcommitted
[SPARK-53429][PYTHON] Support Direct Passthrough Partitioning in the PySpark Dataframe API
### What changes were proposed in this pull request? This PR implements the repartitionById method for PySpark DataFrames ### Why are the changes needed? Support Direct Passthrough Partitioning in the PySpark ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? New unit tests. ### Was this patch authored or co-authored using generative AI tooling? Closes #52295 from shujingyang-db/direct-passthrough-pyspark-api. Lead-authored-by: Shujing Yang <[email protected]> Co-authored-by: Shujing Yang <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 71c67b0 commit 36ed5ee

File tree

5 files changed

+255
-2
lines changed

5 files changed

+255
-2
lines changed

python/pyspark/sql/classic/dataframe.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,30 @@ def repartitionByRange( # type: ignore[misc]
569569
},
570570
)
571571

572+
def repartitionById(
573+
self, numPartitions: int, partitionIdCol: "ColumnOrName"
574+
) -> ParentDataFrame:
575+
if not isinstance(numPartitions, (int, bool)):
576+
raise PySparkTypeError(
577+
errorClass="NOT_INT",
578+
messageParameters={
579+
"arg_name": "numPartitions",
580+
"arg_type": type(numPartitions).__name__,
581+
},
582+
)
583+
if numPartitions <= 0:
584+
raise PySparkValueError(
585+
errorClass="VALUE_NOT_POSITIVE",
586+
messageParameters={
587+
"arg_name": "numPartitions",
588+
"arg_value": str(numPartitions),
589+
},
590+
)
591+
return DataFrame(
592+
self._jdf.repartitionById(numPartitions, _to_java_column(partitionIdCol)),
593+
self.sparkSession,
594+
)
595+
572596
def distinct(self) -> ParentDataFrame:
573597
return DataFrame(self._jdf.distinct(), self.sparkSession)
574598

python/pyspark/sql/connect/dataframe.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
from pyspark.sql.column import Column
8383
from pyspark.sql.connect.expressions import (
8484
ColumnReference,
85+
DirectShufflePartitionID,
8586
SubqueryExpression,
8687
UnresolvedRegex,
8788
UnresolvedStar,
@@ -443,6 +444,38 @@ def repartitionByRange( # type: ignore[misc]
443444
res._cached_schema = self._cached_schema
444445
return res
445446

447+
def repartitionById(
448+
self, numPartitions: int, partitionIdCol: "ColumnOrName"
449+
) -> ParentDataFrame:
450+
from pyspark.sql.connect.column import Column as ConnectColumn
451+
452+
if not isinstance(numPartitions, int) or isinstance(numPartitions, bool):
453+
raise PySparkTypeError(
454+
errorClass="NOT_INT",
455+
messageParameters={
456+
"arg_name": "numPartitions",
457+
"arg_type": type(numPartitions).__name__,
458+
},
459+
)
460+
if numPartitions <= 0:
461+
raise PySparkValueError(
462+
errorClass="VALUE_NOT_POSITIVE",
463+
messageParameters={
464+
"arg_name": "numPartitions",
465+
"arg_value": str(numPartitions),
466+
},
467+
)
468+
469+
partition_connect_col = cast(ConnectColumn, F._to_col(partitionIdCol))
470+
direct_partition_expr = DirectShufflePartitionID(partition_connect_col._expr)
471+
direct_partition_col = ConnectColumn(direct_partition_expr)
472+
res = DataFrame(
473+
plan.RepartitionByExpression(self._plan, numPartitions, [direct_partition_col]),
474+
self._session,
475+
)
476+
res._cached_schema = self._cached_schema
477+
return res
478+
446479
def dropDuplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame:
447480
if subset is not None and not isinstance(subset, (list, tuple)):
448481
raise PySparkTypeError(

python/pyspark/sql/connect/expressions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,3 +1343,24 @@ def __repr__(self) -> str:
13431343
repr_parts.append(f"values={self._in_subquery_values}")
13441344

13451345
return f"SubqueryExpression({', '.join(repr_parts)})"
1346+
1347+
1348+
class DirectShufflePartitionID(Expression):
1349+
"""
1350+
Expression that takes a partition ID value and passes it through directly for use in
1351+
shuffle partitioning. This is used with RepartitionByExpression to allow users to
1352+
directly specify target partition IDs.
1353+
"""
1354+
1355+
def __init__(self, child: Expression):
1356+
super().__init__()
1357+
assert child is not None and isinstance(child, Expression)
1358+
self._child = child
1359+
1360+
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
1361+
expr = self._create_proto_expression()
1362+
expr.direct_shuffle_partition_id.child.CopyFrom(self._child.to_plan(session))
1363+
return expr
1364+
1365+
def __repr__(self) -> str:
1366+
return f"DirectShufflePartitionID(child={self._child})"

python/pyspark/sql/dataframe.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,6 +1887,67 @@ def repartitionByRange(
18871887
"""
18881888
...
18891889

1890+
@dispatch_df_method
1891+
def repartitionById(self, numPartitions: int, partitionIdCol: "ColumnOrName") -> "DataFrame":
1892+
"""
1893+
Returns a new :class:`DataFrame` partitioned by the given partition ID expression.
1894+
Each row's target partition is determined directly by the value of the partition ID column.
1895+
1896+
.. versionadded:: 4.1.0
1897+
1898+
.. versionchanged:: 4.1.0
1899+
Supports Spark Connect.
1900+
1901+
Parameters
1902+
----------
1903+
numPartitions : int
1904+
target number of partitions
1905+
partitionIdCol : str or :class:`Column`
1906+
column expression that evaluates to the target partition ID for each row.
1907+
Must be an integer type. Values are taken modulo numPartitions to determine
1908+
the final partition. Null values are sent to partition 0.
1909+
1910+
Returns
1911+
-------
1912+
:class:`DataFrame`
1913+
Repartitioned DataFrame.
1914+
1915+
Notes
1916+
-----
1917+
The partition ID expression must evaluate to an integer type.
1918+
Partition IDs are taken modulo numPartitions, so values outside the range [0, numPartitions)
1919+
are automatically mapped to valid partition IDs. If the partition ID expression evaluates to
1920+
a NULL value, the row is sent to partition 0.
1921+
1922+
This method provides direct control over partition placement, similar to RDD's
1923+
partitionBy with custom partitioners, but at the DataFrame level.
1924+
1925+
Examples
1926+
--------
1927+
Partition rows based on a computed partition ID:
1928+
1929+
>>> from pyspark.sql import functions as sf
1930+
>>> from pyspark.sql.functions import col
1931+
>>> df = spark.range(10).withColumn("partition_id", (col("id") % 3).cast("int"))
1932+
>>> repartitioned = df.repartitionById(3, "partition_id")
1933+
>>> repartitioned.select("id", "partition_id", sf.spark_partition_id()).orderBy("id").show()
1934+
+---+------------+--------------------+
1935+
| id|partition_id|SPARK_PARTITION_ID()|
1936+
+---+------------+--------------------+
1937+
| 0| 0| 0|
1938+
| 1| 1| 1|
1939+
| 2| 2| 2|
1940+
| 3| 0| 0|
1941+
| 4| 1| 1|
1942+
| 5| 2| 2|
1943+
| 6| 0| 0|
1944+
| 7| 1| 1|
1945+
| 8| 2| 2|
1946+
| 9| 0| 0|
1947+
+---+------------+--------------------+
1948+
"""
1949+
...
1950+
18901951
@dispatch_df_method
18911952
def distinct(self) -> "DataFrame":
18921953
"""Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`.

python/pyspark/sql/tests/test_repartition.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
import unittest
1919

20-
from pyspark.sql.functions import spark_partition_id
20+
from pyspark.sql.functions import spark_partition_id, col, lit, when
2121
from pyspark.sql.types import (
2222
StringType,
2323
IntegerType,
2424
DoubleType,
2525
StructType,
2626
StructField,
2727
)
28-
from pyspark.errors import PySparkTypeError
28+
from pyspark.errors import PySparkTypeError, PySparkValueError
2929
from pyspark.testing.sqlutils import ReusedSQLTestCase
3030

3131

@@ -84,6 +84,120 @@ def test_repartition_by_range(self):
8484
messageParameters={"arg_name": "numPartitions", "arg_type": "list"},
8585
)
8686

87+
def test_repartition_by_id(self):
88+
# Test basic partition ID passthrough behavior
89+
numPartitions = 10
90+
df = self.spark.range(100).withColumn("expected_p_id", col("id") % numPartitions)
91+
repartitioned = df.repartitionById(numPartitions, col("expected_p_id").cast("int"))
92+
result = repartitioned.withColumn("actual_p_id", spark_partition_id())
93+
94+
# All rows should be in their expected partitions
95+
self.assertEqual(result.filter(col("expected_p_id") != col("actual_p_id")).count(), 0)
96+
97+
def test_repartition_by_id_negative_values(self):
98+
df = self.spark.range(10).toDF("id")
99+
repartitioned = df.repartitionById(10, (col("id") - 5).cast("int"))
100+
result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect()
101+
102+
for row in result:
103+
actualPartitionId = row["actual_p_id"]
104+
id_val = row["id"]
105+
expectedPartitionId = int((id_val - 5) % 10)
106+
self.assertEqual(
107+
actualPartitionId,
108+
expectedPartitionId,
109+
f"Row with id={id_val} should be in partition {expectedPartitionId}, "
110+
f"but was in partition {actualPartitionId}",
111+
)
112+
113+
def test_repartition_by_id_null_values(self):
114+
# Test that null partition ids go to partition 0
115+
df = self.spark.range(10).toDF("id")
116+
partitionExpr = when(col("id") < 5, col("id")).otherwise(lit(None)).cast("int")
117+
repartitioned = df.repartitionById(10, partitionExpr)
118+
result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect()
119+
120+
nullRows = [row for row in result if row["id"] >= 5]
121+
self.assertTrue(len(nullRows) > 0, "Should have rows with null partition expression")
122+
for row in nullRows:
123+
self.assertEqual(
124+
row["actual_p_id"],
125+
0,
126+
f"Row with null partition id should go to partition 0, "
127+
f"but went to partition {row['actual_p_id']}",
128+
)
129+
130+
nonNullRows = [row for row in result if row["id"] < 5]
131+
for row in nonNullRows:
132+
id_val = row["id"]
133+
actualPartitionId = row["actual_p_id"]
134+
expectedPartitionId = id_val % 10
135+
self.assertEqual(
136+
actualPartitionId,
137+
expectedPartitionId,
138+
f"Row with id={id_val} should be in partition {expectedPartitionId}, "
139+
f"but was in partition {actualPartitionId}",
140+
)
141+
142+
def test_repartition_by_id_error_non_int_type(self):
143+
# Test error for non-integer partition column type
144+
df = self.spark.range(5).withColumn("s", lit("a"))
145+
with self.assertRaises(Exception): # Should raise analysis exception
146+
df.repartitionById(5, col("s")).collect()
147+
148+
def test_repartition_by_id_error_invalid_num_partitions(self):
149+
df = self.spark.range(5)
150+
151+
with self.assertRaises(PySparkTypeError) as pe:
152+
df.repartitionById("5", col("id").cast("int"))
153+
self.check_error(
154+
exception=pe.exception,
155+
errorClass="NOT_INT",
156+
messageParameters={"arg_name": "numPartitions", "arg_type": "str"},
157+
)
158+
159+
with self.assertRaises(PySparkValueError) as pe:
160+
df.repartitionById(0, col("id").cast("int"))
161+
self.check_error(
162+
exception=pe.exception,
163+
errorClass="VALUE_NOT_POSITIVE",
164+
messageParameters={"arg_name": "numPartitions", "arg_value": "0"},
165+
)
166+
167+
# Test negative numPartitions
168+
with self.assertRaises(PySparkValueError) as pe:
169+
df.repartitionById(-1, col("id").cast("int"))
170+
self.check_error(
171+
exception=pe.exception,
172+
errorClass="VALUE_NOT_POSITIVE",
173+
messageParameters={"arg_name": "numPartitions", "arg_value": "-1"},
174+
)
175+
176+
def test_repartition_by_id_out_of_range(self):
177+
numPartitions = 10
178+
df = self.spark.range(20).toDF("id")
179+
repartitioned = df.repartitionById(numPartitions, col("id").cast("int"))
180+
result = repartitioned.collect()
181+
182+
self.assertEqual(len(result), 20)
183+
# Skip RDD partition count check for Connect mode since RDD is not available
184+
try:
185+
self.assertEqual(repartitioned.rdd.getNumPartitions(), numPartitions)
186+
except Exception:
187+
# Connect mode doesn't support RDD operations, so we skip this check
188+
pass
189+
190+
def test_repartition_by_id_string_column_name(self):
191+
numPartitions = 5
192+
df = self.spark.range(25).withColumn(
193+
"partition_id", (col("id") % numPartitions).cast("int")
194+
)
195+
repartitioned = df.repartitionById(numPartitions, "partition_id")
196+
result = repartitioned.withColumn("actual_p_id", spark_partition_id())
197+
198+
mismatches = result.filter(col("partition_id") != col("actual_p_id")).count()
199+
self.assertEqual(mismatches, 0)
200+
87201

88202
class DataFrameRepartitionTests(
89203
DataFrameRepartitionTestsMixin,

0 commit comments

Comments
 (0)