Skip to content

Commit 9d68da5

Browse files
committed
[SPARK-51875] Support repartition(ByExpression)? and coalesce
### What changes were proposed in this pull request? This PR aims to support `repartition`, `repartitionByExpression` and `coalesce`. Note that `repartitionByRange` is not a part of this PR's scope. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #86 from dongjoon-hyun/SPARK-51875. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent d709159 commit 9d68da5

File tree

4 files changed

+140
-2
lines changed

4 files changed

+140
-2
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ public actor DataFrame: Sendable {
538538
/// - right: Right side of the join operation.
539539
/// - usingColumn: Name of the column to join on. This column must exist on both sides.
540540
/// - joinType: Type of join to perform. Default `inner`.
541-
/// - Returns: <#description#>
541+
/// - Returns: A `DataFrame`.
542542
public func join(_ right: DataFrame, _ usingColumn: String, _ joinType: String = "inner") async -> DataFrame {
543543
await join(right, [usingColumn], joinType)
544544
}
@@ -588,7 +588,7 @@ public actor DataFrame: Sendable {
588588

589589
/// Explicit cartesian join with another `DataFrame`.
590590
/// - Parameter right: Right side of the join operation.
591-
/// - Returns: Cartesian joins are very expensive without an extra filter that can be pushed down.
591+
/// - Returns: A `DataFrame`.
592592
public func crossJoin(_ right: DataFrame) async -> DataFrame {
593593
let rightPlan = await (right.getPlan() as! Plan).root
594594
let plan = SparkConnectClient.getJoin(self.plan.root, rightPlan, JoinType.cross)
@@ -676,6 +676,63 @@ public actor DataFrame: Sendable {
676676
return DataFrame(spark: self.spark, plan: plan)
677677
}
678678

679+
private func buildRepartition(numPartitions: Int32, shuffle: Bool) -> DataFrame {
680+
let plan = SparkConnectClient.getRepartition(self.plan.root, numPartitions, shuffle)
681+
return DataFrame(spark: self.spark, plan: plan)
682+
}
683+
684+
private func buildRepartitionByExpression(numPartitions: Int32?, partitionExprs: [String]) -> DataFrame {
685+
let plan = SparkConnectClient.getRepartitionByExpression(self.plan.root, partitionExprs, numPartitions)
686+
return DataFrame(spark: self.spark, plan: plan)
687+
}
688+
689+
/// Returns a new ``DataFrame`` that has exactly `numPartitions` partitions.
690+
/// - Parameter numPartitions: The number of partitions.
691+
/// - Returns: A `DataFrame`.
692+
public func repartition(_ numPartitions: Int32) -> DataFrame {
693+
return buildRepartition(numPartitions: numPartitions, shuffle: true)
694+
}
695+
696+
/// Returns a new ``DataFrame`` partitioned by the given partitioning expressions, using
697+
/// `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash
698+
/// partitioned.
699+
/// - Parameter partitionExprs: The partition expression strings.
700+
/// - Returns: A `DataFrame`.
701+
public func repartition(_ partitionExprs: String...) -> DataFrame {
702+
return buildRepartitionByExpression(numPartitions: nil, partitionExprs: partitionExprs)
703+
}
704+
705+
/// Returns a new ``DataFrame`` partitioned by the given partitioning expressions, using
706+
/// `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash
707+
/// partitioned.
708+
/// - Parameters:
709+
/// - numPartitions: The number of partitions.
710+
/// - partitionExprs: The partition expression strings.
711+
/// - Returns: A `DataFrame`.
712+
public func repartition(_ numPartitions: Int32, _ partitionExprs: String...) -> DataFrame {
713+
return buildRepartitionByExpression(numPartitions: numPartitions, partitionExprs: partitionExprs)
714+
}
715+
716+
/// Returns a new ``DataFrame`` partitioned by the given partitioning expressions, using
717+
/// `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash
718+
/// partitioned.
719+
/// - Parameter partitionExprs: The partition expression strings.
720+
/// - Returns: A `DataFrame`.
721+
public func repartitionByExpression(_ numPartitions: Int32?, _ partitionExprs: String...) -> DataFrame {
722+
return buildRepartitionByExpression(numPartitions: numPartitions, partitionExprs: partitionExprs)
723+
}
724+
725+
/// Returns a new ``DataFrame`` that has exactly `numPartitions` partitions, when the fewer partitions
726+
/// are requested. If a larger number of partitions is requested, it will stay at the current
727+
/// number of partitions. Similar to coalesce defined on an `RDD`, this operation results in a
728+
/// narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a
729+
/// shuffle, instead each of the 100 new partitions will claim 10 of the current partitions.
730+
/// - Parameter numPartitions: The number of partitions.
731+
/// - Returns: A `DataFrame`.
732+
public func coalesce(_ numPartitions: Int32) -> DataFrame {
733+
return buildRepartition(numPartitions: numPartitions, shuffle: false)
734+
}
735+
679736
/// Returns a ``DataFrameWriter`` that can be used to write non-streaming data.
680737
public var write: DataFrameWriter {
681738
get {

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,38 @@ public actor SparkConnectClient {
628628
})
629629
}
630630

631+
static func getRepartition(_ child: Relation, _ numPartitions: Int32, _ shuffle: Bool = false) -> Plan {
632+
var repartition = Repartition()
633+
repartition.input = child
634+
repartition.numPartitions = numPartitions
635+
repartition.shuffle = shuffle
636+
var relation = Relation()
637+
relation.repartition = repartition
638+
var plan = Plan()
639+
plan.opType = .root(relation)
640+
return plan
641+
}
642+
643+
static func getRepartitionByExpression(
644+
_ child: Relation, _ partitionExprs: [String], _ numPartitions: Int32? = nil
645+
) -> Plan {
646+
var repartitionByExpression = RepartitionByExpression()
647+
repartitionByExpression.input = child
648+
repartitionByExpression.partitionExprs = partitionExprs.map {
649+
var expr = Spark_Connect_Expression()
650+
expr.expressionString = $0.toExpressionString
651+
return expr
652+
}
653+
if let numPartitions {
654+
repartitionByExpression.numPartitions = numPartitions
655+
}
656+
var relation = Relation()
657+
relation.repartitionByExpression = repartitionByExpression
658+
var plan = Plan()
659+
plan.opType = .root(relation)
660+
return plan
661+
}
662+
631663
private enum URIParams {
632664
static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size"
633665
static let PARAM_SESSION_ID = "session_id"

Sources/SparkConnect/TypeAliases.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ typealias Project = Spark_Connect_Project
4242
typealias Range = Spark_Connect_Range
4343
typealias Read = Spark_Connect_Read
4444
typealias Relation = Spark_Connect_Relation
45+
typealias Repartition = Spark_Connect_Repartition
46+
typealias RepartitionByExpression = Spark_Connect_RepartitionByExpression
4547
typealias Sample = Spark_Connect_Sample
4648
typealias SaveMode = Spark_Connect_WriteOperation.SaveMode
4749
typealias SetOperation = Spark_Connect_SetOperation

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
// under the License.
1818
//
1919

20+
import Foundation
2021
import Testing
2122

2223
import SparkConnect
@@ -530,6 +531,52 @@ struct DataFrameTests {
530531
#expect(try await df3.unionByName(df3).count() == 4)
531532
await spark.stop()
532533
}
534+
535+
@Test
536+
func repartition() async throws {
537+
let spark = try await SparkSession.builder.getOrCreate()
538+
let tmpDir = "/tmp/" + UUID().uuidString
539+
let df = try await spark.range(2025)
540+
for n in [1, 3, 5] as [Int32] {
541+
try await df.repartition(n).write.mode("overwrite").orc(tmpDir)
542+
#expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
543+
}
544+
try await spark.range(1).repartition(10).write.mode("overwrite").orc(tmpDir)
545+
#expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
546+
await spark.stop()
547+
}
548+
549+
@Test
550+
func repartitionByExpression() async throws {
551+
let spark = try await SparkSession.builder.getOrCreate()
552+
let tmpDir = "/tmp/" + UUID().uuidString
553+
let df = try await spark.range(2025)
554+
for n in [1, 3, 5] as [Int32] {
555+
try await df.repartition(n, "id").write.mode("overwrite").orc(tmpDir)
556+
#expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
557+
try await df.repartitionByExpression(n, "id").write.mode("overwrite").orc(tmpDir)
558+
#expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
559+
}
560+
try await spark.range(1).repartition(10, "id").write.mode("overwrite").orc(tmpDir)
561+
#expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
562+
try await spark.range(1).repartition("id").write.mode("overwrite").orc(tmpDir)
563+
#expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
564+
await spark.stop()
565+
}
566+
567+
@Test
568+
func coalesce() async throws {
569+
let spark = try await SparkSession.builder.getOrCreate()
570+
let tmpDir = "/tmp/" + UUID().uuidString
571+
let df = try await spark.range(2025)
572+
for n in [1, 2, 3] as [Int32] {
573+
try await df.coalesce(n).write.mode("overwrite").orc(tmpDir)
574+
#expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
575+
}
576+
try await spark.range(1).coalesce(10).write.mode("overwrite").orc(tmpDir)
577+
#expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
578+
await spark.stop()
579+
}
533580
#endif
534581

535582
@Test

0 commit comments

Comments
 (0)