Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,80 @@ public actor DataFrame: Sendable {
}
}

/// Join with another `DataFrame`.
/// Behaves as an INNER JOIN and requires a subsequent join predicate.
/// - Parameter right: Right side of the join operation.
/// - Returns: A `DataFrame`.
public func join(_ right: DataFrame) async -> DataFrame {
let right = await (right.getPlan() as! Plan).root
let plan = SparkConnectClient.getJoin(self.plan.root, right, JoinType.inner)
return DataFrame(spark: self.spark, plan: plan)
}

/// Equi-join with another `DataFrame` using the given column. A cross join with a predicate is
/// specified as an inner join. If you would explicitly like to perform a cross join use the
/// `crossJoin` method.
/// - Parameters:
/// - right: Right side of the join operation.
/// - usingColumn: Name of the column to join on. This column must exist on both sides.
/// - joinType: Type of join to perform. Default `inner`.
/// - Returns: <#description#>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#description# is some kind of placeholder?

public func join(_ right: DataFrame, _ usingColumn: String, _ joinType: String = "inner") async -> DataFrame {
await join(right, [usingColumn], joinType)
}

/// Inner equi-join with another `DataFrame` using the given columns.
/// - Parameters:
/// - right: Right side of the join operation.
/// - usingColumn: Names of the columns to join on. These columns must exist on both sides.
/// - joinType: A join type name.
/// - Returns: A `DataFrame`.
public func join(_ other: DataFrame, _ usingColumns: [String], _ joinType: String = "inner") async -> DataFrame {
let right = await (other.getPlan() as! Plan).root
let plan = SparkConnectClient.getJoin(
self.plan.root,
right,
joinType.toJoinType,
usingColumns: usingColumns
)
return DataFrame(spark: self.spark, plan: plan)
}

/// Inner equi-join with another `DataFrame` using the given columns.
/// - Parameters:
/// - right: Right side of the join operation.
/// - joinExprs:A join expression string.
/// - Returns: A `DataFrame`.
public func join(_ right: DataFrame, joinExprs: String) async -> DataFrame {
return await join(right, joinExprs: joinExprs, joinType: "inner")
}

/// Inner equi-join with another `DataFrame` using the given columns.
/// - Parameters:
/// - right: Right side of the join operation.
/// - joinExprs:A join expression string.
/// - joinType: A join type name.
/// - Returns: A `DataFrame`.
public func join(_ right: DataFrame, joinExprs: String, joinType: String) async -> DataFrame {
let rightPlan = await (right.getPlan() as! Plan).root
let plan = SparkConnectClient.getJoin(
self.plan.root,
rightPlan,
joinType.toJoinType,
joinCondition: joinExprs
)
return DataFrame(spark: self.spark, plan: plan)
}

/// Explicit cartesian join with another `DataFrame`.
/// - Parameter right: Right side of the join operation.
/// - Returns: Cartesian joins are very expensive without an extra filter that can be pushed down.
public func crossJoin(_ right: DataFrame) async -> DataFrame {
let rightPlan = await (right.getPlan() as! Plan).root
let plan = SparkConnectClient.getJoin(self.plan.root, rightPlan, JoinType.cross)
return DataFrame(spark: self.spark, plan: plan)
}

/// Returns a new `DataFrame` containing rows in this `DataFrame` but not in another `DataFrame`.
/// This is equivalent to `EXCEPT DISTINCT` in SQL.
/// - Parameter other: A `DataFrame` to exclude.
Expand Down
13 changes: 13 additions & 0 deletions Sources/SparkConnect/Extension.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ extension String {
default: SaveMode.errorIfExists
}
}

var toJoinType: JoinType {
return switch self.lowercased() {
case "inner": JoinType.inner
case "cross": JoinType.cross
case "outer", "full", "fullouter", "full_outer": JoinType.fullOuter
case "left", "leftouter", "left_outer": JoinType.leftOuter
case "right", "rightouter", "right_outer": JoinType.rightOuter
case "semi", "leftsemi", "left_semi": JoinType.leftSemi
case "anti", "leftanti", "left_anti": JoinType.leftAnti
default: JoinType.inner
}
}
}

extension [String: String] {
Expand Down
22 changes: 22 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,28 @@ public actor SparkConnectClient {
}
}

static func getJoin(
_ left: Relation, _ right: Relation, _ joinType: JoinType,
joinCondition: String? = nil, usingColumns: [String]? = nil
) -> Plan {
var join = Join()
join.left = left
join.right = right
join.joinType = joinType
if let joinCondition {
join.joinCondition.expressionString = joinCondition.toExpressionString
}
if let usingColumns {
join.usingColumns = usingColumns
}
// join.joinDataType = Join.JoinDataType()
var relation = Relation()
relation.join = join
var plan = Plan()
plan.opType = .root(relation)
return plan
}

static func getSetOperation(
_ left: Relation, _ right: Relation, _ opType: SetOpType, isAll: Bool = false,
byName: Bool = false, allowMissingColumns: Bool = false
Expand Down
2 changes: 2 additions & 0 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ typealias ExecutePlanResponse = Spark_Connect_ExecutePlanResponse
typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode
typealias ExpressionString = Spark_Connect_Expression.ExpressionString
typealias Filter = Spark_Connect_Filter
typealias Join = Spark_Connect_Join
typealias JoinType = Spark_Connect_Join.JoinType
typealias KeyValue = Spark_Connect_KeyValue
typealias Limit = Spark_Connect_Limit
typealias MapType = Spark_Connect_DataType.Map
Expand Down
37 changes: 37 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,43 @@ struct DataFrameTests {
await spark.stop()
}

@Test
func join() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df1 = try await spark.sql("SELECT * FROM VALUES ('a', '1'), ('b', '2') AS T(a, b)")
let df2 = try await spark.sql("SELECT * FROM VALUES ('c', '2'), ('d', '3') AS S(c, b)")
let expectedCross = [
Row("a", "1", "c", "2"),
Row("a", "1", "d", "3"),
Row("b", "2", "c", "2"),
Row("b", "2", "d", "3"),
]
#expect(try await df1.join(df2).collect() == expectedCross)
#expect(try await df1.crossJoin(df2).collect() == expectedCross)

#expect(try await df1.join(df2, "b").collect() == [Row("2", "b", "c")])
#expect(try await df1.join(df2, ["b"]).collect() == [Row("2", "b", "c")])

#expect(try await df1.join(df2, "b", "left").collect() == [Row("1", "a", nil), Row("2", "b", "c")])
#expect(try await df1.join(df2, "b", "right").collect() == [Row("2", "b", "c"), Row("3", nil, "d")])
#expect(try await df1.join(df2, "b", "semi").collect() == [Row("2", "b")])
#expect(try await df1.join(df2, "b", "anti").collect() == [Row("1", "a")])

let expectedOuter = [
Row("1", "a", nil),
Row("2", "b", "c"),
Row("3", nil, "d"),
]
#expect(try await df1.join(df2, "b", "outer").collect() == expectedOuter)
#expect(try await df1.join(df2, "b", "full").collect() == expectedOuter)
#expect(try await df1.join(df2, ["b"], "full").collect() == expectedOuter)

let expected = [Row("b", "2", "c", "2")]
#expect(try await df1.join(df2, joinExprs: "T.b = S.b").collect() == expected)
#expect(try await df1.join(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected)
await spark.stop()
}

@Test
func except() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down
Loading