diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index c521df4..31655b2 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -521,6 +521,80 @@ public actor DataFrame: Sendable { } } + /// Join with another `DataFrame`. + /// Behaves as an INNER JOIN and requires a subsequent join predicate. + /// - Parameter right: Right side of the join operation. + /// - Returns: A `DataFrame`. + public func join(_ right: DataFrame) async -> DataFrame { + let right = await (right.getPlan() as! Plan).root + let plan = SparkConnectClient.getJoin(self.plan.root, right, JoinType.inner) + return DataFrame(spark: self.spark, plan: plan) + } + + /// Equi-join with another `DataFrame` using the given column. A cross join with a predicate is + /// specified as an inner join. If you would explicitly like to perform a cross join use the + /// `crossJoin` method. + /// - Parameters: + /// - right: Right side of the join operation. + /// - usingColumn: Name of the column to join on. This column must exist on both sides. + /// - joinType: Type of join to perform. Default `inner`. + /// - Returns: <#description#> + public func join(_ right: DataFrame, _ usingColumn: String, _ joinType: String = "inner") async -> DataFrame { + await join(right, [usingColumn], joinType) + } + + /// Inner equi-join with another `DataFrame` using the given columns. + /// - Parameters: + /// - right: Right side of the join operation. + /// - usingColumn: Names of the columns to join on. These columns must exist on both sides. + /// - joinType: A join type name. + /// - Returns: A `DataFrame`. + public func join(_ other: DataFrame, _ usingColumns: [String], _ joinType: String = "inner") async -> DataFrame { + let right = await (other.getPlan() as! Plan).root + let plan = SparkConnectClient.getJoin( + self.plan.root, + right, + joinType.toJoinType, + usingColumns: usingColumns + ) + return DataFrame(spark: self.spark, plan: plan) + } + + /// Inner equi-join with another `DataFrame` using the given columns. + /// - Parameters: + /// - right: Right side of the join operation. + /// - joinExprs:A join expression string. + /// - Returns: A `DataFrame`. + public func join(_ right: DataFrame, joinExprs: String) async -> DataFrame { + return await join(right, joinExprs: joinExprs, joinType: "inner") + } + + /// Inner equi-join with another `DataFrame` using the given columns. + /// - Parameters: + /// - right: Right side of the join operation. + /// - joinExprs:A join expression string. + /// - joinType: A join type name. + /// - Returns: A `DataFrame`. + public func join(_ right: DataFrame, joinExprs: String, joinType: String) async -> DataFrame { + let rightPlan = await (right.getPlan() as! Plan).root + let plan = SparkConnectClient.getJoin( + self.plan.root, + rightPlan, + joinType.toJoinType, + joinCondition: joinExprs + ) + return DataFrame(spark: self.spark, plan: plan) + } + + /// Explicit cartesian join with another `DataFrame`. + /// - Parameter right: Right side of the join operation. + /// - Returns: Cartesian joins are very expensive without an extra filter that can be pushed down. + public func crossJoin(_ right: DataFrame) async -> DataFrame { + let rightPlan = await (right.getPlan() as! Plan).root + let plan = SparkConnectClient.getJoin(self.plan.root, rightPlan, JoinType.cross) + return DataFrame(spark: self.spark, plan: plan) + } + /// Returns a new `DataFrame` containing rows in this `DataFrame` but not in another `DataFrame`. /// This is equivalent to `EXCEPT DISTINCT` in SQL. /// - Parameter other: A `DataFrame` to exclude. diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index 0594ea2..d41b5b1 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -80,6 +80,19 @@ extension String { default: SaveMode.errorIfExists } } + + var toJoinType: JoinType { + return switch self.lowercased() { + case "inner": JoinType.inner + case "cross": JoinType.cross + case "outer", "full", "fullouter", "full_outer": JoinType.fullOuter + case "left", "leftouter", "left_outer": JoinType.leftOuter + case "right", "rightouter", "right_outer": JoinType.rightOuter + case "semi", "leftsemi", "left_semi": JoinType.leftSemi + case "anti", "leftanti", "left_anti": JoinType.leftAnti + default: JoinType.inner + } + } } extension [String: String] { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 3b3f6de..f71ad83 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -557,6 +557,28 @@ public actor SparkConnectClient { } } + static func getJoin( + _ left: Relation, _ right: Relation, _ joinType: JoinType, + joinCondition: String? = nil, usingColumns: [String]? = nil + ) -> Plan { + var join = Join() + join.left = left + join.right = right + join.joinType = joinType + if let joinCondition { + join.joinCondition.expressionString = joinCondition.toExpressionString + } + if let usingColumns { + join.usingColumns = usingColumns + } + // join.joinDataType = Join.JoinDataType() + var relation = Relation() + relation.join = join + var plan = Plan() + plan.opType = .root(relation) + return plan + } + static func getSetOperation( _ left: Relation, _ right: Relation, _ opType: SetOpType, isAll: Bool = false, byName: Bool = false, allowMissingColumns: Bool = false diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 34632ce..1107c52 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -29,6 +29,8 @@ typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode typealias ExpressionString = Spark_Connect_Expression.ExpressionString typealias Filter = Spark_Connect_Filter +typealias Join = Spark_Connect_Join +typealias JoinType = Spark_Connect_Join.JoinType typealias KeyValue = Spark_Connect_KeyValue typealias Limit = Spark_Connect_Limit typealias MapType = Spark_Connect_DataType.Map diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index ba53923..ee220c3 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -414,6 +414,43 @@ struct DataFrameTests { await spark.stop() } + @Test + func join() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df1 = try await spark.sql("SELECT * FROM VALUES ('a', '1'), ('b', '2') AS T(a, b)") + let df2 = try await spark.sql("SELECT * FROM VALUES ('c', '2'), ('d', '3') AS S(c, b)") + let expectedCross = [ + Row("a", "1", "c", "2"), + Row("a", "1", "d", "3"), + Row("b", "2", "c", "2"), + Row("b", "2", "d", "3"), + ] + #expect(try await df1.join(df2).collect() == expectedCross) + #expect(try await df1.crossJoin(df2).collect() == expectedCross) + + #expect(try await df1.join(df2, "b").collect() == [Row("2", "b", "c")]) + #expect(try await df1.join(df2, ["b"]).collect() == [Row("2", "b", "c")]) + + #expect(try await df1.join(df2, "b", "left").collect() == [Row("1", "a", nil), Row("2", "b", "c")]) + #expect(try await df1.join(df2, "b", "right").collect() == [Row("2", "b", "c"), Row("3", nil, "d")]) + #expect(try await df1.join(df2, "b", "semi").collect() == [Row("2", "b")]) + #expect(try await df1.join(df2, "b", "anti").collect() == [Row("1", "a")]) + + let expectedOuter = [ + Row("1", "a", nil), + Row("2", "b", "c"), + Row("3", nil, "d"), + ] + #expect(try await df1.join(df2, "b", "outer").collect() == expectedOuter) + #expect(try await df1.join(df2, "b", "full").collect() == expectedOuter) + #expect(try await df1.join(df2, ["b"], "full").collect() == expectedOuter) + + let expected = [Row("b", "2", "c", "2")] + #expect(try await df1.join(df2, joinExprs: "T.b = S.b").collect() == expected) + #expect(try await df1.join(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected) + await spark.stop() + } + @Test func except() async throws { let spark = try await SparkSession.builder.getOrCreate()