Skip to content

Commit b5dcc6e

Browse files
committed
[SPARK-52164] Support MergeIntoWriter
1 parent 7073438 commit b5dcc6e

File tree

5 files changed

+365
-1
lines changed

5 files changed

+365
-1
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,15 @@ public actor DataFrame: Sendable {
13951395
return DataFrameWriterV2(table, self)
13961396
}
13971397

1398+
/// Merges a set of updates, insertions, and deletions based on a source table into a target table.
1399+
/// - Parameters:
1400+
/// - table: A target table name.
1401+
/// - condition: A condition expression.
1402+
/// - Returns: A ``MergeIntoWriter`` instance.
1403+
public func mergeInto(_ table: String, _ condition: String) async -> MergeIntoWriter {
1404+
return await MergeIntoWriter(table, self, condition)
1405+
}
1406+
13981407
/// Returns a ``DataStreamWriter`` that can be used to write streaming data.
13991408
public var writeStream: DataStreamWriter {
14001409
get {
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
/// A struct for defining actions to be taken when matching rows in a ``DataFrame``
21+
/// during a merge operation.
22+
public struct WhenMatched: Sendable {
23+
let mergeIntoWriter: MergeIntoWriter
24+
let condition: String?
25+
26+
init(_ mergeIntoWriter: MergeIntoWriter, _ condition: String? = nil) {
27+
self.mergeIntoWriter = mergeIntoWriter
28+
self.condition = condition
29+
}
30+
31+
/// Specifies an action to update all matched rows in the ``DataFrame``.
32+
/// - Returns: The ``MergeIntoWriter`` instance with the update all action configured.
33+
public func updateAll() async -> MergeIntoWriter {
34+
await mergeIntoWriter.updateAll(condition, false)
35+
}
36+
37+
/// Specifies an action to update matched rows in the ``DataFrame`` with the provided column
38+
/// assignments.
39+
/// - Parameter map: A dictionary from column names to expressions representing the updates to be applied.
40+
/// - Returns: The ``MergeIntoWriter`` instance with the update action configured.
41+
public func update(map: [String: String]) async -> MergeIntoWriter {
42+
await mergeIntoWriter.update(condition, map, false)
43+
}
44+
45+
/// Specifies an action to delete matched rows from the DataFrame.
46+
/// - Returns: The ``MergeIntoWriter`` instance with the delete action configured.
47+
public func delete() async -> MergeIntoWriter {
48+
await mergeIntoWriter.delete(condition, false)
49+
}
50+
}
51+
52+
/// A struct for defining actions to be taken when no matching rows are found in a ``DataFrame``
53+
/// during a merge operation.
54+
public struct WhenNotMatched: Sendable {
55+
let mergeIntoWriter: MergeIntoWriter
56+
let condition: String?
57+
58+
init(_ mergeIntoWriter: MergeIntoWriter, _ condition: String? = nil) {
59+
self.mergeIntoWriter = mergeIntoWriter
60+
self.condition = condition
61+
}
62+
63+
/// Specifies an action to insert all non-matched rows into the ``DataFrame``.
64+
/// - Returns: The`` MergeIntoWriter`` instance with the insert all action configured.
65+
public func insertAll() async -> MergeIntoWriter {
66+
await mergeIntoWriter.insertAll(condition)
67+
}
68+
69+
/// Specifies an action to insert non-matched rows into the ``DataFrame``
70+
/// with the provided column assignments.
71+
/// - Parameter map: A dictionary of column names to expressions representing the values to be inserted.
72+
/// - Returns: The ``MergeIntoWriter`` instance with the insert action configured.
73+
public func insert(_ map: [String: String]) async -> MergeIntoWriter {
74+
await mergeIntoWriter.insert(condition, map)
75+
}
76+
}
77+
78+
public struct WhenNotMatchedBySource: Sendable {
79+
let mergeIntoWriter: MergeIntoWriter
80+
let condition: String?
81+
82+
init(_ mergeIntoWriter: MergeIntoWriter, _ condition: String? = nil) {
83+
self.mergeIntoWriter = mergeIntoWriter
84+
self.condition = condition
85+
}
86+
87+
public func updateAll() async -> MergeIntoWriter {
88+
await mergeIntoWriter.updateAll(condition, true)
89+
}
90+
91+
public func update(map: [String: String]) async -> MergeIntoWriter {
92+
await mergeIntoWriter.update(condition, map, true)
93+
}
94+
95+
public func delete() async -> MergeIntoWriter {
96+
await mergeIntoWriter.delete(condition, true)
97+
}
98+
}
99+
100+
/// `MergeIntoWriter` provides methods to define and execute merge actions based on specified
101+
/// conditions.
102+
public actor MergeIntoWriter {
103+
var schemaEvolution: Bool = false
104+
105+
let table: String
106+
107+
let df: DataFrame
108+
109+
let condition: String
110+
111+
var mergeIntoTableCommand = MergeIntoTableCommand()
112+
113+
init(_ table: String, _ df: DataFrame, _ condition: String) async {
114+
self.table = table
115+
self.df = df
116+
self.condition = condition
117+
118+
mergeIntoTableCommand.targetTableName = table
119+
mergeIntoTableCommand.sourceTablePlan = await (df.getPlan() as! Plan).root
120+
mergeIntoTableCommand.mergeCondition.expressionString = condition.toExpressionString
121+
}
122+
123+
public var schemaEvolutionEnabled: Bool {
124+
schemaEvolution
125+
}
126+
127+
/// Enable automatic schema evolution for this merge operation.
128+
/// - Returns: ``MergeIntoWriter`` instance
129+
public func withSchemaEvolution() -> MergeIntoWriter {
130+
self.schemaEvolution = true
131+
return self
132+
}
133+
134+
/// Initialize a `WhenMatched` action without any condition.
135+
/// - Returns: A `WhenMatched` instance.
136+
public func whenMatched() -> WhenMatched {
137+
WhenMatched(self)
138+
}
139+
140+
/// Initialize a `WhenMatched` action with a condition.
141+
/// - Parameter condition: <#condition description#>
142+
/// - Returns: A `WhenMatched` instance configured with the specified condition.
143+
public func whenMatched(_ condition: String) -> WhenMatched {
144+
WhenMatched(self, condition)
145+
}
146+
147+
/// Initialize a `WhenNotMatched` action without any condition.
148+
/// - Returns: A `WhenNotMatched` instance.
149+
public func whenNotMatched() -> WhenNotMatched {
150+
WhenNotMatched(self)
151+
}
152+
153+
/// Initialize a `WhenNotMatched` action with a condition.
154+
/// - Parameter condition: The condition to be evaluated for the action.
155+
/// - Returns: A `WhenNotMatched` instance configured with the specified condition.
156+
public func whenNotMatched(_ condition: String) -> WhenNotMatched {
157+
WhenNotMatched(self, condition)
158+
}
159+
160+
/// Initialize a `WhenNotMatchedBySource` action without any condition.
161+
/// - Returns: A `WhenNotMatchedBySource` instance.
162+
public func whenNotMatchedBySource() -> WhenNotMatchedBySource {
163+
WhenNotMatchedBySource(self)
164+
}
165+
166+
/// Initialize a `WhenNotMatchedBySource` action with a condition
167+
/// - Parameter condition: The condition to be evaluated for the action.
168+
/// - Returns: A `WhenNotMatchedBySource` instance configured with the specified condition.
169+
public func whenNotMatchedBySource(_ condition: String) -> WhenNotMatchedBySource {
170+
WhenNotMatchedBySource(self, condition)
171+
}
172+
173+
/// Executes the merge operation.
174+
public func merge() async throws {
175+
if self.mergeIntoTableCommand.matchActions.count == 0
176+
&& self.mergeIntoTableCommand.notMatchedActions.count == 0
177+
&& self.mergeIntoTableCommand.notMatchedBySourceActions.count == 0
178+
{
179+
throw SparkConnectError.InvalidArgumentException
180+
}
181+
self.mergeIntoTableCommand.withSchemaEvolution = self.schemaEvolution
182+
183+
var command = Spark_Connect_Command()
184+
command.mergeIntoTableCommand = self.mergeIntoTableCommand
185+
_ = try await df.spark.client.execute(df.spark.sessionID, command)
186+
}
187+
188+
public func insertAll(_ condition: String?) -> MergeIntoWriter {
189+
let expression = buildMergeAction(ActionType.insertStar, condition)
190+
self.mergeIntoTableCommand.notMatchedActions.append(expression)
191+
return self
192+
}
193+
194+
public func insert(_ condition: String?, _ map: [String: String]) -> MergeIntoWriter {
195+
let expression = buildMergeAction(ActionType.insert, condition, map)
196+
self.mergeIntoTableCommand.notMatchedActions.append(expression)
197+
return self
198+
}
199+
200+
public func updateAll(_ condition: String?, _ notMatchedBySource: Bool) -> MergeIntoWriter {
201+
appendUpdateDeleteAction(buildMergeAction(ActionType.updateStar, condition), notMatchedBySource)
202+
}
203+
204+
public func update(_ condition: String?, _ map: [String: String], _ notMatchedBySource: Bool)
205+
-> MergeIntoWriter
206+
{
207+
appendUpdateDeleteAction(buildMergeAction(ActionType.update, condition), notMatchedBySource)
208+
}
209+
210+
public func delete(_ condition: String?, _ notMatchedBySource: Bool) -> MergeIntoWriter {
211+
appendUpdateDeleteAction(buildMergeAction(ActionType.delete, condition), notMatchedBySource)
212+
}
213+
214+
private func appendUpdateDeleteAction(
215+
_ action: Spark_Connect_Expression,
216+
_ notMatchedBySource: Bool
217+
) -> MergeIntoWriter {
218+
if notMatchedBySource {
219+
self.mergeIntoTableCommand.notMatchedBySourceActions.append(action)
220+
} else {
221+
self.mergeIntoTableCommand.matchActions.append(action)
222+
}
223+
return self
224+
}
225+
226+
private func buildMergeAction(
227+
_ actionType: ActionType,
228+
_ condition: String?,
229+
_ assignments: [String: String] = [:]
230+
) -> Spark_Connect_Expression {
231+
var mergeAction = Spark_Connect_MergeAction()
232+
mergeAction.actionType = actionType
233+
if let condition {
234+
var expression = Spark_Connect_Expression()
235+
expression.expressionString = condition.toExpressionString
236+
mergeAction.condition = expression
237+
}
238+
mergeAction.assignments = assignments.map { key, value in
239+
var keyExpr = Spark_Connect_Expression()
240+
var valueExpr = Spark_Connect_Expression()
241+
242+
keyExpr.expressionString = key.toExpressionString
243+
valueExpr.expressionString = value.toExpressionString
244+
245+
var assignment = MergeAction.Assignment()
246+
assignment.key = keyExpr
247+
assignment.value = valueExpr
248+
return assignment
249+
}
250+
var expression = Spark_Connect_Expression()
251+
expression.mergeAction = mergeAction
252+
return expression
253+
}
254+
}

Sources/SparkConnect/TypeAliases.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// specific language governing permissions and limitations
1717
// under the License.
1818

19+
typealias ActionType = Spark_Connect_MergeAction.ActionType
1920
typealias Aggregate = Spark_Connect_Aggregate
2021
typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest
2122
typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
@@ -38,6 +39,8 @@ typealias KeyValue = Spark_Connect_KeyValue
3839
typealias LateralJoin = Spark_Connect_LateralJoin
3940
typealias Limit = Spark_Connect_Limit
4041
typealias MapType = Spark_Connect_DataType.Map
42+
typealias MergeAction = Spark_Connect_MergeAction
43+
typealias MergeIntoTableCommand = Spark_Connect_MergeIntoTableCommand
4144
typealias NamedTable = Spark_Connect_Read.NamedTable
4245
typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
4346
typealias OneOf_CatType = Spark_Connect_Catalog.OneOf_CatType

Tests/SparkConnectTests/IcebergTests.swift

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,31 @@ struct IcebergTests {
8383

8484
try await spark.table(t1).writeTo(t2).overwrite("id = 1")
8585
#expect(try await spark.table(t2).count() == 3)
86-
})
8786

87+
try await spark.sql(
88+
"""
89+
MERGE INTO \(t2) t
90+
USING (SELECT *
91+
FROM VALUES
92+
(1, 'delete', null),
93+
(2, 'update', 'updated'),
94+
(4, null, 'new') AS T(id, op, data)) s
95+
ON t.id = s.id
96+
WHEN MATCHED AND s.op = 'delete' THEN DELETE
97+
WHEN MATCHED AND s.op = 'update' THEN UPDATE SET t.data = s.data
98+
WHEN NOT MATCHED THEN INSERT *
99+
WHEN NOT MATCHED BY SOURCE THEN UPDATE SET data = 'invalid'
100+
"""
101+
).count()
102+
#if !os(Linux)
103+
let expected = [
104+
Row(2, "updated"),
105+
Row(3, "invalid"),
106+
Row(4, "new"),
107+
]
108+
#expect(try await spark.table(t2).collect() == expected)
109+
#endif
110+
})
88111
await spark.stop()
89112
}
90113

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
import Foundation
21+
import SparkConnect
22+
import Testing
23+
24+
/// A test suite for `MergeIntoWriter`
25+
/// Since this requires Apache Spark 4 with Iceberg support (SPARK-48794), this suite only tests syntaxes.
26+
@Suite(.serialized)
27+
struct MergeIntoWriterTests {
28+
@Test
29+
func whenMatched() async throws {
30+
let spark = try await SparkSession.builder.getOrCreate()
31+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
32+
try await SQLHelper.withTable(spark, tableName)({
33+
let mergeInto = try await spark.range(1).mergeInto(tableName, "true")
34+
try await #require(throws: Error.self) {
35+
try await mergeInto.whenMatched().delete().merge()
36+
}
37+
try await #require(throws: Error.self) {
38+
try await mergeInto.whenMatched("true").delete().merge()
39+
}
40+
})
41+
await spark.stop()
42+
}
43+
44+
@Test
45+
func whenNotMatched() async throws {
46+
let spark = try await SparkSession.builder.getOrCreate()
47+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
48+
try await SQLHelper.withTable(spark, tableName)({
49+
let mergeInto = try await spark.range(1).mergeInto(tableName, "true")
50+
try await #require(throws: Error.self) {
51+
try await mergeInto.whenNotMatched().insertAll().merge()
52+
}
53+
try await #require(throws: Error.self) {
54+
try await mergeInto.whenNotMatched("true").insertAll().merge()
55+
}
56+
})
57+
await spark.stop()
58+
}
59+
60+
@Test
61+
func whenNotMatchedBySource() async throws {
62+
let spark = try await SparkSession.builder.getOrCreate()
63+
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
64+
try await SQLHelper.withTable(spark, tableName)({
65+
let mergeInto = try await spark.range(1).mergeInto(tableName, "true")
66+
try await #require(throws: Error.self) {
67+
try await mergeInto.whenNotMatchedBySource().delete().merge()
68+
}
69+
try await #require(throws: Error.self) {
70+
try await mergeInto.whenNotMatchedBySource("true").delete().merge()
71+
}
72+
})
73+
await spark.stop()
74+
}
75+
}

0 commit comments

Comments
 (0)