Skip to content

Commit 19717f2

Browse files
committed
[SPARK-52756] Support defineFlow
1 parent 52e217f commit 19717f2

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

Sources/SparkConnect/SparkConnectClient.swift

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,37 @@ public actor SparkConnectClient {
12721272
}
12731273
}
12741274

1275+
@discardableResult
1276+
func defineFlow(
1277+
_ dataflowGraphID: String,
1278+
_ flowName: String,
1279+
_ targetDatasetName: String,
1280+
_ relation: Relation
1281+
) async throws -> Bool {
1282+
try await withGPRC { client in
1283+
if UUID(uuidString: dataflowGraphID) == nil {
1284+
throw SparkConnectError.InvalidArgument
1285+
}
1286+
1287+
var defineFlow = Spark_Connect_PipelineCommand.DefineFlow()
1288+
defineFlow.dataflowGraphID = dataflowGraphID
1289+
defineFlow.flowName = flowName
1290+
defineFlow.targetDatasetName = targetDatasetName
1291+
defineFlow.plan = relation
1292+
1293+
var pipelineCommand = Spark_Connect_PipelineCommand()
1294+
pipelineCommand.commandType = .defineFlow(defineFlow)
1295+
1296+
var command = Spark_Connect_Command()
1297+
command.commandType = .pipelineCommand(pipelineCommand)
1298+
1299+
let responses = try await execute(self.sessionID!, command)
1300+
return responses.contains {
1301+
$0.responseType == .pipelineCommandResult(Spark_Connect_PipelineCommandResult())
1302+
}
1303+
}
1304+
}
1305+
12751306
private enum URIParams {
12761307
static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size"
12771308
static let PARAM_SESSION_ID = "session_id"

Tests/SparkConnectTests/SparkConnectClientTests.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,4 +146,22 @@ struct SparkConnectClientTests {
146146
}
147147
await client.stop()
148148
}
149+
150+
@Test
151+
func defineFlow() async throws {
152+
let client = SparkConnectClient(remote: TEST_REMOTE)
153+
let response = try await client.connect(UUID().uuidString)
154+
155+
try await #require(throws: SparkConnectError.InvalidArgument) {
156+
try await client.defineFlow("not-a-uuid-format", "f1", "ds1", Relation())
157+
}
158+
159+
if response.sparkVersion.version.starts(with: "4.1") {
160+
let dataflowGraphID = try await client.createDataflowGraph()
161+
#expect(UUID(uuidString: dataflowGraphID) != nil)
162+
let relation = await client.getLocalRelation().root
163+
#expect(try await client.defineFlow(dataflowGraphID, "f1", "ds1", relation))
164+
}
165+
await client.stop()
166+
}
149167
}

0 commit comments

Comments
 (0)