Skip to content

Commit 289ccad

Browse files
committed
[SPARK-51839] Support except(All)?/intersect(All)?/union(All)?/unionByName in DataFrame
1 parent 69a4ac4 commit 289ccad

File tree

4 files changed

+178
-0
lines changed

4 files changed

+178
-0
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,86 @@ public actor DataFrame: Sendable {
499499
}
500500
}
501501

502+
/// Returns a new `DataFrame` containing rows in this `DataFrame` but not in another `DataFrame`.
503+
/// This is equivalent to `EXCEPT DISTINCT` in SQL.
504+
/// - Parameter other: A `DataFrame` to exclude.
505+
/// - Returns: A `DataFrame`.
506+
public func except(_ other: DataFrame) async -> DataFrame {
507+
let right = await (other.getPlan() as! Plan).root
508+
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.except)
509+
return DataFrame(spark: self.spark, plan: plan)
510+
}
511+
512+
/// Returns a new `DataFrame` containing rows in this `DataFrame` but not in another `DataFrame` while
513+
/// preserving the duplicates. This is equivalent to `EXCEPT ALL` in SQL.
514+
/// - Parameter other: A `DataFrame` to exclude.
515+
/// - Returns: A `DataFrame`.
516+
public func exceptAll(_ other: DataFrame) async -> DataFrame {
517+
let right = await (other.getPlan() as! Plan).root
518+
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.except, isAll: true)
519+
return DataFrame(spark: self.spark, plan: plan)
520+
}
521+
522+
/// Returns a new `DataFrame` containing rows only in both this `DataFrame` and another `DataFrame`.
523+
/// This is equivalent to `INTERSECT` in SQL.
524+
/// - Parameter other: A `DataFrame` to intersect with.
525+
/// - Returns: A `DataFrame`.
526+
public func intersect(_ other: DataFrame) async -> DataFrame {
527+
let right = await (other.getPlan() as! Plan).root
528+
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.intersect)
529+
return DataFrame(spark: self.spark, plan: plan)
530+
}
531+
532+
/// Returns a new `DataFrame` containing rows only in both this `DataFrame` and another `DataFrame` while
533+
/// preserving the duplicates. This is equivalent to `INTERSECT ALL` in SQL.
534+
/// - Parameter other: A `DataFrame` to intersect with.
535+
/// - Returns: A `DataFrame`.
536+
public func intersectAll(_ other: DataFrame) async -> DataFrame {
537+
let right = await (other.getPlan() as! Plan).root
538+
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.intersect, isAll: true)
539+
return DataFrame(spark: self.spark, plan: plan)
540+
}
541+
542+
/// Returns a new `DataFrame` containing union of rows in this `DataFrame` and another `DataFrame`.
543+
/// This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does
544+
/// deduplication of elements), use this function followed by a [[distinct]].
545+
/// Also as standard in SQL, this function resolves columns by position (not by name)
546+
/// - Parameter other: A `DataFrame` to union with.
547+
/// - Returns: A `DataFrame`.
548+
public func union(_ other: DataFrame) async -> DataFrame {
549+
let right = await (other.getPlan() as! Plan).root
550+
let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.union, isAll: true)
551+
return DataFrame(spark: self.spark, plan: plan)
552+
}
553+
554+
/// Returns a new `DataFrame` containing union of rows in this `DataFrame` and another `DataFrame`.
555+
/// This is an alias of `union`.
556+
/// - Parameter other: A `DataFrame` to union with.
557+
/// - Returns: A `DataFrame`.
558+
public func unionAll(_ other: DataFrame) async -> DataFrame {
559+
return await union(other)
560+
}
561+
562+
/// Returns a new `DataFrame` containing union of rows in this `DataFrame` and another `DataFrame`.
563+
/// The difference between this function and [[union]] is that this function resolves columns by
564+
/// name (not by position).
565+
/// When the parameter `allowMissingColumns` is `true`, the set of column names in this and other
566+
/// `DataFrame` can differ; missing columns will be filled with null. Further, the missing columns
567+
/// of this `DataFrame` will be added at the end in the schema of the union result
568+
/// - Parameter other: A `DataFrame` to union with.
569+
/// - Returns: A `DataFrame`.
570+
public func unionByName(_ other: DataFrame, _ allowMissingColumns: Bool = false) async -> DataFrame {
571+
let right = await (other.getPlan() as! Plan).root
572+
let plan = SparkConnectClient.getSetOperation(
573+
self.plan.root,
574+
right,
575+
SetOpType.union,
576+
byName: true,
577+
allowMissingColumns: allowMissingColumns
578+
)
579+
return DataFrame(spark: self.spark, plan: plan)
580+
}
581+
502582
/// Returns a ``DataFrameWriter`` that can be used to write non-streaming data.
503583
public var write: DataFrameWriter {
504584
get {

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,4 +527,22 @@ public actor SparkConnectClient {
527527
return response.jsonToDdl.ddlString
528528
}
529529
}
530+
531+
static func getSetOperation(
532+
_ left: Relation, _ right: Relation, _ opType: SetOpType, isAll: Bool = false,
533+
byName: Bool = false, allowMissingColumns: Bool = false
534+
) -> Plan {
535+
var setOp = SetOperation()
536+
setOp.leftInput = left
537+
setOp.rightInput = right
538+
setOp.setOpType = opType
539+
setOp.isAll = isAll
540+
setOp.allowMissingColumns = allowMissingColumns
541+
setOp.byName = byName
542+
var relation = Relation()
543+
relation.setOp = setOp
544+
var plan = Plan()
545+
plan.opType = .root(relation)
546+
return plan
547+
}
530548
}

Sources/SparkConnect/TypeAliases.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ typealias Read = Spark_Connect_Read
4242
typealias Relation = Spark_Connect_Relation
4343
typealias Sample = Spark_Connect_Sample
4444
typealias SaveMode = Spark_Connect_WriteOperation.SaveMode
45+
typealias SetOperation = Spark_Connect_SetOperation
46+
typealias SetOpType = SetOperation.SetOpType
4547
typealias SparkConnectService = Spark_Connect_SparkConnectService
4648
typealias Sort = Spark_Connect_Sort
4749
typealias StructType = Spark_Connect_DataType.Struct

Tests/SparkConnectTests/DataFrameTests.swift

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,84 @@ struct DataFrameTests {
369369
#expect(try await df.unpersist().count() == 30)
370370
await spark.stop()
371371
}
372+
373+
@Test
374+
func except() async throws {
375+
let spark = try await SparkSession.builder.getOrCreate()
376+
let df = try await spark.range(1, 3)
377+
#expect(try await df.except(spark.range(1, 5)).collect() == [])
378+
#expect(try await df.except(spark.range(2, 5)).collect() == [Row("1")])
379+
#expect(try await df.except(spark.range(3, 5)).collect() == [Row("1"), Row("2")])
380+
#expect(try await spark.sql("SELECT * FROM VALUES 1, 1").except(df).count() == 0)
381+
await spark.stop()
382+
}
383+
384+
@Test
385+
func exceptAll() async throws {
386+
let spark = try await SparkSession.builder.getOrCreate()
387+
let df = try await spark.range(1, 3)
388+
#expect(try await df.exceptAll(spark.range(1, 5)).collect() == [])
389+
#expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row("1")])
390+
#expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row("1"), Row("2")])
391+
#expect(try await spark.sql("SELECT * FROM VALUES 1, 1").exceptAll(df).count() == 1)
392+
await spark.stop()
393+
}
394+
395+
@Test
396+
func intersect() async throws {
397+
let spark = try await SparkSession.builder.getOrCreate()
398+
let df = try await spark.range(1, 3)
399+
#expect(try await df.intersect(spark.range(1, 5)).collect() == [Row("1"), Row("2")])
400+
#expect(try await df.intersect(spark.range(2, 5)).collect() == [Row("2")])
401+
#expect(try await df.intersect(spark.range(3, 5)).collect() == [])
402+
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
403+
#expect(try await df2.intersect(df2).count() == 1)
404+
await spark.stop()
405+
}
406+
407+
@Test
408+
func intersectAll() async throws {
409+
let spark = try await SparkSession.builder.getOrCreate()
410+
let df = try await spark.range(1, 3)
411+
#expect(try await df.intersectAll(spark.range(1, 5)).collect() == [Row("1"), Row("2")])
412+
#expect(try await df.intersectAll(spark.range(2, 5)).collect() == [Row("2")])
413+
#expect(try await df.intersectAll(spark.range(3, 5)).collect() == [])
414+
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
415+
#expect(try await df2.intersectAll(df2).count() == 2)
416+
await spark.stop()
417+
}
418+
419+
@Test
420+
func union() async throws {
421+
let spark = try await SparkSession.builder.getOrCreate()
422+
let df = try await spark.range(1, 2)
423+
#expect(try await df.union(spark.range(1, 3)).collect() == [Row("1"), Row("1"), Row("2")])
424+
#expect(try await df.union(spark.range(2, 3)).collect() == [Row("1"), Row("2")])
425+
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
426+
#expect(try await df2.union(df2).count() == 4)
427+
await spark.stop()
428+
}
429+
430+
@Test
431+
func unionAll() async throws {
432+
let spark = try await SparkSession.builder.getOrCreate()
433+
let df = try await spark.range(1, 2)
434+
#expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row("1"), Row("1"), Row("2")])
435+
#expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row("1"), Row("2")])
436+
let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1")
437+
#expect(try await df2.unionAll(df2).count() == 4)
438+
await spark.stop()
439+
}
440+
441+
@Test
442+
func unionByName() async throws {
443+
let spark = try await SparkSession.builder.getOrCreate()
444+
let df1 = try await spark.sql("SELECT 1 a, 2 b")
445+
let df2 = try await spark.sql("SELECT 4 b, 3 a")
446+
#expect(try await df1.unionByName(df2).collect() == [Row("1", "2"), Row("3", "4")])
447+
#expect(try await df1.union(df2).collect() == [Row("1", "2"), Row("4", "3")])
448+
await spark.stop()
449+
}
372450
#endif
373451

374452
@Test

0 commit comments

Comments
 (0)