Skip to content

Commit 1dde04c

Browse files
committed
[SPARK-51839] Support except(All)?/intersect(All)?/union(All)?/unionByName in DataFrame
### What changes were proposed in this pull request? This PR aims to support seven `Set`-related `DataFrame` APIs. - `except` - `exceptAll` - `intersect` - `intersectAll` - `union` - `unionAll` - `unionByName` ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No, this is a new addition to the unreleased version. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #68 from dongjoon-hyun/SPARK-51839. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent d1cd6d7 commit 1dde04c

File tree

4 files changed

+181
-0
lines changed

4 files changed

+181
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,87 @@ public actor DataFrame: Sendable {
522522
}
523523
}
524524

525+
/// Returns a new `DataFrame` containing rows in this `DataFrame` but not in another `DataFrame`.
526+
/// This is equivalent to `EXCEPT DISTINCT` in SQL.
527+
/// - Parameter other: A `DataFrame` to exclude.
528+
/// - Returns: A `DataFrame`.
529+
public func except(_ other: DataFrame) async -> DataFrame {
530+
let right = await (other.getPlan() as! Plan).root
531+
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.except)
532+
return DataFrame(spark: self.spark, plan: plan)
533+
}
534+
535+
/// Returns a new `DataFrame` containing rows in this `DataFrame` but not in another `DataFrame` while
536+
/// preserving the duplicates. This is equivalent to `EXCEPT ALL` in SQL.
537+
/// - Parameter other: A `DataFrame` to exclude.
538+
/// - Returns: A `DataFrame`.
539+
public func exceptAll(_ other: DataFrame) async -> DataFrame {
540+
let right = await (other.getPlan() as! Plan).root
541+
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.except, isAll: true)
542+
return DataFrame(spark: self.spark, plan: plan)
543+
}
544+
545+
/// Returns a new `DataFrame` containing rows only in both this `DataFrame` and another `DataFrame`.
546+
/// This is equivalent to `INTERSECT` in SQL.
547+
/// - Parameter other: A `DataFrame` to intersect with.
548+
/// - Returns: A `DataFrame`.
549+
public func intersect(_ other: DataFrame) async -> DataFrame {
550+
let right = await (other.getPlan() as! Plan).root
551+
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.intersect)
552+
return DataFrame(spark: self.spark, plan: plan)
553+
}
554+
555+
/// Returns a new `DataFrame` containing rows only in both this `DataFrame` and another `DataFrame` while
556+
/// preserving the duplicates. This is equivalent to `INTERSECT ALL` in SQL.
557+
/// - Parameter other: A `DataFrame` to intersect with.
558+
/// - Returns: A `DataFrame`.
559+
public func intersectAll(_ other: DataFrame) async -> DataFrame {
560+
let right = await (other.getPlan() as! Plan).root
561+
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.intersect, isAll: true)
562+
return DataFrame(spark: self.spark, plan: plan)
563+
}
564+
565+
/// Returns a new `DataFrame` containing union of rows in this `DataFrame` and another `DataFrame`.
566+
/// This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does
567+
/// deduplication of elements), use this function followed by a [[distinct]].
568+
/// Also as standard in SQL, this function resolves columns by position (not by name)
569+
/// - Parameter other: A `DataFrame` to union with.
570+
/// - Returns: A `DataFrame`.
571+
public func union(_ other: DataFrame) async -> DataFrame {
572+
let right = await (other.getPlan() as! Plan).root
573+
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.union, isAll: true)
574+
return DataFrame(spark: self.spark, plan: plan)
575+
}
576+
577+
/// Returns a new `DataFrame` containing union of rows in this `DataFrame` and another `DataFrame`.
578+
/// This is an alias of `union`.
579+
/// - Parameter other: A `DataFrame` to union with.
580+
/// - Returns: A `DataFrame`.
581+
public func unionAll(_ other: DataFrame) async -> DataFrame {
582+
return await union(other)
583+
}
584+
585+
/// Returns a new `DataFrame` containing union of rows in this `DataFrame` and another `DataFrame`.
586+
/// The difference between this function and [[union]] is that this function resolves columns by
587+
/// name (not by position).
588+
/// When the parameter `allowMissingColumns` is `true`, the set of column names in this and other
589+
/// `DataFrame` can differ; missing columns will be filled with null. Further, the missing columns
590+
/// of this `DataFrame` will be added at the end in the schema of the union result
591+
/// - Parameter other: A `DataFrame` to union with.
592+
/// - Returns: A `DataFrame`.
593+
public func unionByName(_ other: DataFrame, _ allowMissingColumns: Bool = false) async -> DataFrame {
594+
let right = await (other.getPlan() as! Plan).root
595+
let plan = SparkConnectClient.getSetOperation(
596+
self.plan.root,
597+
right,
598+
SetOpType.union,
599+
isAll: true,
600+
byName: true,
601+
allowMissingColumns: allowMissingColumns
602+
)
603+
return DataFrame(spark: self.spark, plan: plan)
604+
}
605+
525606
/// Returns a ``DataFrameWriter`` that can be used to write non-streaming data.
526607
public var write: DataFrameWriter {
527608
get {

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,4 +538,22 @@ public actor SparkConnectClient {
538538
return response.jsonToDdl.ddlString
539539
}
540540
}
541+
542+
static func getSetOperation(
543+
_ left: Relation, _ right: Relation, _ opType: SetOpType, isAll: Bool = false,
544+
byName: Bool = false, allowMissingColumns: Bool = false
545+
) -> Plan {
546+
var setOp = SetOperation()
547+
setOp.leftInput = left
548+
setOp.rightInput = right
549+
setOp.setOpType = opType
550+
setOp.isAll = isAll
551+
setOp.allowMissingColumns = allowMissingColumns
552+
setOp.byName = byName
553+
var relation = Relation()
554+
relation.setOp = setOp
555+
var plan = Plan()
556+
plan.opType = .root(relation)
557+
return plan
558+
}
541559
}

Sources/SparkConnect/TypeAliases.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ typealias Read = Spark_Connect_Read
4242
typealias Relation = Spark_Connect_Relation
4343
typealias Sample = Spark_Connect_Sample
4444
typealias SaveMode = Spark_Connect_WriteOperation.SaveMode
45+
typealias SetOperation = Spark_Connect_SetOperation
46+
typealias SetOpType = SetOperation.SetOpType
4547
typealias SparkConnectService = Spark_Connect_SparkConnectService
4648
typealias Sort = Spark_Connect_Sort
4749
typealias StructType = Spark_Connect_DataType.Struct

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,86 @@ struct DataFrameTests {
376376
#expect(try await df.unpersist().count() == 30)
377377
await spark.stop()
378378
}
379+
380+
@Test
381+
func except() async throws {
382+
let spark = try await SparkSession.builder.getOrCreate()
383+
let df = try await spark.range(1, 3)
384+
#expect(try await df.except(spark.range(1, 5)).collect() == [])
385+
#expect(try await df.except(spark.range(2, 5)).collect() == [Row("1")])
386+
#expect(try await df.except(spark.range(3, 5)).collect() == [Row("1"), Row("2")])
387+
#expect(try await spark.sql("SELECT * FROM VALUES 1, 1").except(df).count() == 0)
388+
await spark.stop()
389+
}
390+
391+
@Test
392+
func exceptAll() async throws {
393+
let spark = try await SparkSession.builder.getOrCreate()
394+
let df = try await spark.range(1, 3)
395+
#expect(try await df.exceptAll(spark.range(1, 5)).collect() == [])
396+
#expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row("1")])
397+
#expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row("1"), Row("2")])
398+
#expect(try await spark.sql("SELECT * FROM VALUES 1, 1").exceptAll(df).count() == 1)
399+
await spark.stop()
400+
}
401+
402+
@Test
403+
func intersect() async throws {
404+
let spark = try await SparkSession.builder.getOrCreate()
405+
let df = try await spark.range(1, 3)
406+
#expect(try await df.intersect(spark.range(1, 5)).collect() == [Row("1"), Row("2")])
407+
#expect(try await df.intersect(spark.range(2, 5)).collect() == [Row("2")])
408+
#expect(try await df.intersect(spark.range(3, 5)).collect() == [])
409+
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
410+
#expect(try await df2.intersect(df2).count() == 1)
411+
await spark.stop()
412+
}
413+
414+
@Test
415+
func intersectAll() async throws {
416+
let spark = try await SparkSession.builder.getOrCreate()
417+
let df = try await spark.range(1, 3)
418+
#expect(try await df.intersectAll(spark.range(1, 5)).collect() == [Row("1"), Row("2")])
419+
#expect(try await df.intersectAll(spark.range(2, 5)).collect() == [Row("2")])
420+
#expect(try await df.intersectAll(spark.range(3, 5)).collect() == [])
421+
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
422+
#expect(try await df2.intersectAll(df2).count() == 2)
423+
await spark.stop()
424+
}
425+
426+
@Test
427+
func union() async throws {
428+
let spark = try await SparkSession.builder.getOrCreate()
429+
let df = try await spark.range(1, 2)
430+
#expect(try await df.union(spark.range(1, 3)).collect() == [Row("1"), Row("1"), Row("2")])
431+
#expect(try await df.union(spark.range(2, 3)).collect() == [Row("1"), Row("2")])
432+
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
433+
#expect(try await df2.union(df2).count() == 4)
434+
await spark.stop()
435+
}
436+
437+
@Test
438+
func unionAll() async throws {
439+
let spark = try await SparkSession.builder.getOrCreate()
440+
let df = try await spark.range(1, 2)
441+
#expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row("1"), Row("1"), Row("2")])
442+
#expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row("1"), Row("2")])
443+
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
444+
#expect(try await df2.unionAll(df2).count() == 4)
445+
await spark.stop()
446+
}
447+
448+
@Test
449+
func unionByName() async throws {
450+
let spark = try await SparkSession.builder.getOrCreate()
451+
let df1 = try await spark.sql("SELECT 1 a, 2 b")
452+
let df2 = try await spark.sql("SELECT 4 b, 3 a")
453+
#expect(try await df1.unionByName(df2).collect() == [Row("1", "2"), Row("3", "4")])
454+
#expect(try await df1.union(df2).collect() == [Row("1", "2"), Row("4", "3")])
455+
let df3 = try await spark.sql("SELECT * FROM VALUES 1, 1")
456+
#expect(try await df3.unionByName(df3).count() == 4)
457+
await spark.stop()
458+
}
379459
#endif
380460

381461
@Test

0 commit comments

Comments
 (0)