Skip to content

Commit d5856c6

Browse files
committed
[SPARK-51785] Support addTag/removeTag/getTags/clearTags in SparkSession
### What changes were proposed in this pull request? This PR aims to support the following `SparkSession` APIs. - `addTag` - `removeTag` - `getTags` - `clearTags` Note that `interrupt`-related operations will be supported later as an independent PR. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No. This is a new addition. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #54 from dongjoon-hyun/SPARK-51785. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 6139773 commit d5856c6

File tree

6 files changed

+137
-0
lines changed

6 files changed

+137
-0
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//
2+
// Licensed to the Apache Software Foundation (ASF) under one
3+
// or more contributor license agreements. See the NOTICE file
4+
// distributed with this work for additional information
5+
// regarding copyright ownership. The ASF licenses this file
6+
// to you under the Apache License, Version 2.0 (the
7+
// "License"); you may not use this file except in compliance
8+
// with the License. You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing,
13+
// software distributed under the License is distributed on an
14+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
// KIND, either express or implied. See the License for the
16+
// specific language governing permissions and limitations
17+
// under the License.
18+
//
19+
import Foundation
20+
21+
/// Utility functions like `org.apache.spark.sql.connect.common.ProtoUtils`.
22+
public enum ProtoUtils {
23+
24+
private static let SPARK_JOB_TAGS_SEP = "," // SparkContext.SPARK_JOB_TAGS_SEP
25+
26+
/// Validate if a tag for ExecutePlanRequest.tags is valid. Throw IllegalArgumentException if not.
27+
/// - Parameter tag: A tag string.
28+
public static func throwIfInvalidTag(_ tag: String) throws {
29+
// Same format rules apply to Spark Connect execution tags as to SparkContext job tags,
30+
// because the Spark Connect job tag is also used as part of SparkContext job tag.
31+
// See SparkContext.throwIfInvalidTag and ExecuteHolderSessionTag
32+
if tag.isEmpty {
33+
throw SparkConnectError.InvalidArgumentException
34+
}
35+
if tag.contains(SPARK_JOB_TAGS_SEP) {
36+
throw SparkConnectError.InvalidArgumentException
37+
}
38+
}
39+
}

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public actor SparkConnectClient {
3030
let port: Int
3131
let userContext: UserContext
3232
var sessionID: String? = nil
33+
var tags = Set<String>()
3334

3435
/// Create a client to use GRPCClient.
3536
/// - Parameters:
@@ -229,6 +230,7 @@ public actor SparkConnectClient {
229230
request.userContext = userContext
230231
request.sessionID = self.sessionID!
231232
request.operationID = UUID().uuidString
233+
request.tags = Array(tags)
232234
request.plan = plan
233235
return request
234236
}
@@ -409,4 +411,31 @@ public actor SparkConnectClient {
409411
}
410412
return result
411413
}
414+
415+
/// Add a tag to be assigned to all the operations started by this thread in this session.
416+
/// - Parameter tag: The tag to be added. Cannot contain ',' (comma) character or be an empty string.
417+
public func addTag(tag: String) throws {
418+
try ProtoUtils.throwIfInvalidTag(tag)
419+
tags.insert(tag)
420+
}
421+
422+
/// Remove a tag previously added to be assigned to all the operations started by this thread in this session.
423+
/// Noop if such a tag was not added earlier.
424+
/// - Parameter tag: The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
425+
public func removeTag(tag: String) throws {
426+
try ProtoUtils.throwIfInvalidTag(tag)
427+
tags.remove(tag)
428+
}
429+
430+
/// Get the operation tags that are currently set to be assigned to all the operations started by
431+
/// this thread in this session.
432+
/// - Returns: A set of string.
433+
public func getTags() -> Set<String> {
434+
return tags
435+
}
436+
437+
/// Clear the current thread's operation tags.
438+
public func clearTags() {
439+
tags.removeAll()
440+
}
412441
}

Sources/SparkConnect/SparkConnectError.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
/// A enum for ``SparkConnect`` package errors
2121
public enum SparkConnectError: Error {
2222
case UnsupportedOperationException
23+
case InvalidArgumentException
2324
case InvalidSessionIDException
2425
case InvalidTypeException
2526
}

Sources/SparkConnect/SparkSession.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,31 @@ public actor SparkSession {
158158
return ret
159159
}
160160

161+
/// Add a tag to be assigned to all the operations started by this thread in this session.
162+
/// - Parameter tag: The tag to be added. Cannot contain ',' (comma) character or be an empty string.
163+
public func addTag(_ tag: String) async throws {
164+
try await client.addTag(tag: tag)
165+
}
166+
167+
/// Remove a tag previously added to be assigned to all the operations started by this thread in this session.
168+
/// Noop if such a tag was not added earlier.
169+
/// - Parameter tag: The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
170+
public func removeTag(_ tag: String) async throws {
171+
try await client.removeTag(tag: tag)
172+
}
173+
174+
/// Get the operation tags that are currently set to be assigned to all the operations started by
175+
/// this thread in this session.
176+
/// - Returns: A set of string.
177+
public func getTags() async -> Set<String> {
178+
return await client.getTags()
179+
}
180+
181+
/// Clear the current thread's operation tags.
182+
public func clearTags() async {
183+
await client.clearTags()
184+
}
185+
161186
/// This is defined as the return type of `SparkSession.sparkContext` method.
162187
/// This is an empty `Struct` type because `sparkContext` method is designed to throw
163188
/// `UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT`.

Tests/SparkConnectTests/SparkConnectClientTests.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,21 @@ struct SparkConnectClientTests {
4646
let _ = try await client.connect(UUID().uuidString)
4747
await client.stop()
4848
}
49+
50+
@Test
51+
func tags() async throws {
52+
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
53+
let sessionID = UUID().uuidString
54+
let _ = try await client.connect(sessionID)
55+
let plan = await client.getPlanRange(0, 1, 1)
56+
57+
#expect(await client.getExecutePlanRequest(sessionID, plan).tags.isEmpty)
58+
try await client.addTag(tag: "tag1")
59+
60+
#expect(await client.getExecutePlanRequest(sessionID, plan).tags == ["tag1"])
61+
await client.clearTags()
62+
63+
#expect(await client.getExecutePlanRequest(sessionID, plan).tags.isEmpty)
64+
await client.stop()
65+
}
4966
}

Tests/SparkConnectTests/SparkSessionTests.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,30 @@ struct SparkSessionTests {
9696
#endif
9797
await spark.stop()
9898
}
99+
100+
@Test
101+
func tag() async throws {
102+
let spark = try await SparkSession.builder.getOrCreate()
103+
try await spark.addTag("tag1")
104+
#expect(await spark.getTags() == Set(["tag1"]))
105+
try await spark.addTag("tag2")
106+
#expect(await spark.getTags() == Set(["tag1", "tag2"]))
107+
try await spark.removeTag("tag1")
108+
#expect(await spark.getTags() == Set(["tag2"]))
109+
await spark.clearTags()
110+
#expect(await spark.getTags().isEmpty)
111+
await spark.stop()
112+
}
113+
114+
@Test
115+
func invalidTags() async throws {
116+
let spark = try await SparkSession.builder.getOrCreate()
117+
await #expect(throws: SparkConnectError.InvalidArgumentException) {
118+
try await spark.addTag("")
119+
}
120+
await #expect(throws: SparkConnectError.InvalidArgumentException) {
121+
try await spark.addTag(",")
122+
}
123+
await spark.stop()
124+
}
99125
}

0 commit comments

Comments
 (0)