Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions Sources/SparkConnect/ProtoUtils.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//
import Foundation

/// Utility functions like `org.apache.spark.sql.connect.common.ProtoUtils`.
public enum ProtoUtils {

private static let SPARK_JOB_TAGS_SEP = "," // SparkContext.SPARK_JOB_TAGS_SEP

/// Validate if a tag for ExecutePlanRequest.tags is valid. Throw IllegalArgumentException if not.
/// - Parameter tag: A tag string.
public static func throwIfInvalidTag(_ tag: String) throws {
// Same format rules apply to Spark Connect execution tags as to SparkContext job tags,
// because the Spark Connect job tag is also used as part of SparkContext job tag.
// See SparkContext.throwIfInvalidTag and ExecuteHolderSessionTag
if tag.isEmpty {
throw SparkConnectError.InvalidArgumentException
}
if tag.contains(SPARK_JOB_TAGS_SEP) {
throw SparkConnectError.InvalidArgumentException
}
}
}
29 changes: 29 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public actor SparkConnectClient {
let port: Int
let userContext: UserContext
var sessionID: String? = nil
var tags = Set<String>()

/// Create a client to use GRPCClient.
/// - Parameters:
Expand Down Expand Up @@ -229,6 +230,7 @@ public actor SparkConnectClient {
request.userContext = userContext
request.sessionID = self.sessionID!
request.operationID = UUID().uuidString
request.tags = Array(tags)
request.plan = plan
return request
}
Expand Down Expand Up @@ -409,4 +411,31 @@ public actor SparkConnectClient {
}
return result
}

/// Add a tag to be assigned to all the operations started by this thread in this session.
/// - Parameter tag: The tag to be added. Cannot contain ',' (comma) character or be an empty string.
public func addTag(tag: String) throws {
try ProtoUtils.throwIfInvalidTag(tag)
tags.insert(tag)
}

/// Remove a tag previously added to be assigned to all the operations started by this thread in this session.
/// Noop if such a tag was not added earlier.
/// - Parameter tag: The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
public func removeTag(tag: String) throws {
try ProtoUtils.throwIfInvalidTag(tag)
tags.remove(tag)
}

/// Get the operation tags that are currently set to be assigned to all the operations started by
/// this thread in this session.
/// - Returns: A set of string.
public func getTags() -> Set<String> {
return tags
}

/// Clear the current thread's operation tags.
public func clearTags() {
tags.removeAll()
}
}
1 change: 1 addition & 0 deletions Sources/SparkConnect/SparkConnectError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
/// A enum for ``SparkConnect`` package errors
public enum SparkConnectError: Error {
case UnsupportedOperationException
case InvalidArgumentException
case InvalidSessionIDException
case InvalidTypeException
}
25 changes: 25 additions & 0 deletions Sources/SparkConnect/SparkSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,31 @@ public actor SparkSession {
return ret
}

/// Add a tag to be assigned to all the operations started by this thread in this session.
/// - Parameter tag: The tag to be added. Cannot contain ',' (comma) character or be an empty string.
public func addTag(_ tag: String) async throws {
try await client.addTag(tag: tag)
}

/// Remove a tag previously added to be assigned to all the operations started by this thread in this session.
/// Noop if such a tag was not added earlier.
/// - Parameter tag: The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
public func removeTag(_ tag: String) async throws {
try await client.removeTag(tag: tag)
}

/// Get the operation tags that are currently set to be assigned to all the operations started by
/// this thread in this session.
/// - Returns: A set of string.
public func getTags() async -> Set<String> {
return await client.getTags()
}

/// Clear the current thread's operation tags.
public func clearTags() async {
await client.clearTags()
}

/// This is defined as the return type of `SparkSession.sparkContext` method.
/// This is an empty `Struct` type because `sparkContext` method is designed to throw
/// `UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT`.
Expand Down
17 changes: 17 additions & 0 deletions Tests/SparkConnectTests/SparkConnectClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,21 @@ struct SparkConnectClientTests {
let _ = try await client.connect(UUID().uuidString)
await client.stop()
}

@Test
func tags() async throws {
let client = SparkConnectClient(remote: "sc://localhost", user: "test")
let sessionID = UUID().uuidString
let _ = try await client.connect(sessionID)
let plan = await client.getPlanRange(0, 1, 1)

#expect(await client.getExecutePlanRequest(sessionID, plan).tags.isEmpty)
try await client.addTag(tag: "tag1")

#expect(await client.getExecutePlanRequest(sessionID, plan).tags == ["tag1"])
await client.clearTags()

#expect(await client.getExecutePlanRequest(sessionID, plan).tags.isEmpty)
await client.stop()
}
}
26 changes: 26 additions & 0 deletions Tests/SparkConnectTests/SparkSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,30 @@ struct SparkSessionTests {
#endif
await spark.stop()
}

@Test
func tag() async throws {
let spark = try await SparkSession.builder.getOrCreate()
try await spark.addTag("tag1")
#expect(await spark.getTags() == Set(["tag1"]))
try await spark.addTag("tag2")
#expect(await spark.getTags() == Set(["tag1", "tag2"]))
try await spark.removeTag("tag1")
#expect(await spark.getTags() == Set(["tag2"]))
await spark.clearTags()
#expect(await spark.getTags().isEmpty)
await spark.stop()
}

@Test
func invalidTags() async throws {
let spark = try await SparkSession.builder.getOrCreate()
await #expect(throws: SparkConnectError.InvalidArgumentException) {
try await spark.addTag("")
}
await #expect(throws: SparkConnectError.InvalidArgumentException) {
try await spark.addTag(",")
}
await spark.stop()
}
}
Loading