diff --git a/Sources/SparkConnect/ProtoUtils.swift b/Sources/SparkConnect/ProtoUtils.swift new file mode 100644 index 0000000..738213f --- /dev/null +++ b/Sources/SparkConnect/ProtoUtils.swift @@ -0,0 +1,39 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +import Foundation + +/// Utility functions like `org.apache.spark.sql.connect.common.ProtoUtils`. +public enum ProtoUtils { + + private static let SPARK_JOB_TAGS_SEP = "," // SparkContext.SPARK_JOB_TAGS_SEP + + /// Validate if a tag for ExecutePlanRequest.tags is valid. Throw IllegalArgumentException if not. + /// - Parameter tag: A tag string. + public static func throwIfInvalidTag(_ tag: String) throws { + // Same format rules apply to Spark Connect execution tags as to SparkContext job tags, + // because the Spark Connect job tag is also used as part of SparkContext job tag. + // See SparkContext.throwIfInvalidTag and ExecuteHolderSessionTag + if tag.isEmpty { + throw SparkConnectError.InvalidArgumentException + } + if tag.contains(SPARK_JOB_TAGS_SEP) { + throw SparkConnectError.InvalidArgumentException + } + } +} diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 4e14077..6001ee8 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -30,6 +30,7 @@ public actor SparkConnectClient { let port: Int let userContext: UserContext var sessionID: String? = nil + var tags = Set() /// Create a client to use GRPCClient. /// - Parameters: @@ -229,6 +230,7 @@ public actor SparkConnectClient { request.userContext = userContext request.sessionID = self.sessionID! request.operationID = UUID().uuidString + request.tags = Array(tags) request.plan = plan return request } @@ -409,4 +411,31 @@ public actor SparkConnectClient { } return result } + + /// Add a tag to be assigned to all the operations started by this thread in this session. + /// - Parameter tag: The tag to be added. Cannot contain ',' (comma) character or be an empty string. + public func addTag(tag: String) throws { + try ProtoUtils.throwIfInvalidTag(tag) + tags.insert(tag) + } + + /// Remove a tag previously added to be assigned to all the operations started by this thread in this session. + /// Noop if such a tag was not added earlier. + /// - Parameter tag: The tag to be removed. Cannot contain ',' (comma) character or be an empty string. + public func removeTag(tag: String) throws { + try ProtoUtils.throwIfInvalidTag(tag) + tags.remove(tag) + } + + /// Get the operation tags that are currently set to be assigned to all the operations started by + /// this thread in this session. + /// - Returns: A set of string. + public func getTags() -> Set { + return tags + } + + /// Clear the current thread's operation tags. + public func clearTags() { + tags.removeAll() + } } diff --git a/Sources/SparkConnect/SparkConnectError.swift b/Sources/SparkConnect/SparkConnectError.swift index df293e2..4434b6d 100644 --- a/Sources/SparkConnect/SparkConnectError.swift +++ b/Sources/SparkConnect/SparkConnectError.swift @@ -20,6 +20,7 @@ /// A enum for ``SparkConnect`` package errors public enum SparkConnectError: Error { case UnsupportedOperationException + case InvalidArgumentException case InvalidSessionIDException case InvalidTypeException } diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index 3b07c27..1b943a4 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -158,6 +158,31 @@ public actor SparkSession { return ret } + /// Add a tag to be assigned to all the operations started by this thread in this session. + /// - Parameter tag: The tag to be added. Cannot contain ',' (comma) character or be an empty string. + public func addTag(_ tag: String) async throws { + try await client.addTag(tag: tag) + } + + /// Remove a tag previously added to be assigned to all the operations started by this thread in this session. + /// Noop if such a tag was not added earlier. + /// - Parameter tag: The tag to be removed. Cannot contain ',' (comma) character or be an empty string. + public func removeTag(_ tag: String) async throws { + try await client.removeTag(tag: tag) + } + + /// Get the operation tags that are currently set to be assigned to all the operations started by + /// this thread in this session. + /// - Returns: A set of string. + public func getTags() async -> Set { + return await client.getTags() + } + + /// Clear the current thread's operation tags. + public func clearTags() async { + await client.clearTags() + } + /// This is defined as the return type of `SparkSession.sparkContext` method. /// This is an empty `Struct` type because `sparkContext` method is designed to throw /// `UNSUPPORTED_CONNECT_FEATURE.SESSION_SPARK_CONTEXT`. diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift b/Tests/SparkConnectTests/SparkConnectClientTests.swift index f50ae5d..399e497 100644 --- a/Tests/SparkConnectTests/SparkConnectClientTests.swift +++ b/Tests/SparkConnectTests/SparkConnectClientTests.swift @@ -46,4 +46,21 @@ struct SparkConnectClientTests { let _ = try await client.connect(UUID().uuidString) await client.stop() } + + @Test + func tags() async throws { + let client = SparkConnectClient(remote: "sc://localhost", user: "test") + let sessionID = UUID().uuidString + let _ = try await client.connect(sessionID) + let plan = await client.getPlanRange(0, 1, 1) + + #expect(await client.getExecutePlanRequest(sessionID, plan).tags.isEmpty) + try await client.addTag(tag: "tag1") + + #expect(await client.getExecutePlanRequest(sessionID, plan).tags == ["tag1"]) + await client.clearTags() + + #expect(await client.getExecutePlanRequest(sessionID, plan).tags.isEmpty) + await client.stop() + } } diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index 24d0537..901a683 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -96,4 +96,30 @@ struct SparkSessionTests { #endif await spark.stop() } + + @Test + func tag() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.addTag("tag1") + #expect(await spark.getTags() == Set(["tag1"])) + try await spark.addTag("tag2") + #expect(await spark.getTags() == Set(["tag1", "tag2"])) + try await spark.removeTag("tag1") + #expect(await spark.getTags() == Set(["tag2"])) + await spark.clearTags() + #expect(await spark.getTags().isEmpty) + await spark.stop() + } + + @Test + func invalidTags() async throws { + let spark = try await SparkSession.builder.getOrCreate() + await #expect(throws: SparkConnectError.InvalidArgumentException) { + try await spark.addTag("") + } + await #expect(throws: SparkConnectError.InvalidArgumentException) { + try await spark.addTag(",") + } + await spark.stop() + } }