Skip to content

Commit 2c2a5f1

Browse files
committed
[SPARK-52164] Support MergeIntoWriter
### What changes were proposed in this pull request? This PR aims to `MergeIntoWriter` actor. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #154 from dongjoon-hyun/SPARK-52164. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 7073438 commit 2c2a5f1

File tree

5 files changed

+377
-1
lines changed

5 files changed

+377
-1
lines changed

Sources/SparkConnect/DataFrame.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,15 @@ public actor DataFrame: Sendable {
13951395
return DataFrameWriterV2(table, self)
13961396
}
13971397

1398+
/// Merges a set of updates, insertions, and deletions based on a source table into a target table.
1399+
/// - Parameters:
1400+
/// - table: A target table name.
1401+
/// - condition: A condition expression.
1402+
/// - Returns: A ``MergeIntoWriter`` instance.
1403+
public func mergeInto(_ table: String, _ condition: String) async -> MergeIntoWriter {
1404+
return MergeIntoWriter(table, self, condition)
1405+
}
1406+
13981407
/// Returns a ``DataStreamWriter`` that can be used to write streaming data.
13991408
public var writeStream: DataStreamWriter {
14001409
get {
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
20+
/// A struct for defining actions to be taken when matching rows in a ``DataFrame``
21+
/// during a merge operation.
22+
public struct WhenMatched: Sendable {
23+
let mergeIntoWriter: MergeIntoWriter
24+
let condition: String?
25+
26+
init(_ mergeIntoWriter: MergeIntoWriter, _ condition: String? = nil) {
27+
self.mergeIntoWriter = mergeIntoWriter
28+
self.condition = condition
29+
}
30+
31+
/// Specifies an action to update all matched rows in the ``DataFrame``.
32+
/// - Returns: The ``MergeIntoWriter`` instance with the update all action configured.
33+
public func updateAll() async -> MergeIntoWriter {
34+
await mergeIntoWriter.updateAll(condition, false)
35+
}
36+
37+
/// Specifies an action to update matched rows in the ``DataFrame`` with the provided column
38+
/// assignments.
39+
/// - Parameter map: A dictionary from column names to expressions representing the updates to be applied.
40+
/// - Returns: The ``MergeIntoWriter`` instance with the update action configured.
41+
public func update(map: [String: String]) async -> MergeIntoWriter {
42+
await mergeIntoWriter.update(condition, map, false)
43+
}
44+
45+
/// Specifies an action to delete matched rows from the DataFrame.
46+
/// - Returns: The ``MergeIntoWriter`` instance with the delete action configured.
47+
public func delete() async -> MergeIntoWriter {
48+
await mergeIntoWriter.delete(condition, false)
49+
}
50+
}
51+
52+
/// A struct for defining actions to be taken when no matching rows are found in a ``DataFrame``
53+
/// during a merge operation.
54+
public struct WhenNotMatched: Sendable {
55+
let mergeIntoWriter: MergeIntoWriter
56+
let condition: String?
57+
58+
init(_ mergeIntoWriter: MergeIntoWriter, _ condition: String? = nil) {
59+
self.mergeIntoWriter = mergeIntoWriter
60+
self.condition = condition
61+
}
62+
63+
/// Specifies an action to insert all non-matched rows into the ``DataFrame``.
64+
/// - Returns: The`` MergeIntoWriter`` instance with the insert all action configured.
65+
public func insertAll() async -> MergeIntoWriter {
66+
await mergeIntoWriter.insertAll(condition)
67+
}
68+
69+
/// Specifies an action to insert non-matched rows into the ``DataFrame``
70+
/// with the provided column assignments.
71+
/// - Parameter map: A dictionary of column names to expressions representing the values to be inserted.
72+
/// - Returns: The ``MergeIntoWriter`` instance with the insert action configured.
73+
public func insert(_ map: [String: String]) async -> MergeIntoWriter {
74+
await mergeIntoWriter.insert(condition, map)
75+
}
76+
}
77+
78+
/// A struct for defining actions to be performed when there is no match by source during a merge
79+
/// operation in a ``MergeIntoWriter``.
80+
public struct WhenNotMatchedBySource: Sendable {
81+
let mergeIntoWriter: MergeIntoWriter
82+
let condition: String?
83+
84+
init(_ mergeIntoWriter: MergeIntoWriter, _ condition: String? = nil) {
85+
self.mergeIntoWriter = mergeIntoWriter
86+
self.condition = condition
87+
}
88+
89+
/// Specifies an action to update all non-matched rows in the target ``DataFrame``
90+
/// when not matched by the source.
91+
/// - Returns: A ``MergeIntoWriter`` instance.
92+
public func updateAll() async -> MergeIntoWriter {
93+
await mergeIntoWriter.updateAll(condition, true)
94+
}
95+
96+
/// Specifies an action to update non-matched rows in the target ``DataFrame``
97+
/// with the provided column assignments when not matched by the source.
98+
/// - Parameter map: A dictionary from column names to expressions representing the updates to be applied
99+
/// - Returns: A ``MergeIntoWriter`` instance.
100+
public func update(map: [String: String]) async -> MergeIntoWriter {
101+
await mergeIntoWriter.update(condition, map, true)
102+
}
103+
104+
/// Specifies an action to delete non-matched rows from the target ``DataFrame``
105+
/// when not matched by the source.
106+
/// - Returns: A ``MergeIntoWriter`` instance.
107+
public func delete() async -> MergeIntoWriter {
108+
await mergeIntoWriter.delete(condition, true)
109+
}
110+
}
111+
112+
/// `MergeIntoWriter` provides methods to define and execute merge actions based on specified
113+
/// conditions.
114+
public actor MergeIntoWriter {
115+
var schemaEvolution: Bool = false
116+
117+
let table: String
118+
119+
let df: DataFrame
120+
121+
let condition: String
122+
123+
var mergeIntoTableCommand = MergeIntoTableCommand()
124+
125+
init(_ table: String, _ df: DataFrame, _ condition: String) {
126+
self.table = table
127+
self.df = df
128+
self.condition = condition
129+
130+
self.mergeIntoTableCommand.targetTableName = table
131+
self.mergeIntoTableCommand.mergeCondition.expressionString = condition.toExpressionString
132+
}
133+
134+
public var schemaEvolutionEnabled: Bool {
135+
schemaEvolution
136+
}
137+
138+
/// Enable automatic schema evolution for this merge operation.
139+
/// - Returns: ``MergeIntoWriter`` instance
140+
public func withSchemaEvolution() -> MergeIntoWriter {
141+
self.schemaEvolution = true
142+
return self
143+
}
144+
145+
/// Initialize a `WhenMatched` action without any condition.
146+
/// - Returns: A `WhenMatched` instance.
147+
public func whenMatched() -> WhenMatched {
148+
WhenMatched(self)
149+
}
150+
151+
/// Initialize a `WhenMatched` action with a condition.
152+
/// - Parameter condition: The condition to be evaluated for the action.
153+
/// - Returns: A `WhenMatched` instance configured with the specified condition.
154+
public func whenMatched(_ condition: String) -> WhenMatched {
155+
WhenMatched(self, condition)
156+
}
157+
158+
/// Initialize a `WhenNotMatched` action without any condition.
159+
/// - Returns: A `WhenNotMatched` instance.
160+
public func whenNotMatched() -> WhenNotMatched {
161+
WhenNotMatched(self)
162+
}
163+
164+
/// Initialize a `WhenNotMatched` action with a condition.
165+
/// - Parameter condition: The condition to be evaluated for the action.
166+
/// - Returns: A `WhenNotMatched` instance configured with the specified condition.
167+
public func whenNotMatched(_ condition: String) -> WhenNotMatched {
168+
WhenNotMatched(self, condition)
169+
}
170+
171+
/// Initialize a `WhenNotMatchedBySource` action without any condition.
172+
/// - Returns: A `WhenNotMatchedBySource` instance.
173+
public func whenNotMatchedBySource() -> WhenNotMatchedBySource {
174+
WhenNotMatchedBySource(self)
175+
}
176+
177+
/// Initialize a `WhenNotMatchedBySource` action with a condition
178+
/// - Parameter condition: The condition to be evaluated for the action.
179+
/// - Returns: A `WhenNotMatchedBySource` instance configured with the specified condition.
180+
public func whenNotMatchedBySource(_ condition: String) -> WhenNotMatchedBySource {
181+
WhenNotMatchedBySource(self, condition)
182+
}
183+
184+
/// Executes the merge operation.
185+
public func merge() async throws {
186+
if self.mergeIntoTableCommand.matchActions.count == 0
187+
&& self.mergeIntoTableCommand.notMatchedActions.count == 0
188+
&& self.mergeIntoTableCommand.notMatchedBySourceActions.count == 0
189+
{
190+
throw SparkConnectError.InvalidArgumentException
191+
}
192+
self.mergeIntoTableCommand.sourceTablePlan = await (self.df.getPlan() as! Plan).root
193+
self.mergeIntoTableCommand.withSchemaEvolution = self.schemaEvolution
194+
195+
var command = Spark_Connect_Command()
196+
command.mergeIntoTableCommand = self.mergeIntoTableCommand
197+
_ = try await df.spark.client.execute(df.spark.sessionID, command)
198+
}
199+
200+
public func insertAll(_ condition: String?) -> MergeIntoWriter {
201+
let expression = buildMergeAction(ActionType.insertStar, condition)
202+
self.mergeIntoTableCommand.notMatchedActions.append(expression)
203+
return self
204+
}
205+
206+
public func insert(_ condition: String?, _ map: [String: String]) -> MergeIntoWriter {
207+
let expression = buildMergeAction(ActionType.insert, condition, map)
208+
self.mergeIntoTableCommand.notMatchedActions.append(expression)
209+
return self
210+
}
211+
212+
public func updateAll(_ condition: String?, _ notMatchedBySource: Bool) -> MergeIntoWriter {
213+
appendUpdateDeleteAction(buildMergeAction(ActionType.updateStar, condition), notMatchedBySource)
214+
}
215+
216+
public func update(_ condition: String?, _ map: [String: String], _ notMatchedBySource: Bool)
217+
-> MergeIntoWriter
218+
{
219+
appendUpdateDeleteAction(buildMergeAction(ActionType.update, condition), notMatchedBySource)
220+
}
221+
222+
public func delete(_ condition: String?, _ notMatchedBySource: Bool) -> MergeIntoWriter {
223+
appendUpdateDeleteAction(buildMergeAction(ActionType.delete, condition), notMatchedBySource)
224+
}
225+
226+
private func appendUpdateDeleteAction(
227+
_ action: Spark_Connect_Expression,
228+
_ notMatchedBySource: Bool
229+
) -> MergeIntoWriter {
230+
if notMatchedBySource {
231+
self.mergeIntoTableCommand.notMatchedBySourceActions.append(action)
232+
} else {
233+
self.mergeIntoTableCommand.matchActions.append(action)
234+
}
235+
return self
236+
}
237+
238+
private func buildMergeAction(
239+
_ actionType: ActionType,
240+
_ condition: String?,
241+
_ assignments: [String: String] = [:]
242+
) -> Spark_Connect_Expression {
243+
var mergeAction = Spark_Connect_MergeAction()
244+
mergeAction.actionType = actionType
245+
if let condition {
246+
var expression = Spark_Connect_Expression()
247+
expression.expressionString = condition.toExpressionString
248+
mergeAction.condition = expression
249+
}
250+
mergeAction.assignments = assignments.map { key, value in
251+
var keyExpr = Spark_Connect_Expression()
252+
var valueExpr = Spark_Connect_Expression()
253+
254+
keyExpr.expressionString = key.toExpressionString
255+
valueExpr.expressionString = value.toExpressionString
256+
257+
var assignment = MergeAction.Assignment()
258+
assignment.key = keyExpr
259+
assignment.value = valueExpr
260+
return assignment
261+
}
262+
var expression = Spark_Connect_Expression()
263+
expression.mergeAction = mergeAction
264+
return expression
265+
}
266+
}

Sources/SparkConnect/TypeAliases.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// specific language governing permissions and limitations
1717
// under the License.
1818

19+
typealias ActionType = Spark_Connect_MergeAction.ActionType
1920
typealias Aggregate = Spark_Connect_Aggregate
2021
typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest
2122
typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse
@@ -38,6 +39,8 @@ typealias KeyValue = Spark_Connect_KeyValue
3839
typealias LateralJoin = Spark_Connect_LateralJoin
3940
typealias Limit = Spark_Connect_Limit
4041
typealias MapType = Spark_Connect_DataType.Map
42+
typealias MergeAction = Spark_Connect_MergeAction
43+
typealias MergeIntoTableCommand = Spark_Connect_MergeIntoTableCommand
4144
typealias NamedTable = Spark_Connect_Read.NamedTable
4245
typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze
4346
typealias OneOf_CatType = Spark_Connect_Catalog.OneOf_CatType

Tests/SparkConnectTests/IcebergTests.swift

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,31 @@ struct IcebergTests {
8383

8484
try await spark.table(t1).writeTo(t2).overwrite("id = 1")
8585
#expect(try await spark.table(t2).count() == 3)
86-
})
8786

87+
try await spark.sql(
88+
"""
89+
MERGE INTO \(t2) t
90+
USING (SELECT *
91+
FROM VALUES
92+
(1, 'delete', null),
93+
(2, 'update', 'updated'),
94+
(4, null, 'new') AS T(id, op, data)) s
95+
ON t.id = s.id
96+
WHEN MATCHED AND s.op = 'delete' THEN DELETE
97+
WHEN MATCHED AND s.op = 'update' THEN UPDATE SET t.data = s.data
98+
WHEN NOT MATCHED THEN INSERT *
99+
WHEN NOT MATCHED BY SOURCE THEN UPDATE SET data = 'invalid'
100+
"""
101+
).count()
102+
#if !os(Linux)
103+
let expected = [
104+
Row(2, "updated"),
105+
Row(3, "invalid"),
106+
Row(4, "new"),
107+
]
108+
#expect(try await spark.table(t2).collect() == expected)
109+
#endif
110+
})
88111
await spark.stop()
89112
}
90113

0 commit comments

Comments
 (0)