Skip to content

Commit cffaffd

Browse files
committed
[SPARK-52743] Support startRun
1 parent 0dca569 commit cffaffd

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,29 @@ public actor SparkConnectClient {
12141214
}
12151215
}
12161216

1217+
@discardableResult
1218+
func startRun(_ dataflowGraphID: String) async throws -> Bool {
1219+
try await withGPRC { client in
1220+
if UUID(uuidString: dataflowGraphID) == nil {
1221+
throw SparkConnectError.InvalidArgument
1222+
}
1223+
1224+
var startRun = Spark_Connect_PipelineCommand.StartRun()
1225+
startRun.dataflowGraphID = dataflowGraphID
1226+
1227+
var pipelineCommand = Spark_Connect_PipelineCommand()
1228+
pipelineCommand.commandType = .startRun(startRun)
1229+
1230+
var command = Spark_Connect_Command()
1231+
command.commandType = .pipelineCommand(pipelineCommand)
1232+
1233+
let responses = try await execute(self.sessionID!, command)
1234+
return responses.contains {
1235+
$0.responseType == .pipelineCommandResult(Spark_Connect_PipelineCommandResult())
1236+
}
1237+
}
1238+
}
1239+
12171240
private enum URIParams {
12181241
static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size"
12191242
static let PARAM_SESSION_ID = "session_id"

Tests/SparkConnectTests/SparkConnectClientTests.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,21 @@ struct SparkConnectClientTests {
107107
}
108108
await client.stop()
109109
}
110+
111+
@Test
112+
func startRun() async throws {
113+
let client = SparkConnectClient(remote: TEST_REMOTE)
114+
let response = try await client.connect(UUID().uuidString)
115+
116+
try await #require(throws: SparkConnectError.InvalidArgument) {
117+
try await client.startRun("not-a-uuid-format")
118+
}
119+
120+
if response.sparkVersion.version.starts(with: "4.1") {
121+
let dataflowGraphID = try await client.createDataflowGraph()
122+
#expect(UUID(uuidString: dataflowGraphID) != nil)
123+
#expect(try await client.startRun(dataflowGraphID))
124+
}
125+
await client.stop()
126+
}
110127
}

0 commit comments

Comments
 (0)