diff --git a/Examples/converse-stream/.gitignore b/Examples/converse-stream/.gitignore new file mode 100644 index 00000000..0023a534 --- /dev/null +++ b/Examples/converse-stream/.gitignore @@ -0,0 +1,8 @@ +.DS_Store +/.build +/Packages +xcuserdata/ +DerivedData/ +.swiftpm/configuration/registries.json +.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata +.netrc diff --git a/Examples/converse-stream/Package.swift b/Examples/converse-stream/Package.swift new file mode 100644 index 00000000..517e3b45 --- /dev/null +++ b/Examples/converse-stream/Package.swift @@ -0,0 +1,32 @@ +// swift-tools-version: 6.1 +// The swift-tools-version declares the minimum version of Swift required to build this package. + +import PackageDescription + +let package = Package( + name: "ConverseStream", + platforms: [.macOS(.v15), .iOS(.v18), .tvOS(.v18)], + products: [ + .executable(name: "ConverseStream", targets: ["ConverseStream"]) + ], + dependencies: [ + // for production use, uncomment the following line + // .package(url: "https://github.com/build-on-aws/swift-bedrock-library.git", branch: "main"), + + // for local development, use the following line + .package(name: "swift-bedrock-library", path: "../.."), + + .package(url: "https://github.com/apple/swift-log.git", from: "1.5.0"), + ], + targets: [ + // Targets are the basic building blocks of a package, defining a module or a test suite. + // Targets can depend on other targets in this package and products from dependencies. + .executableTarget( + name: "ConverseStream", + dependencies: [ + .product(name: "BedrockService", package: "swift-bedrock-library"), + .product(name: "Logging", package: "swift-log"), + ] + ) + ] +) diff --git a/Examples/converse-stream/Sources/ConverseStream.swift b/Examples/converse-stream/Sources/ConverseStream.swift new file mode 100644 index 00000000..7880901c --- /dev/null +++ b/Examples/converse-stream/Sources/ConverseStream.swift @@ -0,0 +1,91 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Bedrock Library open source project +// +// Copyright (c) 2025 Amazon.com, Inc. or its affiliates +// and the Swift Bedrock Library project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import BedrockService +import Logging + +@main +struct Main { + static func main() async throws { + do { + try await Main.converseStream() + } catch { + print("Error:\n\(error)") + } + } + static func converseStream() async throws { + var logger = Logger(label: "Converse") + logger.logLevel = .debug + + let bedrock = try await BedrockService( + region: .useast1, + logger: logger + // uncomment if you use SSO with AWS Identity Center + // authentication: .sso + ) + + // select a model that supports the converse modality + // models must be enabled in your AWS account + let model: BedrockModel = .nova_lite + + guard model.hasConverseModality() else { + throw MyError.incorrectModality("\(model.name) does not support converse") + } + + // create a request + let builder = try ConverseRequestBuilder(with: model) + .withPrompt("Tell me about rainbows") + + // send the request + let reply = try await bedrock.converseStream(with: builder) + + // the reply gives access to two streams. + // 1. `stream` is a high-level stream that provides elements of the conversation : + // - messageStart: this is the start of a message, it contains the role (assistant or user) + // - text: this is a delta of the text content, it contains the partial text + // - reasoning: this is a delta of the reasoning content, it contains the partial reasoning text + // - toolUse: this is a buffered tool use response, it contains the tool name and id, and the input parameters + // - message complete: this includes the complete Message, ready to be added to the history and used for future requests + // - metaData: this is the metadata about the response, it contains statitics about the response, such as the number of tokens used and the latency + // + // 2. `sdkStream` is the low-level stream provided by the AWS SDK. Use it when you need low level access to the stream, + // such as when you want to handle the stream in a custom way or when you need to access the raw data. + // see : https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html#conversation-inference-call-response-converse-stream + for try await element in reply.stream { + // process the stream elements + switch element { + case .messageStart(let role): + logger.info("Message started with role: \(role)") + case .text(_, let text): + print(text, terminator: "") + case .reasoning(let index, let reasoning): + logger.info("Reasoning delta: \(reasoning)", metadata: ["index": "\(index)"]) + case .toolUse(let index, let toolUse): + logger.info( + "Tool use: \(toolUse.name) with id: \(toolUse.id) and input: \(toolUse.input)", + metadata: ["index": "\(index)"] + ) + case .messageComplete(_): + print("\n") + case .metaData(let metaData): + logger.info("Metadata: \(metaData)") + } + } + } + + enum MyError: Error { + case incorrectModality(String) + } +} diff --git a/Examples/converse/Package.swift b/Examples/converse/Package.swift index d1226877..a09889a7 100644 --- a/Examples/converse/Package.swift +++ b/Examples/converse/Package.swift @@ -10,7 +10,12 @@ let package = Package( .executable(name: "Converse", targets: ["Converse"]) ], dependencies: [ - .package(url: "https://github.com/build-on-aws/swift-bedrock-library.git", branch: "main"), + // for production use, uncomment the following line + // .package(url: "https://github.com/build-on-aws/swift-bedrock-library.git", branch: "main"), + + // for local development, use the following line + .package(name: "swift-bedrock-library", path: "../.."), + .package(url: "https://github.com/apple/swift-log.git", from: "1.5.0"), ], targets: [ diff --git a/Examples/converse/Sources/Converse.swift b/Examples/converse/Sources/Converse.swift index ebd1a460..519bab37 100644 --- a/Examples/converse/Sources/Converse.swift +++ b/Examples/converse/Sources/Converse.swift @@ -33,7 +33,7 @@ struct Main { region: .useast1, logger: logger // uncomment if you use SSO with AWS Identity Center - // authentication: .sso + // authentication: .sso ) // select a model that supports the converse modality diff --git a/README.md b/README.md index 5265afc9..3d320b94 100644 --- a/README.md +++ b/README.md @@ -125,38 +125,45 @@ The stream will contain `ConverseStreamElement` object that can either be `conte To create the next builder, with the same model and inference parameters, use the full message from the `.messageComplete`. ```swift -let model: BedrockModel = .nova_lite + let model: BedrockModel = .nova_lite -guard model.hasConverseModality() else { - throw MyError.incorrectModality("\(model.name) does not support converse") -} -guard model.hasConverseModality(.reasoning) else { - throw MyError.incorrectModality("\(model.name) does not support reasoning") -} + guard model.hasConverseModality() else { + throw MyError.incorrectModality("\(model.name) does not support converse") + } -var builder = try ConverseRequestBuilder(from: builder, with: reply) - .withPrompt("Tell me more about the birds in Paris") + // create a request + let builder = try ConverseRequestBuilder(with: model) + .withPrompt("Tell me about rainbows") -let stream = try await bedrock.converseStream(with: builder) + // send the request + let reply = try await bedrock.converseStream(with: builder) -for try await element in stream { - switch element { - case .contentSegment(let contentSegment): - switch contentSegment { - case .text(_, let text): - print(text, terminator: "") - default: - break - } - case .contentBlockComplete: - print("\n\n") - case .messageComplete(let message): - assistantMessage = message - } -} + // consume the stream of elements + for try await element in reply.stream { -builder = try ConverseRequestBuilder(from: builder, with: assistantMessage) - .withPrompt("And what about the rats?") + switch element { + case .messageStart(let role): + logger.info("Message started with role: \(role)") + + case .text(_, let text): + print(text, terminator: "") + + case .reasoning(let index, let reasoning): + logger.info("Reasoning delta: \(reasoning)", metadata: ["index": "\(index)"]) + + case .toolUse(let index, let toolUse): + logger.info( + "Tool use: \(toolUse.name) with id: \(toolUse.id) and input: \(toolUse.input)", + metadata: ["index": "\(index)"] + ) + + case .messageComplete(_): + print("\n") + + case .metaData(let metaData): + logger.info("Metadata: \(metaData)") + } + } ``` ### Vision diff --git a/Sources/Converse/BedrockService+ConverseStreaming.swift b/Sources/Converse/BedrockService+ConverseStreaming.swift index 5b33c860..eb83c8b0 100644 --- a/Sources/Converse/BedrockService+ConverseStreaming.swift +++ b/Sources/Converse/BedrockService+ConverseStreaming.swift @@ -34,7 +34,8 @@ extension BedrockService { /// BedrockLibraryError.invalidPrompt if the prompt is empty or too long /// BedrockLibraryError.invalidModality for invalid modality from the selected model /// BedrockLibraryError.invalidSDKResponse if the response body is missing - /// - Returns: A stream of ConverseResponseStreaming objects + /// - Returns: A ConverseReplyStream object that gives access to the high-level stream of ConverseStreamElements objects + /// or the low-level stream provided by the AWS SDK. public func converseStream( with model: BedrockModel, conversation: [Message], @@ -46,7 +47,7 @@ extension BedrockService { tools: [Tool]? = nil, enableReasoning: Bool? = false, maxReasoningTokens: Int? = nil - ) async throws -> AsyncThrowingStream { + ) async throws -> ConverseReplyStream { do { guard model.hasConverseStreamingModality() else { throw BedrockLibraryError.invalidModality( @@ -118,18 +119,18 @@ extension BedrockService { // - message metadata // see https://github.com/awslabs/aws-sdk-swift/blob/2697fb44f607b9c43ad0ce5ca79867d8d6c545c2/Sources/Services/AWSBedrockRuntime/Sources/AWSBedrockRuntime/Models.swift#L3478 // it will be the responsibility of the user to handle the stream and re-assemble the messages and content - // TODO: should we expose the SDK ConverseStreamOutput from the SDK ? or wrap it (what's the added value) ? - let reply = ConverseReplyStream(sdkStream) + let reply = try ConverseReplyStream(sdkStream) // this time, a different stream is created from the previous one, this one has the following elements - // - content segment: this contains a ContentSegment, an enum which can be a .text(Int, String), - // the integer is the id for the content block that the content segment is a part of, - // the String is the part of text that is send from the model. - // - content block complete: this includes the id of the completed content block and the complete content block itself + // - messageStart: this is the start of a message, it contains the role (assistant or user) + // - text: this is a delta of the text content, it contains the partial text + // - reasoning: this is a delta of the reasoning content, it contains the partial reasoning text + // - toolUse: this is a buffered tool use response, it contains the tool name and id, and the input parameters // - message complete: this includes the complete Message, ready to be added to the history and used for future requests + // - metaData: this is the metadata about the response, it contains statitics about the response, such as the number of tokens used and the latency - return reply.stream + return reply } catch { try handleCommonError(error, context: "invoking converse stream") @@ -143,7 +144,7 @@ extension BedrockService { /// - Returns: A stream of ConverseResponseStreaming objects public func converseStream( with builder: ConverseRequestBuilder - ) async throws -> AsyncThrowingStream { + ) async throws -> ConverseReplyStream { logger.trace("Conversing and streaming") do { var history = builder.history diff --git a/Sources/Converse/Message.swift b/Sources/Converse/Message.swift index e056a633..1d4aef8e 100644 --- a/Sources/Converse/Message.swift +++ b/Sources/Converse/Message.swift @@ -17,14 +17,27 @@ import Foundation public struct Message: Codable, CustomStringConvertible, Sendable { + + // https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_MessageStopEvent.html + // end_turn | tool_use | max_tokens | stop_sequence | guardrail_intervened | content_filtered + public enum StopReason: Codable, Sendable { + case endTurn + case toolUse + case maxTokens + case stopSequence + case guardrailIntervened + case contentFiltered + } public let role: Role public let content: [Content] + public let stopReason: StopReason? // MARK - initializers - public init(from role: Role, content: [Content]) { + public init(from role: Role, content: [Content], stopReason: StopReason? = nil) { self.role = role self.content = content + self.stopReason = stopReason } /// convenience initializer for message with only a user prompt @@ -111,6 +124,15 @@ public struct Message: Codable, CustomStringConvertible, Sendable { public func hasTextContent() -> Bool { content.contains { $0.isText() } } + public func textContent() -> String? { + let content = content.first(where: { $0.isText() }) + if case .text(let text) = content { + return text + } else { + return nil + } + } + public func hasImageContent() -> Bool { content.contains { $0.isImage() } } @@ -137,4 +159,24 @@ public struct Message: Codable, CustomStringConvertible, Sendable { role: role.getSDKConversationRole() ) } + + public static func stopReason(fromSDK sdkStopReason: BedrockRuntimeClientTypes.StopReason?) -> StopReason? { + switch sdkStopReason { + case .endTurn: + return .endTurn + case .toolUse: + return .toolUse + case .maxTokens: + return .maxTokens + case .stopSequence: + return .stopSequence + case .guardrailIntervened: + return .guardrailIntervened + case .contentFiltered: + return .contentFiltered + default: + return nil + } + } + } diff --git a/Sources/Converse/Role.swift b/Sources/Converse/Role.swift index d1601607..dd5bc8f1 100644 --- a/Sources/Converse/Role.swift +++ b/Sources/Converse/Role.swift @@ -16,7 +16,7 @@ @preconcurrency import AWSBedrockRuntime import Foundation -public struct Role: Codable, Sendable, Equatable { +public struct Role: Codable, Sendable, Equatable, CustomStringConvertible { private enum RoleType: Codable, Sendable, Equatable { case user case assistant @@ -71,9 +71,19 @@ public struct Role: Codable, Sendable, Equatable { } } /// Returns the type of the role as a string. + public var description: String { + switch self.type { + case .user: return "user" + case .assistant: return "assistant" + } + } + + // Equatable public static func == (lhs: Role, rhs: Role) -> Bool { lhs.type == rhs.type } + + // convenience static properties for common roles private init(_ type: RoleType) { self.type = type } diff --git a/Sources/Converse/Streaming/Content+getFromSegements.swift b/Sources/Converse/Streaming/Content+getFromSegements.swift deleted file mode 100644 index c752b984..00000000 --- a/Sources/Converse/Streaming/Content+getFromSegements.swift +++ /dev/null @@ -1,94 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift Bedrock Library open source project -// -// Copyright (c) 2025 Amazon.com, Inc. or its affiliates -// and the Swift Bedrock Library project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import Foundation - -extension Content { - static func getFromSegements(with index: Int, from segments: [ContentSegment]) throws -> Content { - var text = "" - var toolUseName = "" - var toolUseId = "" - var toolUseInput = "" - var reasoningText = "" - var reasoningSignature = "" - var encryptedReasoning: Data? = nil - - for segment in segments { - if segment.index == index { - switch segment { - - case .text(_, let textPart): - text += textPart - - case .reasoning(_, let textPart, let signaturePart): - guard text == "" else { - throw BedrockLibraryError.streamingError( - "A reasoning segment was found in a contentBlock that already contained text segments" - ) - } - reasoningText += textPart - reasoningSignature += signaturePart - - case .toolUse(_, let toolUsePart): - guard text == "" else { - throw BedrockLibraryError.streamingError( - "A toolUse segment was found in a contentBlock that already contained text segments" - ) - } - if toolUseName == "" { - toolUseName = toolUsePart.name - } else if toolUseName != toolUsePart.name { - throw BedrockLibraryError.streamingError( - "A toolUse segment was found in a contentBlock that contained multiple tools with different toolUseName" - ) - } - if toolUseId == "" { - toolUseId = toolUsePart.toolUseId - } else if toolUseId != toolUsePart.toolUseId { - throw BedrockLibraryError.streamingError( - "A toolUse segment was found in a contentBlock that contained multiple tools with different toolUseId" - ) - } - toolUseInput += toolUsePart.inputPart - - case .encryptedReasoning(_, let data): - guard text == "" else { - throw BedrockLibraryError.streamingError( - "An encrypted reasoning segment was found in a contentBlock that already contained text segments" - ) - } - guard reasoningText == "", reasoningSignature == "" else { - throw BedrockLibraryError.streamingError( - "An encrypted reasoning segment was found in a contentBlock that already contained reasoning segments" - ) - } - encryptedReasoning = data - break - } - } - } - if text != "" { - return .text(text) - } else if reasoningText != "" { - return .reasoning(Reasoning(reasoningText, signature: reasoningSignature)) - } else if toolUseInput != "", toolUseName != "", toolUseId != "" { - return .toolUse(ToolUseBlock(id: toolUseId, name: toolUseName, input: try JSON(from: toolUseInput))) - } else if let encryptedReasoning { - return .encryptedReasoning(EncryptedReasoning(encryptedReasoning)) - } else { - throw BedrockLibraryError.streamingError("No content found in ContentSegments to create Content") - } - } -} diff --git a/Sources/Converse/Streaming/ContentSegment.swift b/Sources/Converse/Streaming/ContentSegment.swift deleted file mode 100644 index 94796d66..00000000 --- a/Sources/Converse/Streaming/ContentSegment.swift +++ /dev/null @@ -1,142 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift Bedrock Library open source project -// -// Copyright (c) 2025 Amazon.com, Inc. or its affiliates -// and the Swift Bedrock Library project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -@preconcurrency import AWSBedrockRuntime -import Foundation - -public enum ContentSegment: Sendable { - case text(Int, String) - case reasoning(Int, String, String) // index, text, signature - case encryptedReasoning(Int, Data) - case toolUse(Int, ToolUsePart) - - public var index: Int { - switch self { - case .text(let index, _): - return index - case .toolUse(let index, _): - return index - case .reasoning(let index, _, _): - return index - case .encryptedReasoning(let index, _): - return index - } - } - - public var reasoningText: String? { - switch self { - case .reasoning(_, let text, _): - return text - default: - return nil - } - } - - public var reasoningSignature: String? { - switch self { - case .reasoning(_, _, let signature): - return signature - default: - return nil - } - } - - // MARK - Init - - package init( - index: Int, - sdkContentBlockDelta: BedrockRuntimeClientTypes.ContentBlockDelta, - toolUseStarts: [ToolUseStart] - ) throws { - switch sdkContentBlockDelta { - case .text(let text): - self = .text(index, text) - case .tooluse(let toolUseBlockDelta): - guard let input = toolUseBlockDelta.input else { - throw BedrockLibraryError.invalidSDKType("No input found in ToolUseBlockDelta") - } - guard let toolUseStart = toolUseStarts.first(where: { $0.index == index }) - else { - throw BedrockLibraryError.streamingError( - "No ToolUse can be constructed, because no matching name and toolUseId from ContentBlockStart for ToolUseBlockDelta were found " - ) - } - self = .toolUse( - index, - ToolUsePart( - index: index, - name: toolUseStart.name, - toolUseId: toolUseStart.toolUseId, - inputPart: input - ) - ) - case .reasoningcontent(let sdkReasoningBlock): - switch sdkReasoningBlock { - case .text(let reasoningText): - self = .reasoning(index, reasoningText, "") - case .signature(let reasoningSignature): - self = .reasoning(index, "", reasoningSignature) - case .redactedcontent(let data): - self = .encryptedReasoning(index, data) - default: - throw BedrockLibraryError.notImplemented( - "ReasoningBlockContent \(sdkReasoningBlock) is not implemented by BedrockService or not implemented by BedrockRuntimeClientTypes in case of `sdkUnknown`" - ) - } - default: - throw BedrockLibraryError.notImplemented( - "ContentBlockDelta \(sdkContentBlockDelta) is not implemented by BedrockService or not implemented by BedrockRuntimeClientTypes in case of `sdkUnknown`" - ) - } - } - - // MARK - convenience - - public func hasToolUse() -> Bool { - switch self { - case .toolUse: - return true - default: - return false - } - } - - public func hasText() -> Bool { - switch self { - case .text: - return true - default: - return false - } - } - - public func hasReasoning() -> Bool { - switch self { - case .reasoning: - return true - default: - return false - } - } - - public func hasEncryptedReasoning() -> Bool { - switch self { - case .encryptedReasoning: - return true - default: - return false - } - } -} diff --git a/Sources/Converse/Streaming/ConverseReplyStream.swift b/Sources/Converse/Streaming/ConverseReplyStream.swift index 71e8eae7..cf7025c8 100644 --- a/Sources/Converse/Streaming/ConverseReplyStream.swift +++ b/Sources/Converse/Streaming/ConverseReplyStream.swift @@ -13,87 +13,201 @@ // //===----------------------------------------------------------------------===// @preconcurrency import AWSBedrockRuntime +import Foundation +import Logging -// To consider: do we want the developer to use ConverseReplyStream or do we simply use it to return the stream? -// This will determine the visibility -package struct ConverseReplyStream { - package var stream: AsyncThrowingStream +public struct ConverseReplyStream: Sendable { - package init(_ inputStream: AsyncThrowingStream) { + private let logger: Logger - self.stream = AsyncThrowingStream(ConverseStreamElement.self) { continuation in + // This is the stream that the user will consume + public let stream: AsyncThrowingStream + + // This is the stream that the SDK provides, which we will convert to our own stream + // we expose it as a public property to allow demanding developers to access the raw SDK stream if needed + public let sdkStream: AsyncThrowingStream + + package init( + _ inputStream: AsyncThrowingStream, + logger: Logger? = nil + ) throws { + + self.logger = logger ?? .init(label: "ConverseReplyStream") + + // store the sdk-provided stream to expose it to developers if needed + self.sdkStream = inputStream + + // build a new stream that will convert the SDK stream output to our own ConverseStreamElement + self.stream = try ConverseReplyStream.convertToLibraryStream(inputStream, logger: self.logger) + } + + /// Convert the SDK Stream to a highler level stream of ConverseStreamElement + private static func convertToLibraryStream( + _ inputStream: AsyncThrowingStream, + logger: Logger + ) throws -> AsyncThrowingStream { + + AsyncThrowingStream(ConverseStreamElement.self) { continuation in let t = Task { - var indexes: [Int] = [] - var contentParts: [ContentSegment] = [] - var content: [Content] = [] - var toolUseStarts: [ToolUseStart] = [] do { + var state: StreamState! + + // Convert the SDK stream output to our own stream elements for try await output in inputStream { + switch output { - // case .messagestart(_): - // continuation.yield(.messageStart) + case .messagestart(let event): + logger.trace("Message Start", metadata: ["event": "\(event)"]) + + guard let sdkRole = event.role, + let role = try? Role(from: sdkRole) + else { + throw BedrockLibraryError.invalidSDKType("Role is missing in message start event") + } + + state = StreamState(with: role) + continuation.yield(.messageStart(role)) + + // only received at the start of a tool use block + // https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html#conversation-inference-call-response case .contentblockstart(let event): - guard let index = event.contentBlockIndex else { + logger.trace("Content Block Start") + guard state.currentBlockId == -1 else { + // If we already have a block started, this is an error throw BedrockLibraryError.invalidSDKType( - "No contentBlockIndex found in ContentBlockStart" + "ContentBlockStart received while another block is active" ) } - indexes.append(index) - if let start: BedrockRuntimeClientTypes.ContentBlockStart = event.start { - if case .tooluse(let toolUseBlockStart) = start { - toolUseStarts.append( - try ToolUseStart(index: index, sdkToolUseStart: toolUseBlockStart) - ) - } + guard let blockId = event.contentBlockIndex else { + throw BedrockLibraryError.invalidSDKType( + "Block ID is missing in content block start event" + ) } + state.currentBlockId = blockId + state.toolUseStart = try ToolUseStart(index: blockId, sdkEventBlockStart: event.start) + // do not yield an event here, wait for full ToolUse block to arrive + case .contentblockdelta(let event): - guard let index = event.contentBlockIndex else { + logger.trace("Content Block Delta") + guard let blockId = event.contentBlockIndex else { + // when there is no blockId, this is an error throw BedrockLibraryError.invalidSDKType( - "No contentBlockIndex found in ContentBlockDelta" + "Block ID is missing in content block delta event" ) } - if !indexes.contains(index) { - //some models do not send ContentBlockStart before ContentBlockDelta - indexes.append(index) - // continuation.yield(.messageStart) + guard state.currentBlockId == -1 || state.currentBlockId == blockId else { + // when the blockId doesn't match the current block, this is an error + throw BedrockLibraryError.invalidSDKType( + "Block ID mismatch in content block delta event" + ) } - guard let delta = event.delta else { - throw BedrockLibraryError.invalidSDKType("No delta found in ContentBlockDelta") + // for text and reasoning delta, we receive the block id at the first delta event + state.currentBlockId = blockId + + switch event.delta { + case .text(let text): + state.bufferText += text + continuation.yield(.text(blockId, text)) + case .tooluse(let toolUseDelta): + state.bufferToolUse += toolUseDelta.input ?? "" + // do not yield events for tooluse, wait for the full JSON to arrive + case .reasoningcontent(let reasoningDelta): + switch reasoningDelta { + case .text(let text): + state.bufferReasoning += text + continuation.yield(.reasoning(blockId, text)) + case .signature(let signature): + state.bufferReasoningSignature += signature + // do not yield partial signature, wait for full JSON data + case .redactedcontent(let redactedContent): + state.bufferReasoningData.append(redactedContent) + // do not yield partial reasoning data, wait for full JSON data + case .sdkUnknown(let output): + logger.warning( + "Received unknown SDK Reasoning Delta", + metadata: ["reasoning delta": "\(output)"] + ) + } + case .sdkUnknown(let output): + logger.warning( + "Received unknown SDK Event Delta", + metadata: ["delta": "\(output)"] + ) + case .none: + logger.warning("Received none SDK Event Delta") } - let segment = try ContentSegment( - index: index, - sdkContentBlockDelta: delta, - toolUseStarts: toolUseStarts - ) - contentParts.append(segment) - continuation.yield(.contentSegment(segment)) case .contentblockstop(let event): - guard let completedIndex = event.contentBlockIndex else { + logger.trace("Content Block Stop") + guard state.currentBlockId != -1 else { + // If we don't have a block started, this is an error + throw BedrockLibraryError.invalidSDKType( + "ContentBlockStop received while no block is active" + ) + } + guard let blockId = event.contentBlockIndex, + blockId == state.currentBlockId + else { + // If we don't have a block started, this is an error throw BedrockLibraryError.invalidSDKType( - "No contentBlockIndex found in ContentBlockStop" + "ContentBlockStop received while no block is active or block ID mismatch" ) } - guard indexes.contains(completedIndex) else { - throw BedrockLibraryError.streamingError( - "No matching index from ContentBlockStart or ContentBlockDelta found for index from ContentBlockStop" + + // reassemble buffered data and emit top-level event + try ConverseReplyStream.flushContent(state: &state, continuation: continuation) + guard let lastContentBlock = state.lastContentBlock else { + fatalError( + String( + "ContentBlockStop received but no content block was buffered for block ID \(blockId)" + ) ) } - let contentBlock = try Content.getFromSegements(with: completedIndex, from: contentParts) - content.append(contentBlock) - continuation.yield(.contentBlockComplete(completedIndex, contentBlock)) + // just yield ToolUse, the partial text and reasoning are already yielded + if case .toolUse(let toolUse) = lastContentBlock.1 { + continuation.yield(.toolUse(blockId, toolUse)) + } + // buffer this content block + state.contentBlocks[blockId] = lastContentBlock.1 - case .messagestop(_): - let message = Message(from: .assistant, content: content) + // reset the current block ID + state.currentBlockId = -1 + + case .messagestop(let event): + logger.trace("Message Stop") + state.messageComplete = true + + // create a Message with all content blocks + let message = Message( + from: state.role, + content: state.contentBlocks.sorted { $0.key < $1.key }.map { $0.value }, + stopReason: Message.stopReason(fromSDK: event.stopReason) + ) continuation.yield(.messageComplete(message)) - continuation.finish() - default: - break - } - } + case .metadata(let event): + logger.trace("Metadata", metadata: ["event": "\(event)"]) + + // Convert the metadata event to our ResponseMetadata type + let metadata = try ResponseMetadata(from: event) + continuation.yield(.metaData(metadata)) + + case .sdkUnknown(let output): + // Handle unknown SDK output + // This is a catch-all for any future SDK output types that we don't handle yet + // We log it and continue, but we could also throw an error if desired + logger.warning( + "Received unknown SDK ConverseStreamOutput", + metadata: ["output": "\(output)"] + ) + } // switch + + } // for try await + + continuation.finish() + // when we reach here, the stream is finished or the Task is cancelled - // when cancelled, it will throw CancellationError + // when cancelled, it should throw CancellationError // not really necessary as this seems to be handled by the Stream anyway. try Task.checkCancellation() @@ -111,4 +225,104 @@ package struct ConverseReplyStream { } } } + + /// Flushes and processes the buffered content from the stream state + /// + /// This function processes any buffered content in the stream state and creates the appropriate Content type. + /// It performs validation to ensure only one type of content buffer is non-empty at a time. + /// + /// The method is static to avoid callers to capture self, which is not allowed in async contexts. + /// + /// - Parameters: + /// - state: The current stream state containing buffered content + /// - continuation: The stream continuation for emitting events + /// + /// - Returns: A tuple containing the block ID and processed Content, or nil if no content to process + /// + /// - Throws: BedrockLibraryError.invalidSDKType if validation fails or buffers are in an invalid state + private static func flushContent( + state: inout StreamState, + continuation: AsyncThrowingStream.Continuation + ) throws { + guard + isToolUseBufferValid(state) || isReasoningDataBufferValid(state) || isEmptyBufferValid(state) + || isReasoningBufferValid(state) || isTextBufferValid(state) + else { + throw BedrockLibraryError.invalidSDKType("ContentBlockStop received while multiple buffers are not empty") + } + + if !state.bufferText.isEmpty { + state.lastContentBlock = (state.currentBlockId, Content.text(state.bufferText)) + state.bufferText = "" + } + if !state.bufferReasoning.isEmpty { + let signature = state.bufferReasoningSignature == "" ? nil : state.bufferReasoningSignature + state.lastContentBlock = ( + state.currentBlockId, .reasoning(.init(state.bufferReasoning, signature: signature)) + ) + state.bufferReasoning = "" + } + // TODO: encrypted reasoning is not supported at the moment + // if !bufferReasoningData.isEmpty { + // contentBlock[currentBlockId] = .reasoning(bufferReasoningData) + // bufferReasoningData = Data() + // } + if !state.bufferToolUse.isEmpty { + guard let toolUseStart = state.toolUseStart else { + throw BedrockLibraryError.invalidSDKType("Received a tool use delta without tool use start") + } + let json = try JSON(from: state.bufferToolUse) + state.lastContentBlock = ( + state.currentBlockId, .toolUse(.init(id: toolUseStart.id, name: toolUseStart.name, input: json)) + ) + state.bufferToolUse = "" + } + state.currentBlockId = -1 + } + + private static func isToolUseBufferValid(_ state: StreamState) -> Bool { + state.bufferText.isEmpty && state.bufferReasoning.isEmpty && state.bufferReasoningData.isEmpty + && !state.bufferToolUse.isEmpty + } + + private static func isReasoningDataBufferValid(_ state: StreamState) -> Bool { + state.bufferText.isEmpty && state.bufferReasoning.isEmpty && !state.bufferReasoningData.isEmpty + && state.bufferToolUse.isEmpty + } + + private static func isEmptyBufferValid(_ state: StreamState) -> Bool { + state.bufferText.isEmpty && state.bufferReasoning.isEmpty && state.bufferReasoningData.isEmpty + && state.bufferToolUse.isEmpty + } + + private static func isReasoningBufferValid(_ state: StreamState) -> Bool { + state.bufferText.isEmpty && !state.bufferReasoning.isEmpty && state.bufferReasoningData.isEmpty + && state.bufferToolUse.isEmpty + } + + private static func isTextBufferValid(_ state: StreamState) -> Bool { + !state.bufferText.isEmpty && state.bufferReasoning.isEmpty && state.bufferReasoningData.isEmpty + && state.bufferToolUse.isEmpty + } + + // a simple struct to buffer whatever content we receive from the SDK + // until final message is complete + package struct StreamState { + package init(with role: Role) { + self.role = role + } + let role: Role + var messageComplete: Bool = false + var currentBlockId: Int = -1 // -1 means no block is active + var bufferText: String = "" + var bufferReasoning: String = "" + var bufferReasoningSignature = "" + var bufferReasoningData = Data() + var bufferToolUse: String = "" + var toolUseStart: ToolUseStart? = nil + + // list of content blocks to be accumulated while reading the stream + var lastContentBlock: (Int, Content)? = nil + var contentBlocks: [Int: Content] = [:] + } } diff --git a/Sources/Converse/Streaming/ConverseStreamElement.swift b/Sources/Converse/Streaming/ConverseStreamElement.swift index c508729e..428f4b82 100644 --- a/Sources/Converse/Streaming/ConverseStreamElement.swift +++ b/Sources/Converse/Streaming/ConverseStreamElement.swift @@ -14,8 +14,12 @@ //===----------------------------------------------------------------------===// public enum ConverseStreamElement: Sendable { - case messageStart - case contentSegment(ContentSegment) - case contentBlockComplete(Int, Content) - case messageComplete(Message) + case messageStart(Role) // start of a message + case text(Int, String) // partial text + case reasoning(Int, String) // partial reasoning + case toolUse(Int, ToolUseBlock) // a complete tool use response + case messageComplete(Message) // complete text message (with all content blocks and reason for stop) + case metaData(ResponseMetadata) // metadata about the response } + +//TODO: the above struct does not manage encryptedReasoning diff --git a/Sources/Converse/Streaming/ResponseMetaData.swift b/Sources/Converse/Streaming/ResponseMetaData.swift new file mode 100644 index 00000000..aeb786aa --- /dev/null +++ b/Sources/Converse/Streaming/ResponseMetaData.swift @@ -0,0 +1,66 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Bedrock Library open source project +// +// Copyright (c) 2025 Amazon.com, Inc. or its affiliates +// and the Swift Bedrock Library project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AWSBedrockRuntime + +public struct ResponseMetadata: Codable, Sendable { + let metadata: Metadata + + package init(from sdkMetadata: BedrockRuntimeClientTypes.ConverseStreamMetadataEvent) throws { + self.metadata = try .init(usage: sdkMetadata.usage, metrics: sdkMetadata.metrics) + } + + public struct Metadata: Codable, Sendable { + let usage: Usage? + let metrics: Metrics? + // TODO: trace and performance are not implemented yet + + package init( + usage: BedrockRuntimeClientTypes.TokenUsage?, + metrics: BedrockRuntimeClientTypes.ConverseStreamMetrics? + ) throws { + if usage != nil { + self.usage = try .init(from: usage!) + } else { + self.usage = nil + } + + if metrics != nil { + self.metrics = try .init(from: metrics!) + } else { + self.metrics = nil + } + + } + public struct Usage: Codable, Sendable { + package init(from sdkUsage: BedrockRuntimeClientTypes.TokenUsage) throws { + self.inputTokens = sdkUsage.inputTokens ?? 0 + self.outputTokens = sdkUsage.outputTokens ?? 0 + self.totalTokens = sdkUsage.totalTokens ?? 0 + } + + let inputTokens: Int + let outputTokens: Int + let totalTokens: Int + } + + public struct Metrics: Codable, Sendable { + package init(from sdkMetrics: BedrockRuntimeClientTypes.ConverseStreamMetrics) throws { + self.latencyMs = Int(sdkMetrics.latencyMs ?? 0) + } + let latencyMs: Int + } + } +} diff --git a/Sources/Converse/Streaming/ToolUseStart.swift b/Sources/Converse/Streaming/ToolUseStart.swift index b26950a0..84e71fb3 100644 --- a/Sources/Converse/Streaming/ToolUseStart.swift +++ b/Sources/Converse/Streaming/ToolUseStart.swift @@ -18,9 +18,9 @@ package struct ToolUseStart: Sendable { var index: Int var name: String - var toolUseId: String + var id: String - init(index: Int, sdkToolUseStart: BedrockRuntimeClientTypes.ToolUseBlockStart) throws { + private init(index: Int, sdkToolUseStart: BedrockRuntimeClientTypes.ToolUseBlockStart) throws { guard let name = sdkToolUseStart.name else { throw BedrockLibraryError.invalidSDKType("No name found in ToolUseBlockStart") } @@ -29,25 +29,16 @@ package struct ToolUseStart: Sendable { } self.index = index self.name = name - self.toolUseId = toolUseId + self.id = toolUseId + } + package init(index: Int, sdkEventBlockStart: BedrockRuntimeClientTypes.ContentBlockStart?) throws { + guard let sdkEventBlockStart else { + throw BedrockLibraryError.invalidSDKType("No ContentBlockStart found in ToolUseStart") + } + if case .tooluse(let sdkToolUseStart) = sdkEventBlockStart { + try self.init(index: index, sdkToolUseStart: sdkToolUseStart) + } else { + throw BedrockLibraryError.invalidSDKType("ContentBlockStart is not a ToolUseStart") + } } -} - -public struct ToolUsePart: Sendable { - var index: Int - var name: String - var toolUseId: String - var inputPart: String - - // init(index: Int, sdkToolUseStart: BedrockRuntimeClientTypes.ToolUseBlockStart) throws { - // guard let name = sdkToolUseStart.name else { - // throw BedrockLibraryError.invalidSDKType("No name found in ToolUseBlockStart") - // } - // guard let toolUseId = sdkToolUseStart.toolUseId else { - // throw BedrockLibraryError.invalidSDKType("No toolUseId found in ToolUseBlockStart") - // } - // self.index = index - // self.name = name - // self.toolUseId = toolUseId - // } } diff --git a/Tests/ConverseStream/ConverseReplyGenerator.swift b/Tests/ConverseStream/ConverseReplyGenerator.swift new file mode 100644 index 00000000..1f8d6ce4 --- /dev/null +++ b/Tests/ConverseStream/ConverseReplyGenerator.swift @@ -0,0 +1,163 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Bedrock Library open source project +// +// Copyright (c) 2025 Amazon.com, Inc. or its affiliates +// and the Swift Bedrock Library project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@preconcurrency import AWSBedrockRuntime + +extension ConverseReplyStreamTests { + + // Helper function to create a simulated stream with a single text block + func createSingleTextBlockStream() -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + // Message start + let messageStartEvent = BedrockRuntimeClientTypes.MessageStartEvent( + role: .assistant + ) + continuation.yield(.messagestart(messageStartEvent)) + + // Content block delta (first part) + let contentBlockDelta1 = BedrockRuntimeClientTypes.ContentBlockDelta.text("Hello, ") + let contentBlockDeltaEvent1 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 0, + delta: contentBlockDelta1 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent1)) + + // Content block delta (second part) + let contentBlockDelta2 = BedrockRuntimeClientTypes.ContentBlockDelta.text("this is ") + let contentBlockDeltaEvent2 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 0, + delta: contentBlockDelta2 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent2)) + + // Content block delta (third part) + let contentBlockDelta3 = BedrockRuntimeClientTypes.ContentBlockDelta.text("a test message.") + let contentBlockDeltaEvent3 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 0, + delta: contentBlockDelta3 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent3)) + + // Content block stop + let contentBlockStopEvent = BedrockRuntimeClientTypes.ContentBlockStopEvent( + contentBlockIndex: 0 + ) + continuation.yield(.contentblockstop(contentBlockStopEvent)) + + // Message stop + let messageStopEvent = BedrockRuntimeClientTypes.MessageStopEvent( + additionalModelResponseFields: nil, + stopReason: nil + ) + continuation.yield(.messagestop(messageStopEvent)) + + continuation.finish() + } + } + + // Helper function to create a simulated stream with multiple content blocks + func createMultipleContentBlocksStream() -> AsyncThrowingStream< + BedrockRuntimeClientTypes.ConverseStreamOutput, Error + > { + AsyncThrowingStream { continuation in + // Message start + let messageStartEvent = BedrockRuntimeClientTypes.MessageStartEvent( + role: .assistant + ) + continuation.yield(.messagestart(messageStartEvent)) + + // First content block + let contentBlockDelta1 = BedrockRuntimeClientTypes.ContentBlockDelta.text("First block content.") + let contentBlockDeltaEvent1 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 0, + delta: contentBlockDelta1 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent1)) + + let contentBlockStopEvent1 = BedrockRuntimeClientTypes.ContentBlockStopEvent( + contentBlockIndex: 0 + ) + continuation.yield(.contentblockstop(contentBlockStopEvent1)) + + // Second content block + let contentBlockDelta2 = BedrockRuntimeClientTypes.ContentBlockDelta.text("Second block content.") + let contentBlockDeltaEvent2 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 1, + delta: contentBlockDelta2 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent2)) + + let contentBlockStopEvent2 = BedrockRuntimeClientTypes.ContentBlockStopEvent( + contentBlockIndex: 1 + ) + continuation.yield(.contentblockstop(contentBlockStopEvent2)) + + // Message stop + let messageStopEvent = BedrockRuntimeClientTypes.MessageStopEvent( + additionalModelResponseFields: nil, + stopReason: .endTurn + ) + continuation.yield(.messagestop(messageStopEvent)) + + continuation.finish() + } + } + + // Helper function to create a never-ending stream that will continue indefinitely + func createNeverEndingStream() -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + // Message start + let messageStartEvent = BedrockRuntimeClientTypes.MessageStartEvent( + role: .assistant + ) + continuation.yield(.messagestart(messageStartEvent)) + + // Set up a counter to track how many deltas we've sent + var counter = 0 + + // Create a Task that will continuously send content block deltas + // This simulates a never-ending stream of tokens from the model + let continuousTask = Task { + while !Task.isCancelled { + // Create a content block delta with a counter to track progress + let text = "Token \(counter) " + let contentBlockDelta = BedrockRuntimeClientTypes.ContentBlockDelta.text(text) + let contentBlockDeltaEvent = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 0, + delta: contentBlockDelta + ) + + // Yield the delta + continuation.yield(.contentblockdelta(contentBlockDeltaEvent)) + + // Increment counter + counter += 1 + + // Add a small delay to avoid overwhelming the system + try await Task.sleep(nanoseconds: 10_000_000) // 10ms + } + + // If we get here, the task was cancelled + continuation.finish(throwing: CancellationError()) + } + + // When the stream is terminated, cancel our continuous task + // this is not necessary for the test, but it's a good practice + continuation.onTermination = { @Sendable _ in + continuousTask.cancel() + } + } + } +} diff --git a/Tests/ConverseStream/ConverseReplyStreamTests.swift b/Tests/ConverseStream/ConverseReplyStreamTests.swift index b0aeaf9e..1f4bcf6c 100644 --- a/Tests/ConverseStream/ConverseReplyStreamTests.swift +++ b/Tests/ConverseStream/ConverseReplyStreamTests.swift @@ -21,127 +21,19 @@ import Testing @Suite("ConverseReplyStreamTests") struct ConverseReplyStreamTests { - // Helper function to create a simulated stream with a single text block - func createSingleTextBlockStream() -> AsyncThrowingStream { - AsyncThrowingStream { continuation in - // Message start - let messageStartEvent = BedrockRuntimeClientTypes.MessageStartEvent( - role: .assistant - ) - continuation.yield(.messagestart(messageStartEvent)) + let bedrock: BedrockService - // Content block start - let contentBlockStartEvent = BedrockRuntimeClientTypes.ContentBlockStartEvent( - contentBlockIndex: 0, - start: nil - ) - continuation.yield(.contentblockstart(contentBlockStartEvent)) - - // Content block delta (first part) - let contentBlockDelta1 = BedrockRuntimeClientTypes.ContentBlockDelta.text("Hello, ") - let contentBlockDeltaEvent1 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 0, - delta: contentBlockDelta1 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent1)) - - // Content block delta (second part) - let contentBlockDelta2 = BedrockRuntimeClientTypes.ContentBlockDelta.text("this is ") - let contentBlockDeltaEvent2 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 0, - delta: contentBlockDelta2 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent2)) - - // Content block delta (third part) - let contentBlockDelta3 = BedrockRuntimeClientTypes.ContentBlockDelta.text("a test message.") - let contentBlockDeltaEvent3 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 0, - delta: contentBlockDelta3 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent3)) - - // Content block stop - let contentBlockStopEvent = BedrockRuntimeClientTypes.ContentBlockStopEvent( - contentBlockIndex: 0 - ) - continuation.yield(.contentblockstop(contentBlockStopEvent)) - - // Message stop - let messageStopEvent = BedrockRuntimeClientTypes.MessageStopEvent( - additionalModelResponseFields: nil, - stopReason: nil - ) - continuation.yield(.messagestop(messageStopEvent)) - - continuation.finish() - } - } - - // Helper function to create a simulated stream with multiple content blocks - func createMultipleContentBlocksStream() -> AsyncThrowingStream< - BedrockRuntimeClientTypes.ConverseStreamOutput, Error - > { - AsyncThrowingStream { continuation in - // Message start - let messageStartEvent = BedrockRuntimeClientTypes.MessageStartEvent( - role: .assistant - ) - continuation.yield(.messagestart(messageStartEvent)) - - // First content block - let contentBlockStartEvent1 = BedrockRuntimeClientTypes.ContentBlockStartEvent( - contentBlockIndex: 0, - start: nil - ) - continuation.yield(.contentblockstart(contentBlockStartEvent1)) - - let contentBlockDelta1 = BedrockRuntimeClientTypes.ContentBlockDelta.text("First block content.") - let contentBlockDeltaEvent1 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 0, - delta: contentBlockDelta1 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent1)) - - let contentBlockStopEvent1 = BedrockRuntimeClientTypes.ContentBlockStopEvent( - contentBlockIndex: 0 - ) - continuation.yield(.contentblockstop(contentBlockStopEvent1)) - - // Second content block - let contentBlockStartEvent2 = BedrockRuntimeClientTypes.ContentBlockStartEvent( - contentBlockIndex: 1, - start: nil - ) - continuation.yield(.contentblockstart(contentBlockStartEvent2)) - - let contentBlockDelta2 = BedrockRuntimeClientTypes.ContentBlockDelta.text("Second block content.") - let contentBlockDeltaEvent2 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 1, - delta: contentBlockDelta2 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent2)) - - let contentBlockStopEvent2 = BedrockRuntimeClientTypes.ContentBlockStopEvent( - contentBlockIndex: 1 - ) - continuation.yield(.contentblockstop(contentBlockStopEvent2)) - - // Message stop - let messageStopEvent = BedrockRuntimeClientTypes.MessageStopEvent( - additionalModelResponseFields: nil, - stopReason: nil - ) - continuation.yield(.messagestop(messageStopEvent)) - - continuation.finish() - } + init() async throws { + self.bedrock = try await BedrockService( + bedrockClient: MockBedrockClient(), + bedrockRuntimeClient: MockBedrockRuntimeClient() + ) } @Test("Test streaming text response") func testStreamingTextResponse() async throws { // Create the ConverseReplyStream from the simulated stream - let converseReplyStream = ConverseReplyStream(createSingleTextBlockStream()) + let converseReplyStream = try ConverseReplyStream(createSingleTextBlockStream()) // Collect all the stream elements var streamElements: [ConverseStreamElement] = [] @@ -153,52 +45,22 @@ struct ConverseReplyStreamTests { #expect(streamElements.count == 5) // Check content segments - if case .contentSegment(let segment1) = streamElements[0] { - if case .text(let index1, let text1) = segment1 { - #expect(index1 == 0) - #expect(text1 == "Hello, ") - } else { - Issue.record("Expected text segment") - } + if case .messageStart(let segment1) = streamElements[0] { + #expect(segment1 == .assistant) } else { - Issue.record("Expected contentSegment") + Issue.record("Expected messageStart") } - if case .contentSegment(let segment2) = streamElements[1] { - if case .text(let index2, let text2) = segment2 { - #expect(index2 == 0) - #expect(text2 == "this is ") - } else { - Issue.record("Expected text segment") - } + if case .text(let blockId, let textDelta) = streamElements[1] { + #expect(blockId == 0) + #expect(textDelta == "Hello, ") } else { - Issue.record("Expected contentSegment") + Issue.record("Expected text segment") } - if case .contentSegment(let segment3) = streamElements[2] { - if case .text(let index3, let text3) = segment3 { - #expect(index3 == 0) - #expect(text3 == "a test message.") - } else { - Issue.record("Expected text segment") - } - } else { - Issue.record("Expected contentSegment") - } + // no need t test each text delta, let's skip to ful message // Check content block complete - if case .contentBlockComplete(let index, let content) = streamElements[3] { - #expect(index == 0) - if case .text(let text) = content { - #expect(text == "Hello, this is a test message.") - } else { - Issue.record("Expected text content") - } - } else { - Issue.record("Expected contentBlockComplete") - } - - // Check message complete if case .messageComplete(let message) = streamElements[4] { #expect(message.role == .assistant) #expect(message.content.count == 1) @@ -208,14 +70,14 @@ struct ConverseReplyStreamTests { Issue.record("Expected text content in message") } } else { - Issue.record("Expected messageComplete") + Issue.record("Expected a full message") } } @Test("Test multiple content blocks") func testMultipleContentBlocks() async throws { // Create the ConverseReplyStream from the simulated stream - let converseReplyStream = ConverseReplyStream(createMultipleContentBlocksStream()) + let converseReplyStream = try ConverseReplyStream(createMultipleContentBlocksStream()) // Collect all the stream elements var streamElements: [ConverseStreamElement] = [] @@ -224,58 +86,33 @@ struct ConverseReplyStreamTests { } // Verify the stream elements - #expect(streamElements.count == 5) + #expect(streamElements.count == 4) - // Check first content segment - if case .contentSegment(let segment1) = streamElements[0] { - if case .text(let index1, let text1) = segment1 { - #expect(index1 == 0) - #expect(text1 == "First block content.") - } else { - Issue.record("Expected text segment") - } + // Check first event + if case .messageStart(let segment1) = streamElements[0] { + #expect(segment1 == .assistant) } else { - Issue.record("Expected contentSegment") + Issue.record("Expected messageStart") } - // Check first content block complete - if case .contentBlockComplete(let index1, let content1) = streamElements[1] { + // Check first content segment + if case .text(let index1, let content1) = streamElements[1] { #expect(index1 == 0) - if case .text(let text1) = content1 { - #expect(text1 == "First block content.") - } else { - Issue.record("Expected text content") - } + #expect(content1 == "First block content.") } else { Issue.record("Expected contentBlockComplete") } // Check second content segment - if case .contentSegment(let segment2) = streamElements[2] { - if case .text(let index2, let text2) = segment2 { - #expect(index2 == 1) - #expect(text2 == "Second block content.") - } else { - Issue.record("Expected text segment") - } - } else { - Issue.record("Expected contentSegment") - } - - // Check second content block complete - if case .contentBlockComplete(let index2, let content2) = streamElements[3] { - #expect(index2 == 1) - if case .text(let text2) = content2 { - #expect(text2 == "Second block content.") - } else { - Issue.record("Expected text content") - } + if case .text(let index1, let content1) = streamElements[2] { + #expect(index1 == 1) + #expect(content1 == "Second block content.") } else { Issue.record("Expected contentBlockComplete") } // Check message complete - if case .messageComplete(let message) = streamElements[4] { + if case .messageComplete(let message) = streamElements[3] { #expect(message.role == .assistant) #expect(message.content.count == 2) if case .text(let text1) = message.content[0] { @@ -293,69 +130,16 @@ struct ConverseReplyStreamTests { } } - // Helper function to create a never-ending stream that will continue indefinitely - func createNeverEndingStream() -> AsyncThrowingStream { - AsyncThrowingStream { continuation in - // Message start - let messageStartEvent = BedrockRuntimeClientTypes.MessageStartEvent( - role: .assistant - ) - continuation.yield(.messagestart(messageStartEvent)) - - // Content block start - let contentBlockStartEvent = BedrockRuntimeClientTypes.ContentBlockStartEvent( - contentBlockIndex: 0, - start: nil - ) - continuation.yield(.contentblockstart(contentBlockStartEvent)) - - // Set up a counter to track how many deltas we've sent - var counter = 0 - - // Create a Task that will continuously send content block deltas - // This simulates a never-ending stream of tokens from the model - let continuousTask = Task { - while !Task.isCancelled { - // Create a content block delta with a counter to track progress - let text = "Token \(counter) " - let contentBlockDelta = BedrockRuntimeClientTypes.ContentBlockDelta.text(text) - let contentBlockDeltaEvent = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 0, - delta: contentBlockDelta - ) - - // Yield the delta - continuation.yield(.contentblockdelta(contentBlockDeltaEvent)) - - // Increment counter - counter += 1 - - // Add a small delay to avoid overwhelming the system - try await Task.sleep(nanoseconds: 10_000_000) // 10ms - } - - // If we get here, the task was cancelled - continuation.finish(throwing: CancellationError()) - } - - // When the stream is terminated, cancel our continuous task - // this is not necessary for the test, but it's a good practice - continuation.onTermination = { @Sendable _ in - continuousTask.cancel() - } - } - } - @Test("Test cancellation of never-ending stream") func testCancellationOfNeverEndingStream() async throws { // Create the ConverseReplyStream from the simulated never-ending stream - let converseReplyStream = ConverseReplyStream(createNeverEndingStream()) + let converseReplyStream = try ConverseReplyStream(createNeverEndingStream()) // Create a task to consume the stream let consumptionTask = Task { var count = 0 for try await element in converseReplyStream.stream { - if case .contentSegment = element { + if case .text = element { count += 1 } } diff --git a/Tests/ConverseStream/ConverseStreamDocumentTests.swift b/Tests/ConverseStream/ConverseStreamDocumentTests.swift deleted file mode 100644 index 6fafb3db..00000000 --- a/Tests/ConverseStream/ConverseStreamDocumentTests.swift +++ /dev/null @@ -1,106 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift Bedrock Library open source project -// -// Copyright (c) 2025 Amazon.com, Inc. or its affiliates -// and the Swift Bedrock Library project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import Testing - -@testable import BedrockService - -// MARK - Streaming converse document input - -extension BedrockServiceTests { - - @Test("Continue streaming conversation with document") - func converseStreamWithDocument() async throws { - let source = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" - var builder = try ConverseRequestBuilder(with: .nova_lite) - .withPrompt("First prompt") - .withMaxTokens(100) - .withTemperature(0.5) - .withTopP(0.5) - .withStopSequence("\n\nHuman:") - .withSystemPrompt("You are a helpful assistant.") - .withDocument(name: "document", format: .md, source: source) - - #expect(builder.prompt == "First prompt") - #expect(builder.image == nil) - #expect(builder.document != nil) - #expect(builder.maxTokens == 100) - #expect(builder.temperature == 0.5) - #expect(builder.topP == 0.5) - #expect(builder.stopSequences == ["\n\nHuman:"]) - #expect(builder.systemPrompts == ["You are a helpful assistant."]) - - var stream = try await bedrock.converseStream(with: builder) - - // Collect all the stream elements - var streamElements: [ConverseStreamElement] = [] - for try await element in stream { - streamElements.append(element) - } - - // Verify the stream elements - #expect(streamElements.count == 6) - - var message: Message = Message("") - if case .messageComplete(let msg) = streamElements.last { - message = msg - } else { - Issue.record("Expected message") - } - - #expect(message.content.count == 1) - #expect(message.role == .assistant) - - if case .text(let text) = message.content.last { - #expect(text == "Hello, your prompt was: Document received") - } - - builder = try ConverseRequestBuilder(from: builder, with: message) - .withPrompt("Second prompt") - #expect(builder.prompt == "Second prompt") - #expect(builder.image == nil) - #expect(builder.document == nil) - #expect(builder.maxTokens == 100) - #expect(builder.temperature == 0.5) - #expect(builder.topP == 0.5) - #expect(builder.stopSequences == ["\n\nHuman:"]) - #expect(builder.systemPrompts == ["You are a helpful assistant."]) - #expect(builder.history.count == 2) - - stream = try await bedrock.converseStream(with: builder) - // Collect all the stream elements - streamElements = [] - for try await element in stream { - streamElements.append(element) - } - - // Verify the stream elements - #expect(streamElements.count == 6) - - message = Message("") - if case .messageComplete(let msg) = streamElements.last { - message = msg - } else { - Issue.record("Expected message") - } - - #expect(message.content.count == 1) - #expect(message.role == .assistant) - - if case .text(let text) = message.content.last { - #expect(text == "Hello, your prompt was: Second prompt") - } - } -} diff --git a/Tests/ConverseStream/ConverseStreamReasoningTests.swift b/Tests/ConverseStream/ConverseStreamReasoningTests.swift index 920904c0..5d2b8625 100644 --- a/Tests/ConverseStream/ConverseStreamReasoningTests.swift +++ b/Tests/ConverseStream/ConverseStreamReasoningTests.swift @@ -19,7 +19,7 @@ import Testing // MARK - Streaming converse tekst -extension BedrockServiceTests { +extension ConverseReplyStreamTests { @Test("Streaming converse with reasoning") func streamingConverseReasoning() async throws { @@ -35,8 +35,8 @@ extension BedrockServiceTests { #expect(builder.maxReasoningTokens == 4096) #expect(builder.history.count == 0) - var stream = try await bedrock.converseStream(with: builder) - var message: Message = try await validateStream(stream, elementsCount: 6) + var reply = try await bedrock.converseStream(with: builder) + var message: Message = try await validateStream(reply.stream, elementsCount: 6) try checkReasoningContent(message) try checkTextContent(message, prompt: prompt) @@ -51,8 +51,8 @@ extension BedrockServiceTests { #expect(builder.maxReasoningTokens == 4096) #expect(builder.history.count == 2) - stream = try await bedrock.converseStream(with: builder) - message = try await validateStream(stream, elementsCount: 6) + reply = try await bedrock.converseStream(with: builder) + message = try await validateStream(reply.stream, elementsCount: 6) try checkReasoningContent(message) try checkTextContent(message, prompt: prompt) @@ -68,8 +68,8 @@ extension BedrockServiceTests { #expect(builder.maxReasoningTokens == nil) #expect(builder.history.count == 4) - stream = try await bedrock.converseStream(with: builder) - message = try await validateStream(stream, elementsCount: 6, contentCount: 1) + reply = try await bedrock.converseStream(with: builder) + message = try await validateStream(reply.stream, elementsCount: 6, contentCount: 1) try checkTextContent(message, prompt: prompt) try checkReasoningContent(message, hasReasoningContent: false) } diff --git a/Tests/ConverseStream/ConverseStreamTextTests.swift b/Tests/ConverseStream/ConverseStreamTextTests.swift deleted file mode 100644 index 1e33ab40..00000000 --- a/Tests/ConverseStream/ConverseStreamTextTests.swift +++ /dev/null @@ -1,100 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift Bedrock Library open source project -// -// Copyright (c) 2025 Amazon.com, Inc. or its affiliates -// and the Swift Bedrock Library project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import Testing - -@testable import BedrockService - -// MARK - Streaming converse tekst - -extension BedrockServiceTests { - - @Test("Continue conversation reusing builder") - func converseStreamWithReusedBuilder() async throws { - var builder = try ConverseRequestBuilder(with: .nova_lite) - .withPrompt("First prompt") - .withMaxTokens(100) - .withTemperature(0.5) - .withTopP(0.5) - .withStopSequence("\n\nHuman:") - .withSystemPrompt("You are a helpful assistant.") - - #expect(builder.prompt == "First prompt") - #expect(builder.maxTokens == 100) - #expect(builder.temperature == 0.5) - #expect(builder.topP == 0.5) - #expect(builder.stopSequences == ["\n\nHuman:"]) - #expect(builder.systemPrompts == ["You are a helpful assistant."]) - - var stream = try await bedrock.converseStream(with: builder) - - // Collect all the stream elements - var streamElements: [ConverseStreamElement] = [] - for try await element in stream { - streamElements.append(element) - } - - // Verify the stream elements - #expect(streamElements.count == 6) - - var message: Message = Message("") - if case .messageComplete(let msg) = streamElements.last { - message = msg - } else { - Issue.record("Expected message") - } - - #expect(message.content.count == 1) - #expect(message.role == .assistant) - - if case .text(let text) = message.content.last { - #expect(text == "Hello, your prompt was: First prompt") - } - - builder = try ConverseRequestBuilder(from: builder, with: message) - .withPrompt("Second prompt") - #expect(builder.prompt == "Second prompt") - #expect(builder.maxTokens == 100) - #expect(builder.temperature == 0.5) - #expect(builder.topP == 0.5) - #expect(builder.stopSequences == ["\n\nHuman:"]) - #expect(builder.systemPrompts == ["You are a helpful assistant."]) - #expect(builder.history.count == 2) - - stream = try await bedrock.converseStream(with: builder) - // Collect all the stream elements - streamElements = [] - for try await element in stream { - streamElements.append(element) - } - - // Verify the stream elements - #expect(streamElements.count == 6) - - message = Message("") - if case .messageComplete(let msg) = streamElements.last { - message = msg - } else { - Issue.record("Expected message") - } - - #expect(message.content.count == 1) - #expect(message.role == .assistant) - - if case .text(let text) = message.content.last { - #expect(text == "Hello, your prompt was: Second prompt") - } - } -} diff --git a/Tests/ConverseStream/ConverseStreamToolTests.swift b/Tests/ConverseStream/ConverseStreamToolTests.swift index f58d57b4..d02d2f9c 100644 --- a/Tests/ConverseStream/ConverseStreamToolTests.swift +++ b/Tests/ConverseStream/ConverseStreamToolTests.swift @@ -19,7 +19,7 @@ import Testing // MARK - Streaming conversetooluse -extension BedrockServiceTests { +extension ConverseReplyStreamTests { @Test("Continue conversation with tool use") func converseStreamWithToolUse() async throws { let tool = try Tool( @@ -44,11 +44,11 @@ extension BedrockServiceTests { #expect(builder.systemPrompts == ["You are a helpful assistant."]) #expect(builder.tools != nil) - var stream = try await bedrock.converseStream(with: builder) + var reply = try await bedrock.converseStream(with: builder) // Collect all the stream elements var streamElements: [ConverseStreamElement] = [] - for try await element in stream { + for try await element in reply.stream { streamElements.append(element) } @@ -88,10 +88,10 @@ extension BedrockServiceTests { #expect(builder.history.count == 2) #expect(builder.tools != nil) - stream = try await bedrock.converseStream(with: builder) + reply = try await bedrock.converseStream(with: builder) // Collect all the stream elements streamElements = [] - for try await element in stream { + for try await element in reply.stream { streamElements.append(element) } diff --git a/Tests/ConverseStream/ConverseStreamVisionTests.swift b/Tests/ConverseStream/ConverseStreamVisionTests.swift deleted file mode 100644 index d9670b0b..00000000 --- a/Tests/ConverseStream/ConverseStreamVisionTests.swift +++ /dev/null @@ -1,104 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Swift Bedrock Library open source project -// -// Copyright (c) 2025 Amazon.com, Inc. or its affiliates -// and the Swift Bedrock Library project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import Testing - -@testable import BedrockService - -// MARK - Streaming converse vision - -extension BedrockServiceTests { - - @Test("Continue conversation with vision") - func converseStreamWithVision() async throws { - let source = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" - var builder = try ConverseRequestBuilder(with: .nova_lite) - .withPrompt("First prompt") - .withMaxTokens(100) - .withTemperature(0.5) - .withTopP(0.5) - .withStopSequence("\n\nHuman:") - .withSystemPrompt("You are a helpful assistant.") - .withImage(format: .jpeg, source: source) - - #expect(builder.prompt == "First prompt") - #expect(builder.image != nil) - #expect(builder.maxTokens == 100) - #expect(builder.temperature == 0.5) - #expect(builder.topP == 0.5) - #expect(builder.stopSequences == ["\n\nHuman:"]) - #expect(builder.systemPrompts == ["You are a helpful assistant."]) - - var stream = try await bedrock.converseStream(with: builder) - - // Collect all the stream elements - var streamElements: [ConverseStreamElement] = [] - for try await element in stream { - streamElements.append(element) - } - - // Verify the stream elements - #expect(streamElements.count == 6) - - var message: Message = Message("") - if case .messageComplete(let msg) = streamElements.last { - message = msg - } else { - Issue.record("Expected message") - } - - #expect(message.content.count == 1) - #expect(message.role == .assistant) - - if case .text(let text) = message.content.last { - #expect(text == "Hello, your prompt was: Image received") - } - - builder = try ConverseRequestBuilder(from: builder, with: message) - .withPrompt("Second prompt") - #expect(builder.prompt == "Second prompt") - #expect(builder.image == nil) - #expect(builder.maxTokens == 100) - #expect(builder.temperature == 0.5) - #expect(builder.topP == 0.5) - #expect(builder.stopSequences == ["\n\nHuman:"]) - #expect(builder.systemPrompts == ["You are a helpful assistant."]) - #expect(builder.history.count == 2) - - stream = try await bedrock.converseStream(with: builder) - // Collect all the stream elements - streamElements = [] - for try await element in stream { - streamElements.append(element) - } - - // Verify the stream elements - #expect(streamElements.count == 6) - - message = Message("") - if case .messageComplete(let msg) = streamElements.last { - message = msg - } else { - Issue.record("Expected message") - } - - #expect(message.content.count == 1) - #expect(message.role == .assistant) - - if case .text(let text) = message.content.last { - #expect(text == "Hello, your prompt was: Second prompt") - } - } -} diff --git a/Tests/ConverseStream/MockBedrockRuntimeClient+StreamGenerator.swift b/Tests/ConverseStream/MockBedrockRuntimeClient+StreamGenerator.swift new file mode 100644 index 00000000..2984d23d --- /dev/null +++ b/Tests/ConverseStream/MockBedrockRuntimeClient+StreamGenerator.swift @@ -0,0 +1,223 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Bedrock Library open source project +// +// Copyright (c) 2025 Amazon.com, Inc. or its affiliates +// and the Swift Bedrock Library project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of Swift Bedrock Library project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@preconcurrency import AWSBedrockRuntime + +extension MockBedrockRuntimeClient { + // returns "Hello, your prompt was: \(textPrompt)" + package func getTextStream( + _ textPrompt: String = "Streaming Text" + ) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + // Message start + let messageStartEvent = BedrockRuntimeClientTypes.MessageStartEvent( + role: .assistant + ) + continuation.yield(.messagestart(messageStartEvent)) + + // Content block delta (first part) + let contentBlockDelta1 = BedrockRuntimeClientTypes.ContentBlockDelta.text( + "Hello, " + ) + let contentBlockDeltaEvent1 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 0, + delta: contentBlockDelta1 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent1)) + + // Content block delta (second part) + let contentBlockDelta2 = BedrockRuntimeClientTypes.ContentBlockDelta.text( + "your prompt " + ) + let contentBlockDeltaEvent2 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 0, + delta: contentBlockDelta2 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent2)) + + // Content block delta (third part) + let contentBlockDelta3 = BedrockRuntimeClientTypes.ContentBlockDelta.text( + "was: " + ) + let contentBlockDeltaEvent3 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 0, + delta: contentBlockDelta3 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent3)) + + // Content block delta (third part) + let contentBlockDelta4 = BedrockRuntimeClientTypes.ContentBlockDelta.text( + textPrompt + ) + let contentBlockDeltaEvent4 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 0, + delta: contentBlockDelta4 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent4)) + + // Content block stop + let contentBlockStopEvent = BedrockRuntimeClientTypes.ContentBlockStopEvent( + contentBlockIndex: 0 + ) + continuation.yield(.contentblockstop(contentBlockStopEvent)) + + // Message stop + let messageStopEvent = BedrockRuntimeClientTypes.MessageStopEvent( + additionalModelResponseFields: nil, + stopReason: .endTurn + ) + continuation.yield(.messagestop(messageStopEvent)) + + continuation.finish() + } + } + + package func getToolUseStream( + for toolName: String + ) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + // Message start + let messageStartEvent = BedrockRuntimeClientTypes.MessageStartEvent( + role: .assistant + ) + continuation.yield(.messagestart(messageStartEvent)) + + // Content block start + let contentBlockStartEvent = BedrockRuntimeClientTypes.ContentBlockStartEvent( + contentBlockIndex: 0, + start: .tooluse(BedrockRuntimeClientTypes.ToolUseBlockStart(name: toolName, toolUseId: "tooluseid")) + ) + continuation.yield(.contentblockstart(contentBlockStartEvent)) + + // Content block delta + let contentBlockDelta = BedrockRuntimeClientTypes.ContentBlockDelta.tooluse( + BedrockRuntimeClientTypes.ToolUseBlockDelta(input: "{\"key\": \"ABC\"}") + ) + let contentBlockDeltaEvent = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 0, + delta: contentBlockDelta + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent)) + + // Content block stop + let contentBlockStopEvent = BedrockRuntimeClientTypes.ContentBlockStopEvent( + contentBlockIndex: 0 + ) + continuation.yield(.contentblockstop(contentBlockStopEvent)) + + // Message stop + let messageStopEvent = BedrockRuntimeClientTypes.MessageStopEvent( + additionalModelResponseFields: nil, + stopReason: nil + ) + continuation.yield(.messagestop(messageStopEvent)) + + continuation.finish() + } + } + + package func getReasoningStream( + _ textPrompt: String = "Streaming Text" + ) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + // Message start + let messageStartEvent = BedrockRuntimeClientTypes.MessageStartEvent( + role: .assistant + ) + continuation.yield(.messagestart(messageStartEvent)) + + // Content block delta (reasoning - first part) + let contentBlockDeltaReasoning1 = BedrockRuntimeClientTypes.ContentBlockDelta.reasoningcontent( + BedrockRuntimeClientTypes.ReasoningContentBlockDelta.text("reasoning ") + ) + let contentBlockDeltaReasoningEvent1 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 0, + delta: contentBlockDeltaReasoning1 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaReasoningEvent1)) + + // Content block delta (reasoning - second part) + let contentBlockDeltaReasoning2 = BedrockRuntimeClientTypes.ContentBlockDelta.reasoningcontent( + BedrockRuntimeClientTypes.ReasoningContentBlockDelta.text("text ") + ) + let contentBlockDeltaReasoningEvent2 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 0, + delta: contentBlockDeltaReasoning2 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaReasoningEvent2)) + + // Content block stop + let contentBlockStopEvent = BedrockRuntimeClientTypes.ContentBlockStopEvent( + contentBlockIndex: 0 + ) + continuation.yield(.contentblockstop(contentBlockStopEvent)) + + // Content block delta (first part) + let contentBlockDelta1 = BedrockRuntimeClientTypes.ContentBlockDelta.text( + "Hello, " + ) + let contentBlockDeltaEvent1 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 1, + delta: contentBlockDelta1 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent1)) + + // Content block delta (second part) + let contentBlockDelta2 = BedrockRuntimeClientTypes.ContentBlockDelta.text( + "your prompt " + ) + let contentBlockDeltaEvent2 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 1, + delta: contentBlockDelta2 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent2)) + + // Content block delta (third part) + let contentBlockDelta3 = BedrockRuntimeClientTypes.ContentBlockDelta.text( + "was: " + ) + let contentBlockDeltaEvent3 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 1, + delta: contentBlockDelta3 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent3)) + + // Content block delta (third part) + let contentBlockDelta4 = BedrockRuntimeClientTypes.ContentBlockDelta.text( + textPrompt + ) + let contentBlockDeltaEvent4 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( + contentBlockIndex: 1, + delta: contentBlockDelta4 + ) + continuation.yield(.contentblockdelta(contentBlockDeltaEvent4)) + + // Content block stop + let contentBlockStopEvent1 = BedrockRuntimeClientTypes.ContentBlockStopEvent( + contentBlockIndex: 1 + ) + continuation.yield(.contentblockstop(contentBlockStopEvent1)) + + // Message stop + let messageStopEvent = BedrockRuntimeClientTypes.MessageStopEvent( + additionalModelResponseFields: nil, + stopReason: nil + ) + continuation.yield(.messagestop(messageStopEvent)) + + continuation.finish() + } + } +} diff --git a/Tests/Mock/MockBedrockRuntimeClient.swift b/Tests/Mock/MockBedrockRuntimeClient.swift index d53b0578..4ff289a3 100644 --- a/Tests/Mock/MockBedrockRuntimeClient.swift +++ b/Tests/Mock/MockBedrockRuntimeClient.swift @@ -79,232 +79,6 @@ public struct MockBedrockRuntimeClient: BedrockRuntimeClientProtocol { return ConverseStreamOutput(stream: stream) } - // returns "Hello, your prompt was: \(textPrompt)" - private func getTextStream( - _ textPrompt: String = "Streaming Text" - ) -> AsyncThrowingStream { - AsyncThrowingStream { continuation in - // Message start - let messageStartEvent = BedrockRuntimeClientTypes.MessageStartEvent( - role: .assistant - ) - continuation.yield(.messagestart(messageStartEvent)) - - // Content block start - let contentBlockStartEvent = BedrockRuntimeClientTypes.ContentBlockStartEvent( - contentBlockIndex: 0, - start: nil - ) - continuation.yield(.contentblockstart(contentBlockStartEvent)) - - // Content block delta (first part) - let contentBlockDelta1 = BedrockRuntimeClientTypes.ContentBlockDelta.text( - "Hello, " - ) - let contentBlockDeltaEvent1 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 0, - delta: contentBlockDelta1 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent1)) - - // Content block delta (second part) - let contentBlockDelta2 = BedrockRuntimeClientTypes.ContentBlockDelta.text( - "your prompt " - ) - let contentBlockDeltaEvent2 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 0, - delta: contentBlockDelta2 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent2)) - - // Content block delta (third part) - let contentBlockDelta3 = BedrockRuntimeClientTypes.ContentBlockDelta.text( - "was: " - ) - let contentBlockDeltaEvent3 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 0, - delta: contentBlockDelta3 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent3)) - - // Content block delta (third part) - let contentBlockDelta4 = BedrockRuntimeClientTypes.ContentBlockDelta.text( - textPrompt - ) - let contentBlockDeltaEvent4 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 0, - delta: contentBlockDelta4 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent4)) - - // Content block stop - let contentBlockStopEvent = BedrockRuntimeClientTypes.ContentBlockStopEvent( - contentBlockIndex: 0 - ) - continuation.yield(.contentblockstop(contentBlockStopEvent)) - - // Message stop - let messageStopEvent = BedrockRuntimeClientTypes.MessageStopEvent( - additionalModelResponseFields: nil, - stopReason: nil - ) - continuation.yield(.messagestop(messageStopEvent)) - - continuation.finish() - } - } - - private func getToolUseStream( - for toolName: String - ) -> AsyncThrowingStream { - AsyncThrowingStream { continuation in - // Message start - let messageStartEvent = BedrockRuntimeClientTypes.MessageStartEvent( - role: .assistant - ) - continuation.yield(.messagestart(messageStartEvent)) - - // Content block start - let contentBlockStartEvent = BedrockRuntimeClientTypes.ContentBlockStartEvent( - contentBlockIndex: 0, - start: .tooluse(BedrockRuntimeClientTypes.ToolUseBlockStart(name: toolName, toolUseId: "tooluseid")) - ) - continuation.yield(.contentblockstart(contentBlockStartEvent)) - - // Content block delta - let contentBlockDelta = BedrockRuntimeClientTypes.ContentBlockDelta.tooluse( - BedrockRuntimeClientTypes.ToolUseBlockDelta(input: "{\"key\": \"ABC\"}") - ) - let contentBlockDeltaEvent = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 0, - delta: contentBlockDelta - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent)) - - // Content block stop - let contentBlockStopEvent = BedrockRuntimeClientTypes.ContentBlockStopEvent( - contentBlockIndex: 0 - ) - continuation.yield(.contentblockstop(contentBlockStopEvent)) - - // Message stop - let messageStopEvent = BedrockRuntimeClientTypes.MessageStopEvent( - additionalModelResponseFields: nil, - stopReason: nil - ) - continuation.yield(.messagestop(messageStopEvent)) - - continuation.finish() - } - } - - private func getReasoningStream( - _ textPrompt: String = "Streaming Text" - ) -> AsyncThrowingStream { - AsyncThrowingStream { continuation in - // Message start - let messageStartEvent = BedrockRuntimeClientTypes.MessageStartEvent( - role: .assistant - ) - continuation.yield(.messagestart(messageStartEvent)) - - // Content block start 0 - let contentBlockStartEvent = BedrockRuntimeClientTypes.ContentBlockStartEvent( - contentBlockIndex: 0, - start: nil - ) - continuation.yield(.contentblockstart(contentBlockStartEvent)) - - // Content block delta (reasoning - first part) - let contentBlockDeltaReasoning1 = BedrockRuntimeClientTypes.ContentBlockDelta.reasoningcontent( - BedrockRuntimeClientTypes.ReasoningContentBlockDelta.text("reasoning ") - ) - let contentBlockDeltaReasoningEvent1 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 0, - delta: contentBlockDeltaReasoning1 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaReasoningEvent1)) - - // Content block delta (reasoning - second part) - let contentBlockDeltaReasoning2 = BedrockRuntimeClientTypes.ContentBlockDelta.reasoningcontent( - BedrockRuntimeClientTypes.ReasoningContentBlockDelta.text("text ") - ) - let contentBlockDeltaReasoningEvent2 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 0, - delta: contentBlockDeltaReasoning2 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaReasoningEvent2)) - - // Content block stop - let contentBlockStopEvent = BedrockRuntimeClientTypes.ContentBlockStopEvent( - contentBlockIndex: 0 - ) - continuation.yield(.contentblockstop(contentBlockStopEvent)) - - // Content block start 1 - let contentBlockStartEvent1 = BedrockRuntimeClientTypes.ContentBlockStartEvent( - contentBlockIndex: 1, - start: nil - ) - continuation.yield(.contentblockstart(contentBlockStartEvent1)) - - // Content block delta (first part) - let contentBlockDelta1 = BedrockRuntimeClientTypes.ContentBlockDelta.text( - "Hello, " - ) - let contentBlockDeltaEvent1 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 1, - delta: contentBlockDelta1 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent1)) - - // Content block delta (second part) - let contentBlockDelta2 = BedrockRuntimeClientTypes.ContentBlockDelta.text( - "your prompt " - ) - let contentBlockDeltaEvent2 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 1, - delta: contentBlockDelta2 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent2)) - - // Content block delta (third part) - let contentBlockDelta3 = BedrockRuntimeClientTypes.ContentBlockDelta.text( - "was: " - ) - let contentBlockDeltaEvent3 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 1, - delta: contentBlockDelta3 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent3)) - - // Content block delta (third part) - let contentBlockDelta4 = BedrockRuntimeClientTypes.ContentBlockDelta.text( - textPrompt - ) - let contentBlockDeltaEvent4 = BedrockRuntimeClientTypes.ContentBlockDeltaEvent( - contentBlockIndex: 1, - delta: contentBlockDelta4 - ) - continuation.yield(.contentblockdelta(contentBlockDeltaEvent4)) - - // Content block stop - let contentBlockStopEvent1 = BedrockRuntimeClientTypes.ContentBlockStopEvent( - contentBlockIndex: 1 - ) - continuation.yield(.contentblockstop(contentBlockStopEvent1)) - - // Message stop - let messageStopEvent = BedrockRuntimeClientTypes.MessageStopEvent( - additionalModelResponseFields: nil, - stopReason: nil - ) - continuation.yield(.messagestop(messageStopEvent)) - - continuation.finish() - } - } - // MARK: converse public func converse(input: ConverseInput) async throws -> ConverseOutput { guard let messages = input.messages,