From 289ccad20ebbeacd024863c005489a0f3426b2b9 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 18 Apr 2025 14:45:21 +0900 Subject: [PATCH 1/2] [SPARK-51839] Support `except(All)?/intersect(All)?/union(All)?/unionByName` in `DataFrame` --- Sources/SparkConnect/DataFrame.swift | 80 +++++++++++++++++++ Sources/SparkConnect/SparkConnectClient.swift | 18 +++++ Sources/SparkConnect/TypeAliases.swift | 2 + Tests/SparkConnectTests/DataFrameTests.swift | 78 ++++++++++++++++++ 4 files changed, 178 insertions(+) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 80e5692..70fa01a 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -499,6 +499,86 @@ public actor DataFrame: Sendable { } } + /// Returns a new `DataFrame` containing rows in this `DataFrame` but not in another `DataFrame`. + /// This is equivalent to `EXCEPT DISTINCT` in SQL. + /// - Parameter other: A `DataFrame` to exclude. + /// - Returns: A `DataFrame`. + public func except(_ other: DataFrame) async -> DataFrame { + let right = await (other.getPlan() as! Plan).root + let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.except) + return DataFrame(spark: self.spark, plan: plan) + } + + /// Returns a new `DataFrame` containing rows in this `DataFrame` but not in another `DataFrame` while + /// preserving the duplicates. This is equivalent to `EXCEPT ALL` in SQL. + /// - Parameter other: A `DataFrame` to exclude. + /// - Returns: A `DataFrame`. + public func exceptAll(_ other: DataFrame) async -> DataFrame { + let right = await (other.getPlan() as! Plan).root + let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.except, isAll: true) + return DataFrame(spark: self.spark, plan: plan) + } + + /// Returns a new `DataFrame` containing rows only in both this `DataFrame` and another `DataFrame`. + /// This is equivalent to `INTERSECT` in SQL. + /// - Parameter other: A `DataFrame` to intersect with. + /// - Returns: A `DataFrame`. + public func intersect(_ other: DataFrame) async -> DataFrame { + let right = await (other.getPlan() as! Plan).root + let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.intersect) + return DataFrame(spark: self.spark, plan: plan) + } + + /// Returns a new `DataFrame` containing rows only in both this `DataFrame` and another `DataFrame` while + /// preserving the duplicates. This is equivalent to `INTERSECT ALL` in SQL. + /// - Parameter other: A `DataFrame` to intersect with. + /// - Returns: A `DataFrame`. + public func intersectAll(_ other: DataFrame) async -> DataFrame { + let right = await (other.getPlan() as! Plan).root + let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.intersect, isAll: true) + return DataFrame(spark: self.spark, plan: plan) + } + + /// Returns a new `DataFrame` containing union of rows in this `DataFrame` and another `DataFrame`. + /// This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does + /// deduplication of elements), use this function followed by a [[distinct]]. + /// Also as standard in SQL, this function resolves columns by position (not by name) + /// - Parameter other: A `DataFrame` to union with. + /// - Returns: A `DataFrame`. + public func union(_ other: DataFrame) async -> DataFrame { + let right = await (other.getPlan() as! Plan).root + let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.union, isAll: true) + return DataFrame(spark: self.spark, plan: plan) + } + + /// Returns a new `DataFrame` containing union of rows in this `DataFrame` and another `DataFrame`. + /// This is an alias of `union`. + /// - Parameter other: A `DataFrame` to union with. + /// - Returns: A `DataFrame`. + public func unionAll(_ other: DataFrame) async -> DataFrame { + return await union(other) + } + + /// Returns a new `DataFrame` containing union of rows in this `DataFrame` and another `DataFrame`. + /// The difference between this function and [[union]] is that this function resolves columns by + /// name (not by position). + /// When the parameter `allowMissingColumns` is `true`, the set of column names in this and other + /// `DataFrame` can differ; missing columns will be filled with null. Further, the missing columns + /// of this `DataFrame` will be added at the end in the schema of the union result + /// - Parameter other: A `DataFrame` to union with. + /// - Returns: A `DataFrame`. + public func unionByName(_ other: DataFrame, _ allowMissingColumns: Bool = false) async -> DataFrame { + let right = await (other.getPlan() as! Plan).root + let plan = SparkConnectClient.getSetOperation( + self.plan.root, + right, + SetOpType.union, + byName: true, + allowMissingColumns: allowMissingColumns + ) + return DataFrame(spark: self.spark, plan: plan) + } + /// Returns a ``DataFrameWriter`` that can be used to write non-streaming data. public var write: DataFrameWriter { get { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 904e76e..329f8da 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -527,4 +527,22 @@ public actor SparkConnectClient { return response.jsonToDdl.ddlString } } + + static func getSetOperation( + _ left: Relation, _ right: Relation, _ opType: SetOpType, isAll: Bool = false, + byName: Bool = false, allowMissingColumns: Bool = false + ) -> Plan { + var setOp = SetOperation() + setOp.leftInput = left + setOp.rightInput = right + setOp.setOpType = opType + setOp.isAll = isAll + setOp.allowMissingColumns = allowMissingColumns + setOp.byName = byName + var relation = Relation() + relation.setOp = setOp + var plan = Plan() + plan.opType = .root(relation) + return plan + } } diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 0934a05..2ae0636 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -42,6 +42,8 @@ typealias Read = Spark_Connect_Read typealias Relation = Spark_Connect_Relation typealias Sample = Spark_Connect_Sample typealias SaveMode = Spark_Connect_WriteOperation.SaveMode +typealias SetOperation = Spark_Connect_SetOperation +typealias SetOpType = SetOperation.SetOpType typealias SparkConnectService = Spark_Connect_SparkConnectService typealias Sort = Spark_Connect_Sort typealias StructType = Spark_Connect_DataType.Struct diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index afd182c..9dc5d40 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -369,6 +369,84 @@ struct DataFrameTests { #expect(try await df.unpersist().count() == 30) await spark.stop() } + + @Test + func except() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 3) + #expect(try await df.except(spark.range(1, 5)).collect() == []) + #expect(try await df.except(spark.range(2, 5)).collect() == [Row("1")]) + #expect(try await df.except(spark.range(3, 5)).collect() == [Row("1"), Row("2")]) + #expect(try await spark.sql("SELECT * FROM VALUES 1, 1").except(df).count() == 0) + await spark.stop() + } + + @Test + func exceptAll() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 3) + #expect(try await df.exceptAll(spark.range(1, 5)).collect() == []) + #expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row("1")]) + #expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row("1"), Row("2")]) + #expect(try await spark.sql("SELECT * FROM VALUES 1, 1").exceptAll(df).count() == 1) + await spark.stop() + } + + @Test + func intersect() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 3) + #expect(try await df.intersect(spark.range(1, 5)).collect() == [Row("1"), Row("2")]) + #expect(try await df.intersect(spark.range(2, 5)).collect() == [Row("2")]) + #expect(try await df.intersect(spark.range(3, 5)).collect() == []) + let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df2.intersect(df2).count() == 1) + await spark.stop() + } + + @Test + func intersectAll() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 3) + #expect(try await df.intersectAll(spark.range(1, 5)).collect() == [Row("1"), Row("2")]) + #expect(try await df.intersectAll(spark.range(2, 5)).collect() == [Row("2")]) + #expect(try await df.intersectAll(spark.range(3, 5)).collect() == []) + let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df2.intersectAll(df2).count() == 2) + await spark.stop() + } + + @Test + func union() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 2) + #expect(try await df.union(spark.range(1, 3)).collect() == [Row("1"), Row("1"), Row("2")]) + #expect(try await df.union(spark.range(2, 3)).collect() == [Row("1"), Row("2")]) + let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df2.union(df2).count() == 4) + await spark.stop() + } + + @Test + func unionAll() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 2) + #expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row("1"), Row("1"), Row("2")]) + #expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row("1"), Row("2")]) + let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df2.unionAll(df2).count() == 4) + await spark.stop() + } + + @Test + func unionByName() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df1 = try await spark.sql("SELECT 1 a, 2 b") + let df2 = try await spark.sql("SELECT 4 b, 3 a") + #expect(try await df1.unionByName(df2).collect() == [Row("1", "2"), Row("3", "4")]) + #expect(try await df1.union(df2).collect() == [Row("1", "2"), Row("4", "3")]) + await spark.stop() + } #endif @Test From a3317d9539c996fadb81a742bbffeb76135b141f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 18 Apr 2025 15:33:24 +0900 Subject: [PATCH 2/2] Address comment --- Sources/SparkConnect/DataFrame.swift | 1 + Tests/SparkConnectTests/DataFrameTests.swift | 2 ++ 2 files changed, 3 insertions(+) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 70fa01a..4e1f248 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -573,6 +573,7 @@ public actor DataFrame: Sendable { self.plan.root, right, SetOpType.union, + isAll: true, byName: true, allowMissingColumns: allowMissingColumns ) diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 9dc5d40..62aa056 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -445,6 +445,8 @@ struct DataFrameTests { let df2 = try await spark.sql("SELECT 4 b, 3 a") #expect(try await df1.unionByName(df2).collect() == [Row("1", "2"), Row("3", "4")]) #expect(try await df1.union(df2).collect() == [Row("1", "2"), Row("4", "3")]) + let df3 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df3.unionByName(df3).count() == 4) await spark.stop() } #endif