Skip to content

Commit 129cb59

Browse files
committed
[SPARK-51863] Support join and crossJoin in DataFrame
### What changes were proposed in this pull request? This PR aims to support `join` and `crossJoin` in `DataFrame`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #80 from dongjoon-hyun/SPARK-51863. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 7fe4882 commit 129cb59

File tree

5 files changed

+148
-0
lines changed

5 files changed

+148
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,80 @@ public actor DataFrame: Sendable {
521521
}
522522
}
523523

524+
/// Join with another `DataFrame`.
525+
/// Behaves as an INNER JOIN and requires a subsequent join predicate.
526+
/// - Parameter right: Right side of the join operation.
527+
/// - Returns: A `DataFrame`.
528+
public func join(_ right: DataFrame) async -> DataFrame {
529+
let right = await (right.getPlan() as! Plan).root
530+
let plan = SparkConnectClient.getJoin(self.plan.root, right, JoinType.inner)
531+
return DataFrame(spark: self.spark, plan: plan)
532+
}
533+
534+
/// Equi-join with another `DataFrame` using the given column. A cross join with a predicate is
535+
/// specified as an inner join. If you would explicitly like to perform a cross join use the
536+
/// `crossJoin` method.
537+
/// - Parameters:
538+
/// - right: Right side of the join operation.
539+
/// - usingColumn: Name of the column to join on. This column must exist on both sides.
540+
/// - joinType: Type of join to perform. Default `inner`.
541+
/// - Returns: <#description#>
542+
public func join(_ right: DataFrame, _ usingColumn: String, _ joinType: String = "inner") async -> DataFrame {
543+
await join(right, [usingColumn], joinType)
544+
}
545+
546+
/// Inner equi-join with another `DataFrame` using the given columns.
547+
/// - Parameters:
548+
/// - right: Right side of the join operation.
549+
/// - usingColumn: Names of the columns to join on. These columns must exist on both sides.
550+
/// - joinType: A join type name.
551+
/// - Returns: A `DataFrame`.
552+
public func join(_ other: DataFrame, _ usingColumns: [String], _ joinType: String = "inner") async -> DataFrame {
553+
let right = await (other.getPlan() as! Plan).root
554+
let plan = SparkConnectClient.getJoin(
555+
self.plan.root,
556+
right,
557+
joinType.toJoinType,
558+
usingColumns: usingColumns
559+
)
560+
return DataFrame(spark: self.spark, plan: plan)
561+
}
562+
563+
/// Inner equi-join with another `DataFrame` using the given columns.
564+
/// - Parameters:
565+
/// - right: Right side of the join operation.
566+
/// - joinExprs:A join expression string.
567+
/// - Returns: A `DataFrame`.
568+
public func join(_ right: DataFrame, joinExprs: String) async -> DataFrame {
569+
return await join(right, joinExprs: joinExprs, joinType: "inner")
570+
}
571+
572+
/// Inner equi-join with another `DataFrame` using the given columns.
573+
/// - Parameters:
574+
/// - right: Right side of the join operation.
575+
/// - joinExprs:A join expression string.
576+
/// - joinType: A join type name.
577+
/// - Returns: A `DataFrame`.
578+
public func join(_ right: DataFrame, joinExprs: String, joinType: String) async -> DataFrame {
579+
let rightPlan = await (right.getPlan() as! Plan).root
580+
let plan = SparkConnectClient.getJoin(
581+
self.plan.root,
582+
rightPlan,
583+
joinType.toJoinType,
584+
joinCondition: joinExprs
585+
)
586+
return DataFrame(spark: self.spark, plan: plan)
587+
}
588+
589+
/// Explicit cartesian join with another `DataFrame`.
590+
/// - Parameter right: Right side of the join operation.
591+
/// - Returns: Cartesian joins are very expensive without an extra filter that can be pushed down.
592+
public func crossJoin(_ right: DataFrame) async -> DataFrame {
593+
let rightPlan = await (right.getPlan() as! Plan).root
594+
let plan = SparkConnectClient.getJoin(self.plan.root, rightPlan, JoinType.cross)
595+
return DataFrame(spark: self.spark, plan: plan)
596+
}
597+
524598
/// Returns a new `DataFrame` containing rows in this `DataFrame` but not in another `DataFrame`.
525599
/// This is equivalent to `EXCEPT DISTINCT` in SQL.
526600
/// - Parameter other: A `DataFrame` to exclude.

Sources/SparkConnect/Extension.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ extension String {
8080
default: SaveMode.errorIfExists
8181
}
8282
}
83+
84+
var toJoinType: JoinType {
85+
return switch self.lowercased() {
86+
case "inner": JoinType.inner
87+
case "cross": JoinType.cross
88+
case "outer", "full", "fullouter", "full_outer": JoinType.fullOuter
89+
case "left", "leftouter", "left_outer": JoinType.leftOuter
90+
case "right", "rightouter", "right_outer": JoinType.rightOuter
91+
case "semi", "leftsemi", "left_semi": JoinType.leftSemi
92+
case "anti", "leftanti", "left_anti": JoinType.leftAnti
93+
default: JoinType.inner
94+
}
95+
}
8396
}
8497

8598
extension [String: String] {

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,28 @@ public actor SparkConnectClient {
557557
}
558558
}
559559

560+
static func getJoin(
561+
_ left: Relation, _ right: Relation, _ joinType: JoinType,
562+
joinCondition: String? = nil, usingColumns: [String]? = nil
563+
) -> Plan {
564+
var join = Join()
565+
join.left = left
566+
join.right = right
567+
join.joinType = joinType
568+
if let joinCondition {
569+
join.joinCondition.expressionString = joinCondition.toExpressionString
570+
}
571+
if let usingColumns {
572+
join.usingColumns = usingColumns
573+
}
574+
// join.joinDataType = Join.JoinDataType()
575+
var relation = Relation()
576+
relation.join = join
577+
var plan = Plan()
578+
plan.opType = .root(relation)
579+
return plan
580+
}
581+
560582
static func getSetOperation(
561583
_ left: Relation, _ right: Relation, _ opType: SetOpType, isAll: Bool = false,
562584
byName: Bool = false, allowMissingColumns: Bool = false

Sources/SparkConnect/TypeAliases.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse
2929
typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
3030
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
3131
typealias Filter = Spark_Connect_Filter
32+
typealias Join = Spark_Connect_Join
33+
typealias JoinType = Spark_Connect_Join.JoinType
3234
typealias KeyValue = Spark_Connect_KeyValue
3335
typealias Limit = Spark_Connect_Limit
3436
typealias MapType = Spark_Connect_DataType.Map

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,43 @@ struct DataFrameTests {
414414
await spark.stop()
415415
}
416416

417+
@Test
418+
func join() async throws {
419+
let spark = try await SparkSession.builder.getOrCreate()
420+
let df1 = try await spark.sql("SELECT * FROM VALUES ('a', '1'), ('b', '2') AS T(a, b)")
421+
let df2 = try await spark.sql("SELECT * FROM VALUES ('c', '2'), ('d', '3') AS S(c, b)")
422+
let expectedCross = [
423+
Row("a", "1", "c", "2"),
424+
Row("a", "1", "d", "3"),
425+
Row("b", "2", "c", "2"),
426+
Row("b", "2", "d", "3"),
427+
]
428+
#expect(try await df1.join(df2).collect() == expectedCross)
429+
#expect(try await df1.crossJoin(df2).collect() == expectedCross)
430+
431+
#expect(try await df1.join(df2, "b").collect() == [Row("2", "b", "c")])
432+
#expect(try await df1.join(df2, ["b"]).collect() == [Row("2", "b", "c")])
433+
434+
#expect(try await df1.join(df2, "b", "left").collect() == [Row("1", "a", nil), Row("2", "b", "c")])
435+
#expect(try await df1.join(df2, "b", "right").collect() == [Row("2", "b", "c"), Row("3", nil, "d")])
436+
#expect(try await df1.join(df2, "b", "semi").collect() == [Row("2", "b")])
437+
#expect(try await df1.join(df2, "b", "anti").collect() == [Row("1", "a")])
438+
439+
let expectedOuter = [
440+
Row("1", "a", nil),
441+
Row("2", "b", "c"),
442+
Row("3", nil, "d"),
443+
]
444+
#expect(try await df1.join(df2, "b", "outer").collect() == expectedOuter)
445+
#expect(try await df1.join(df2, "b", "full").collect() == expectedOuter)
446+
#expect(try await df1.join(df2, ["b"], "full").collect() == expectedOuter)
447+
448+
let expected = [Row("b", "2", "c", "2")]
449+
#expect(try await df1.join(df2, joinExprs: "T.b = S.b").collect() == expected)
450+
#expect(try await df1.join(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected)
451+
await spark.stop()
452+
}
453+
417454
@Test
418455
func except() async throws {
419456
let spark = try await SparkSession.builder.getOrCreate()

0 commit comments

Comments
 (0)