From cffaffd48754d98c756eacf605326d1d6aed45b1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 9 Jul 2025 20:18:29 -0700 Subject: [PATCH] [SPARK-52743] Support `startRun` --- Sources/SparkConnect/SparkConnectClient.swift | 23 +++++++++++++++++++ .../SparkConnectClientTests.swift | 17 ++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 023265c..e86e8ba 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -1214,6 +1214,29 @@ public actor SparkConnectClient { } } + @discardableResult + func startRun(_ dataflowGraphID: String) async throws -> Bool { + try await withGPRC { client in + if UUID(uuidString: dataflowGraphID) == nil { + throw SparkConnectError.InvalidArgument + } + + var startRun = Spark_Connect_PipelineCommand.StartRun() + startRun.dataflowGraphID = dataflowGraphID + + var pipelineCommand = Spark_Connect_PipelineCommand() + pipelineCommand.commandType = .startRun(startRun) + + var command = Spark_Connect_Command() + command.commandType = .pipelineCommand(pipelineCommand) + + let responses = try await execute(self.sessionID!, command) + return responses.contains { + $0.responseType == .pipelineCommandResult(Spark_Connect_PipelineCommandResult()) + } + } + } + private enum URIParams { static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size" static let PARAM_SESSION_ID = "session_id" diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift b/Tests/SparkConnectTests/SparkConnectClientTests.swift index 58702b1..955a9c8 100644 --- a/Tests/SparkConnectTests/SparkConnectClientTests.swift +++ b/Tests/SparkConnectTests/SparkConnectClientTests.swift @@ -107,4 +107,21 @@ struct SparkConnectClientTests { } await client.stop() } + + @Test + func startRun() async throws { + let client = SparkConnectClient(remote: TEST_REMOTE) + let response = try await client.connect(UUID().uuidString) + + try await #require(throws: SparkConnectError.InvalidArgument) { + try await client.startRun("not-a-uuid-format") + } + + if response.sparkVersion.version.starts(with: "4.1") { + let dataflowGraphID = try await client.createDataflowGraph() + #expect(UUID(uuidString: dataflowGraphID) != nil) + #expect(try await client.startRun(dataflowGraphID)) + } + await client.stop() + } }