diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 8ea6250..7f033fd 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -1303,6 +1303,35 @@ public actor SparkConnectClient { } } + @discardableResult + func defineSqlGraphElements( + _ dataflowGraphID: String, + _ sqlFilePath: String, + _ sqlText: String + ) async throws -> Bool { + try await withGPRC { client in + if UUID(uuidString: dataflowGraphID) == nil { + throw SparkConnectError.InvalidArgument + } + + var elements = Spark_Connect_PipelineCommand.DefineSqlGraphElements() + elements.dataflowGraphID = dataflowGraphID + elements.sqlFilePath = sqlFilePath + elements.sqlText = sqlText + + var pipelineCommand = Spark_Connect_PipelineCommand() + pipelineCommand.commandType = .defineSqlGraphElements(elements) + + var command = Spark_Connect_Command() + command.commandType = .pipelineCommand(pipelineCommand) + + let responses = try await execute(self.sessionID!, command) + return responses.contains { + $0.responseType == .pipelineCommandResult(Spark_Connect_PipelineCommandResult()) + } + } + } + private enum URIParams { static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size" static let PARAM_SESSION_ID = "session_id" diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift b/Tests/SparkConnectTests/SparkConnectClientTests.swift index 72e31ba..ccdad48 100644 --- a/Tests/SparkConnectTests/SparkConnectClientTests.swift +++ b/Tests/SparkConnectTests/SparkConnectClientTests.swift @@ -164,4 +164,22 @@ struct SparkConnectClientTests { } await client.stop() } + + @Test + func defineSqlGraphElements() async throws { + let client = SparkConnectClient(remote: TEST_REMOTE) + let response = try await client.connect(UUID().uuidString) + + try await #require(throws: SparkConnectError.InvalidArgument) { + try await client.defineSqlGraphElements("not-a-uuid-format", "path", "sql") + } + + if response.sparkVersion.version.starts(with: "4.1") { + let dataflowGraphID = try await client.createDataflowGraph() + let sqlText = "CREATE MATERIALIZED VIEW mv1 AS SELECT 1" + #expect(UUID(uuidString: dataflowGraphID) != nil) + #expect(try await client.defineSqlGraphElements(dataflowGraphID, "path", sqlText)) + } + await client.stop() + } }