|
17 | 17 |
|
18 | 18 | import unittest
|
19 | 19 |
|
20 |
| -from pyspark.sql.functions import spark_partition_id |
| 20 | +from pyspark.sql.functions import spark_partition_id, col, lit, when |
21 | 21 | from pyspark.sql.types import (
|
22 | 22 | StringType,
|
23 | 23 | IntegerType,
|
24 | 24 | DoubleType,
|
25 | 25 | StructType,
|
26 | 26 | StructField,
|
27 | 27 | )
|
28 |
| -from pyspark.errors import PySparkTypeError |
| 28 | +from pyspark.errors import PySparkTypeError, PySparkValueError |
29 | 29 | from pyspark.testing.sqlutils import ReusedSQLTestCase
|
30 | 30 |
|
31 | 31 |
|
@@ -84,6 +84,120 @@ def test_repartition_by_range(self):
|
84 | 84 | messageParameters={"arg_name": "numPartitions", "arg_type": "list"},
|
85 | 85 | )
|
86 | 86 |
|
| 87 | + def test_repartition_by_id(self): |
| 88 | + # Test basic partition ID passthrough behavior |
| 89 | + numPartitions = 10 |
| 90 | + df = self.spark.range(100).withColumn("expected_p_id", col("id") % numPartitions) |
| 91 | + repartitioned = df.repartitionById(numPartitions, col("expected_p_id").cast("int")) |
| 92 | + result = repartitioned.withColumn("actual_p_id", spark_partition_id()) |
| 93 | + |
| 94 | + # All rows should be in their expected partitions |
| 95 | + self.assertEqual(result.filter(col("expected_p_id") != col("actual_p_id")).count(), 0) |
| 96 | + |
| 97 | + def test_repartition_by_id_negative_values(self): |
| 98 | + df = self.spark.range(10).toDF("id") |
| 99 | + repartitioned = df.repartitionById(10, (col("id") - 5).cast("int")) |
| 100 | + result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect() |
| 101 | + |
| 102 | + for row in result: |
| 103 | + actualPartitionId = row["actual_p_id"] |
| 104 | + id_val = row["id"] |
| 105 | + expectedPartitionId = int((id_val - 5) % 10) |
| 106 | + self.assertEqual( |
| 107 | + actualPartitionId, |
| 108 | + expectedPartitionId, |
| 109 | + f"Row with id={id_val} should be in partition {expectedPartitionId}, " |
| 110 | + f"but was in partition {actualPartitionId}", |
| 111 | + ) |
| 112 | + |
| 113 | + def test_repartition_by_id_null_values(self): |
| 114 | + # Test that null partition ids go to partition 0 |
| 115 | + df = self.spark.range(10).toDF("id") |
| 116 | + partitionExpr = when(col("id") < 5, col("id")).otherwise(lit(None)).cast("int") |
| 117 | + repartitioned = df.repartitionById(10, partitionExpr) |
| 118 | + result = repartitioned.withColumn("actual_p_id", spark_partition_id()).collect() |
| 119 | + |
| 120 | + nullRows = [row for row in result if row["id"] >= 5] |
| 121 | + self.assertTrue(len(nullRows) > 0, "Should have rows with null partition expression") |
| 122 | + for row in nullRows: |
| 123 | + self.assertEqual( |
| 124 | + row["actual_p_id"], |
| 125 | + 0, |
| 126 | + f"Row with null partition id should go to partition 0, " |
| 127 | + f"but went to partition {row['actual_p_id']}", |
| 128 | + ) |
| 129 | + |
| 130 | + nonNullRows = [row for row in result if row["id"] < 5] |
| 131 | + for row in nonNullRows: |
| 132 | + id_val = row["id"] |
| 133 | + actualPartitionId = row["actual_p_id"] |
| 134 | + expectedPartitionId = id_val % 10 |
| 135 | + self.assertEqual( |
| 136 | + actualPartitionId, |
| 137 | + expectedPartitionId, |
| 138 | + f"Row with id={id_val} should be in partition {expectedPartitionId}, " |
| 139 | + f"but was in partition {actualPartitionId}", |
| 140 | + ) |
| 141 | + |
| 142 | + def test_repartition_by_id_error_non_int_type(self): |
| 143 | + # Test error for non-integer partition column type |
| 144 | + df = self.spark.range(5).withColumn("s", lit("a")) |
| 145 | + with self.assertRaises(Exception): # Should raise analysis exception |
| 146 | + df.repartitionById(5, col("s")).collect() |
| 147 | + |
| 148 | + def test_repartition_by_id_error_invalid_num_partitions(self): |
| 149 | + df = self.spark.range(5) |
| 150 | + |
| 151 | + with self.assertRaises(PySparkTypeError) as pe: |
| 152 | + df.repartitionById("5", col("id").cast("int")) |
| 153 | + self.check_error( |
| 154 | + exception=pe.exception, |
| 155 | + errorClass="NOT_INT", |
| 156 | + messageParameters={"arg_name": "numPartitions", "arg_type": "str"}, |
| 157 | + ) |
| 158 | + |
| 159 | + with self.assertRaises(PySparkValueError) as pe: |
| 160 | + df.repartitionById(0, col("id").cast("int")) |
| 161 | + self.check_error( |
| 162 | + exception=pe.exception, |
| 163 | + errorClass="VALUE_NOT_POSITIVE", |
| 164 | + messageParameters={"arg_name": "numPartitions", "arg_value": "0"}, |
| 165 | + ) |
| 166 | + |
| 167 | + # Test negative numPartitions |
| 168 | + with self.assertRaises(PySparkValueError) as pe: |
| 169 | + df.repartitionById(-1, col("id").cast("int")) |
| 170 | + self.check_error( |
| 171 | + exception=pe.exception, |
| 172 | + errorClass="VALUE_NOT_POSITIVE", |
| 173 | + messageParameters={"arg_name": "numPartitions", "arg_value": "-1"}, |
| 174 | + ) |
| 175 | + |
| 176 | + def test_repartition_by_id_out_of_range(self): |
| 177 | + numPartitions = 10 |
| 178 | + df = self.spark.range(20).toDF("id") |
| 179 | + repartitioned = df.repartitionById(numPartitions, col("id").cast("int")) |
| 180 | + result = repartitioned.collect() |
| 181 | + |
| 182 | + self.assertEqual(len(result), 20) |
| 183 | + # Skip RDD partition count check for Connect mode since RDD is not available |
| 184 | + try: |
| 185 | + self.assertEqual(repartitioned.rdd.getNumPartitions(), numPartitions) |
| 186 | + except Exception: |
| 187 | + # Connect mode doesn't support RDD operations, so we skip this check |
| 188 | + pass |
| 189 | + |
| 190 | + def test_repartition_by_id_string_column_name(self): |
| 191 | + numPartitions = 5 |
| 192 | + df = self.spark.range(25).withColumn( |
| 193 | + "partition_id", (col("id") % numPartitions).cast("int") |
| 194 | + ) |
| 195 | + repartitioned = df.repartitionById(numPartitions, "partition_id") |
| 196 | + result = repartitioned.withColumn("actual_p_id", spark_partition_id()) |
| 197 | + |
| 198 | + mismatches = result.filter(col("partition_id") != col("actual_p_id")).count() |
| 199 | + self.assertEqual(mismatches, 0) |
| 200 | + |
87 | 201 |
|
88 | 202 | class DataFrameRepartitionTests(
|
89 | 203 | DataFrameRepartitionTestsMixin,
|
|
0 commit comments