Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ public actor DataFrame: Sendable {
/// - right: Right side of the join operation.
/// - usingColumn: Name of the column to join on. This column must exist on both sides.
/// - joinType: Type of join to perform. Default `inner`.
/// - Returns: <#description#>
/// - Returns: A `DataFrame`.
public func join(_ right: DataFrame, _ usingColumn: String, _ joinType: String = "inner") async -> DataFrame {
await join(right, [usingColumn], joinType)
}
Expand Down Expand Up @@ -588,7 +588,7 @@ public actor DataFrame: Sendable {

/// Explicit cartesian join with another `DataFrame`.
/// - Parameter right: Right side of the join operation.
/// - Returns: Cartesian joins are very expensive without an extra filter that can be pushed down.
/// - Returns: A `DataFrame`.
public func crossJoin(_ right: DataFrame) async -> DataFrame {
let rightPlan = await (right.getPlan() as! Plan).root
let plan = SparkConnectClient.getJoin(self.plan.root, rightPlan, JoinType.cross)
Expand Down Expand Up @@ -676,6 +676,63 @@ public actor DataFrame: Sendable {
return DataFrame(spark: self.spark, plan: plan)
}

private func buildRepartition(numPartitions: Int32, shuffle: Bool) -> DataFrame {
let plan = SparkConnectClient.getRepartition(self.plan.root, numPartitions, shuffle)
return DataFrame(spark: self.spark, plan: plan)
}

private func buildRepartitionByExpression(numPartitions: Int32?, partitionExprs: [String]) -> DataFrame {
let plan = SparkConnectClient.getRepartitionByExpression(self.plan.root, partitionExprs, numPartitions)
return DataFrame(spark: self.spark, plan: plan)
}

/// Returns a new ``DataFrame`` that has exactly `numPartitions` partitions.
/// - Parameter numPartitions: The number of partitions.
/// - Returns: A `DataFrame`.
public func repartition(_ numPartitions: Int32) -> DataFrame {
return buildRepartition(numPartitions: numPartitions, shuffle: true)
}

/// Returns a new ``DataFrame`` partitioned by the given partitioning expressions, using
/// `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash
/// partitioned.
/// - Parameter partitionExprs: The partition expression strings.
/// - Returns: A `DataFrame`.
public func repartition(_ partitionExprs: String...) -> DataFrame {
return buildRepartitionByExpression(numPartitions: nil, partitionExprs: partitionExprs)
}

/// Returns a new ``DataFrame`` partitioned by the given partitioning expressions, using
/// `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash
/// partitioned.
/// - Parameters:
/// - numPartitions: The number of partitions.
/// - partitionExprs: The partition expression strings.
/// - Returns: A `DataFrame`.
public func repartition(_ numPartitions: Int32, _ partitionExprs: String...) -> DataFrame {
return buildRepartitionByExpression(numPartitions: numPartitions, partitionExprs: partitionExprs)
}

/// Returns a new ``DataFrame`` partitioned by the given partitioning expressions, using
/// `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash
/// partitioned.
/// - Parameter partitionExprs: The partition expression strings.
/// - Returns: A `DataFrame`.
public func repartitionByExpression(_ numPartitions: Int32?, _ partitionExprs: String...) -> DataFrame {
return buildRepartitionByExpression(numPartitions: numPartitions, partitionExprs: partitionExprs)
}

/// Returns a new ``DataFrame`` that has exactly `numPartitions` partitions, when the fewer partitions
/// are requested. If a larger number of partitions is requested, it will stay at the current
/// number of partitions. Similar to coalesce defined on an `RDD`, this operation results in a
/// narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a
/// shuffle, instead each of the 100 new partitions will claim 10 of the current partitions.
/// - Parameter numPartitions: The number of partitions.
/// - Returns: A `DataFrame`.
public func coalesce(_ numPartitions: Int32) -> DataFrame {
return buildRepartition(numPartitions: numPartitions, shuffle: false)
}

/// Returns a ``DataFrameWriter`` that can be used to write non-streaming data.
public var write: DataFrameWriter {
get {
Expand Down
32 changes: 32 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,38 @@ public actor SparkConnectClient {
})
}

static func getRepartition(_ child: Relation, _ numPartitions: Int32, _ shuffle: Bool = false) -> Plan {
var repartition = Repartition()
repartition.input = child
repartition.numPartitions = numPartitions
repartition.shuffle = shuffle
var relation = Relation()
relation.repartition = repartition
var plan = Plan()
plan.opType = .root(relation)
return plan
}

static func getRepartitionByExpression(
_ child: Relation, _ partitionExprs: [String], _ numPartitions: Int32? = nil
) -> Plan {
var repartitionByExpression = RepartitionByExpression()
repartitionByExpression.input = child
repartitionByExpression.partitionExprs = partitionExprs.map {
var expr = Spark_Connect_Expression()
expr.expressionString = $0.toExpressionString
return expr
}
if let numPartitions {
repartitionByExpression.numPartitions = numPartitions
}
var relation = Relation()
relation.repartitionByExpression = repartitionByExpression
var plan = Plan()
plan.opType = .root(relation)
return plan
}

private enum URIParams {
static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size"
static let PARAM_SESSION_ID = "session_id"
Expand Down
2 changes: 2 additions & 0 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ typealias Project = Spark_Connect_Project
typealias Range = Spark_Connect_Range
typealias Read = Spark_Connect_Read
typealias Relation = Spark_Connect_Relation
typealias Repartition = Spark_Connect_Repartition
typealias RepartitionByExpression = Spark_Connect_RepartitionByExpression
typealias Sample = Spark_Connect_Sample
typealias SaveMode = Spark_Connect_WriteOperation.SaveMode
typealias SetOperation = Spark_Connect_SetOperation
Expand Down
47 changes: 47 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
// under the License.
//

import Foundation
import Testing

import SparkConnect
Expand Down Expand Up @@ -530,6 +531,52 @@ struct DataFrameTests {
#expect(try await df3.unionByName(df3).count() == 4)
await spark.stop()
}

@Test
func repartition() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let tmpDir = "/tmp/" + UUID().uuidString
let df = try await spark.range(2025)
for n in [1, 3, 5] as [Int32] {
try await df.repartition(n).write.mode("overwrite").orc(tmpDir)
#expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
}
try await spark.range(1).repartition(10).write.mode("overwrite").orc(tmpDir)
#expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
await spark.stop()
}

@Test
func repartitionByExpression() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let tmpDir = "/tmp/" + UUID().uuidString
let df = try await spark.range(2025)
for n in [1, 3, 5] as [Int32] {
try await df.repartition(n, "id").write.mode("overwrite").orc(tmpDir)
#expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
try await df.repartitionByExpression(n, "id").write.mode("overwrite").orc(tmpDir)
#expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
}
try await spark.range(1).repartition(10, "id").write.mode("overwrite").orc(tmpDir)
#expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
try await spark.range(1).repartition("id").write.mode("overwrite").orc(tmpDir)
#expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
await spark.stop()
}

@Test
func coalesce() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let tmpDir = "/tmp/" + UUID().uuidString
let df = try await spark.range(2025)
for n in [1, 2, 3] as [Int32] {
try await df.coalesce(n).write.mode("overwrite").orc(tmpDir)
#expect(try await spark.read.orc(tmpDir).inputFiles().count == n)
}
try await spark.range(1).coalesce(10).write.mode("overwrite").orc(tmpDir)
#expect(try await spark.read.orc(tmpDir).inputFiles().count < 10)
await spark.stop()
}
#endif

@Test
Expand Down
Loading