Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,87 @@ public actor DataFrame: Sendable {
}
}

/// Returns a new `DataFrame` containing rows in this `DataFrame` but not in another `DataFrame`.
/// This is equivalent to `EXCEPT DISTINCT` in SQL.
/// - Parameter other: A `DataFrame` to exclude.
/// - Returns: A `DataFrame`.
public func except(_ other: DataFrame) async -> DataFrame {
let right = await (other.getPlan() as! Plan).root
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.except)
return DataFrame(spark: self.spark, plan: plan)
}

/// Returns a new `DataFrame` containing rows in this `DataFrame` but not in another `DataFrame` while
/// preserving the duplicates. This is equivalent to `EXCEPT ALL` in SQL.
/// - Parameter other: A `DataFrame` to exclude.
/// - Returns: A `DataFrame`.
public func exceptAll(_ other: DataFrame) async -> DataFrame {
let right = await (other.getPlan() as! Plan).root
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.except, isAll: true)
return DataFrame(spark: self.spark, plan: plan)
}

/// Returns a new `DataFrame` containing rows only in both this `DataFrame` and another `DataFrame`.
/// This is equivalent to `INTERSECT` in SQL.
/// - Parameter other: A `DataFrame` to intersect with.
/// - Returns: A `DataFrame`.
public func intersect(_ other: DataFrame) async -> DataFrame {
let right = await (other.getPlan() as! Plan).root
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.intersect)
return DataFrame(spark: self.spark, plan: plan)
}

/// Returns a new `DataFrame` containing rows only in both this `DataFrame` and another `DataFrame` while
/// preserving the duplicates. This is equivalent to `INTERSECT ALL` in SQL.
/// - Parameter other: A `DataFrame` to intersect with.
/// - Returns: A `DataFrame`.
public func intersectAll(_ other: DataFrame) async -> DataFrame {
let right = await (other.getPlan() as! Plan).root
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.intersect, isAll: true)
return DataFrame(spark: self.spark, plan: plan)
}

/// Returns a new `DataFrame` containing union of rows in this `DataFrame` and another `DataFrame`.
/// This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does
/// deduplication of elements), use this function followed by a [[distinct]].
/// Also as standard in SQL, this function resolves columns by position (not by name)
/// - Parameter other: A `DataFrame` to union with.
/// - Returns: A `DataFrame`.
public func union(_ other: DataFrame) async -> DataFrame {
let right = await (other.getPlan() as! Plan).root
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.union, isAll: true)
return DataFrame(spark: self.spark, plan: plan)
}

/// Returns a new `DataFrame` containing union of rows in this `DataFrame` and another `DataFrame`.
/// This is an alias of `union`.
/// - Parameter other: A `DataFrame` to union with.
/// - Returns: A `DataFrame`.
public func unionAll(_ other: DataFrame) async -> DataFrame {
return await union(other)
}

/// Returns a new `DataFrame` containing union of rows in this `DataFrame` and another `DataFrame`.
/// The difference between this function and [[union]] is that this function resolves columns by
/// name (not by position).
/// When the parameter `allowMissingColumns` is `true`, the set of column names in this and other
/// `DataFrame` can differ; missing columns will be filled with null. Further, the missing columns
/// of this `DataFrame` will be added at the end in the schema of the union result
/// - Parameter other: A `DataFrame` to union with.
/// - Returns: A `DataFrame`.
public func unionByName(_ other: DataFrame, _ allowMissingColumns: Bool = false) async -> DataFrame {
let right = await (other.getPlan() as! Plan).root
let plan = SparkConnectClient.getSetOperation(
self.plan.root,
right,
SetOpType.union,
isAll: true,
byName: true,
allowMissingColumns: allowMissingColumns
Copy link
Member

@viirya viirya Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to set isAll: true for unionByName like union?

)
return DataFrame(spark: self.spark, plan: plan)
}

/// Returns a ``DataFrameWriter`` that can be used to write non-streaming data.
public var write: DataFrameWriter {
get {
Expand Down
18 changes: 18 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -527,4 +527,22 @@ public actor SparkConnectClient {
return response.jsonToDdl.ddlString
}
}

static func getSetOperation(
_ left: Relation, _ right: Relation, _ opType: SetOpType, isAll: Bool = false,
byName: Bool = false, allowMissingColumns: Bool = false
) -> Plan {
var setOp = SetOperation()
setOp.leftInput = left
setOp.rightInput = right
setOp.setOpType = opType
setOp.isAll = isAll
setOp.allowMissingColumns = allowMissingColumns
setOp.byName = byName
var relation = Relation()
relation.setOp = setOp
var plan = Plan()
plan.opType = .root(relation)
return plan
}
}
2 changes: 2 additions & 0 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ typealias Read = Spark_Connect_Read
typealias Relation = Spark_Connect_Relation
typealias Sample = Spark_Connect_Sample
typealias SaveMode = Spark_Connect_WriteOperation.SaveMode
typealias SetOperation = Spark_Connect_SetOperation
typealias SetOpType = SetOperation.SetOpType
typealias SparkConnectService = Spark_Connect_SparkConnectService
typealias Sort = Spark_Connect_Sort
typealias StructType = Spark_Connect_DataType.Struct
Expand Down
80 changes: 80 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,86 @@ struct DataFrameTests {
#expect(try await df.unpersist().count() == 30)
await spark.stop()
}

@Test
func except() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(1, 3)
#expect(try await df.except(spark.range(1, 5)).collect() == [])
#expect(try await df.except(spark.range(2, 5)).collect() == [Row("1")])
#expect(try await df.except(spark.range(3, 5)).collect() == [Row("1"), Row("2")])
#expect(try await spark.sql("SELECT * FROM VALUES 1, 1").except(df).count() == 0)
await spark.stop()
}

@Test
func exceptAll() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(1, 3)
#expect(try await df.exceptAll(spark.range(1, 5)).collect() == [])
#expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row("1")])
#expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row("1"), Row("2")])
#expect(try await spark.sql("SELECT * FROM VALUES 1, 1").exceptAll(df).count() == 1)
await spark.stop()
}

@Test
func intersect() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(1, 3)
#expect(try await df.intersect(spark.range(1, 5)).collect() == [Row("1"), Row("2")])
#expect(try await df.intersect(spark.range(2, 5)).collect() == [Row("2")])
#expect(try await df.intersect(spark.range(3, 5)).collect() == [])
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
#expect(try await df2.intersect(df2).count() == 1)
await spark.stop()
}

@Test
func intersectAll() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(1, 3)
#expect(try await df.intersectAll(spark.range(1, 5)).collect() == [Row("1"), Row("2")])
#expect(try await df.intersectAll(spark.range(2, 5)).collect() == [Row("2")])
#expect(try await df.intersectAll(spark.range(3, 5)).collect() == [])
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
#expect(try await df2.intersectAll(df2).count() == 2)
await spark.stop()
}

@Test
func union() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(1, 2)
#expect(try await df.union(spark.range(1, 3)).collect() == [Row("1"), Row("1"), Row("2")])
#expect(try await df.union(spark.range(2, 3)).collect() == [Row("1"), Row("2")])
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
#expect(try await df2.union(df2).count() == 4)
await spark.stop()
}

@Test
func unionAll() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(1, 2)
#expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row("1"), Row("1"), Row("2")])
#expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row("1"), Row("2")])
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
#expect(try await df2.unionAll(df2).count() == 4)
await spark.stop()
}

@Test
func unionByName() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df1 = try await spark.sql("SELECT 1 a, 2 b")
let df2 = try await spark.sql("SELECT 4 b, 3 a")
#expect(try await df1.unionByName(df2).collect() == [Row("1", "2"), Row("3", "4")])
#expect(try await df1.union(df2).collect() == [Row("1", "2"), Row("4", "3")])
let df3 = try await spark.sql("SELECT * FROM VALUES 1, 1")
#expect(try await df3.unionByName(df3).count() == 4)
await spark.stop()
}
#endif

@Test
Expand Down
Loading