Skip to content

Commit 562aa88

Browse files
committed
minor API update
1 parent 12df935 commit 562aa88

File tree

5 files changed

+46
-36
lines changed

5 files changed

+46
-36
lines changed

Sources/Converse/BedrockService+ConverseStreaming.swift

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ extension BedrockService {
3434
/// BedrockLibraryError.invalidPrompt if the prompt is empty or too long
3535
/// BedrockLibraryError.invalidModality for invalid modality from the selected model
3636
/// BedrockLibraryError.invalidSDKResponse if the response body is missing
37-
/// - Returns: A stream of ConverseResponseStreaming objects
37+
/// - Returns: A ConverseReplyStream object that gives access to the high-level stream of ConverseStreamElements objects
38+
/// or the low-level stream provided by the AWS SDK.
3839
public func converseStream(
3940
with model: BedrockModel,
4041
conversation: [Message],
@@ -46,7 +47,7 @@ extension BedrockService {
4647
tools: [Tool]? = nil,
4748
enableReasoning: Bool? = false,
4849
maxReasoningTokens: Int? = nil
49-
) async throws -> AsyncThrowingStream<ConverseStreamElement, any Error> {
50+
) async throws -> ConverseReplyStream {
5051
do {
5152
guard model.hasConverseStreamingModality() else {
5253
throw BedrockLibraryError.invalidModality(
@@ -118,18 +119,18 @@ extension BedrockService {
118119
// - message metadata
119120
// see https://github.com/awslabs/aws-sdk-swift/blob/2697fb44f607b9c43ad0ce5ca79867d8d6c545c2/Sources/Services/AWSBedrockRuntime/Sources/AWSBedrockRuntime/Models.swift#L3478
120121
// it will be the responsibility of the user to handle the stream and re-assemble the messages and content
121-
// TODO: should we expose the SDK ConverseStreamOutput from the SDK ? or wrap it (what's the added value) ?
122122

123-
let reply = ConverseReplyStream(sdkStream)
123+
let reply = try ConverseReplyStream(sdkStream)
124124

125125
// this time, a different stream is created from the previous one, this one has the following elements
126-
// - content segment: this contains a ContentSegment, an enum which can be a .text(Int, String),
127-
// the integer is the id for the content block that the content segment is a part of,
128-
// the String is the part of text that is send from the model.
129-
// - content block complete: this includes the id of the completed content block and the complete content block itself
126+
// - messageStart: this is the start of a message, it contains the role (assistant or user)
127+
// - text: this is a delta of the text content, it contains the partial text
128+
// - reasoning: this is a delta of the reasoning content, it contains the partial reasoning text
129+
// - toolUse: this is a buffered tool use response, it contains the tool name and id, and the input parameters
130130
// - message complete: this includes the complete Message, ready to be added to the history and used for future requests
131+
// - metaData: this is the metadata about the response, it contains statitics about the response, such as the number of tokens used and the latency
131132

132-
return reply.stream
133+
return reply
133134

134135
} catch {
135136
try handleCommonError(error, context: "invoking converse stream")
@@ -143,7 +144,7 @@ extension BedrockService {
143144
/// - Returns: A stream of ConverseResponseStreaming objects
144145
public func converseStream(
145146
with builder: ConverseRequestBuilder
146-
) async throws -> AsyncThrowingStream<ConverseStreamElement, any Error> {
147+
) async throws -> ConverseReplyStream {
147148
logger.trace("Conversing and streaming")
148149
do {
149150
var history = builder.history

Sources/Converse/Streaming/ConverseReplyStream.swift

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,24 @@ public struct ConverseReplyStream: Sendable {
3030
package init(
3131
_ inputStream: AsyncThrowingStream<BedrockRuntimeClientTypes.ConverseStreamOutput, Error>,
3232
logger: Logger? = nil
33-
) {
33+
) throws {
3434

3535
self.logger = logger ?? .init(label: "ConverseReplyStream")
3636

3737
// store the sdk-provided stream to expose it to developers if needed
3838
self.sdkStream = inputStream
3939

4040
// build a new stream that will convert the SDK stream output to our own ConverseStreamElement
41-
self.stream = AsyncThrowingStream(ConverseStreamElement.self) { continuation in
41+
self.stream = try ConverseReplyStream.convertToLibraryStream(inputStream, logger: self.logger)
42+
}
43+
44+
/// Convert the SDK Stream to a highler level stream of ConverseStreamElement
45+
private static func convertToLibraryStream(
46+
_ inputStream: AsyncThrowingStream<BedrockRuntimeClientTypes.ConverseStreamOutput, Error>,
47+
logger: Logger
48+
) throws -> AsyncThrowingStream<ConverseStreamElement, Error> {
49+
50+
AsyncThrowingStream(ConverseStreamElement.self) { continuation in
4251
let t = Task {
4352
do {
4453
var state: StreamState!
@@ -48,7 +57,7 @@ public struct ConverseReplyStream: Sendable {
4857

4958
switch output {
5059
case .messagestart(let event):
51-
logger?.trace("Message Start", metadata: ["event": "\(event)"])
60+
logger.trace("Message Start", metadata: ["event": "\(event)"])
5261

5362
guard let sdkRole = event.role,
5463
let role = try? Role(from: sdkRole)
@@ -62,7 +71,7 @@ public struct ConverseReplyStream: Sendable {
6271
// only received at the start of a tool use block
6372
// https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html#conversation-inference-call-response
6473
case .contentblockstart(let event):
65-
logger?.trace("Content Block Start")
74+
logger.trace("Content Block Start")
6675
guard state.currentBlockId == -1 else {
6776
// If we already have a block started, this is an error
6877
throw BedrockLibraryError.invalidSDKType(
@@ -79,7 +88,7 @@ public struct ConverseReplyStream: Sendable {
7988
// do not yield an event here, wait for full ToolUse block to arrive
8089

8190
case .contentblockdelta(let event):
82-
logger?.trace("Content Block Delta")
91+
logger.trace("Content Block Delta")
8392
guard let blockId = event.contentBlockIndex else {
8493
// when there is no blockId, this is an error
8594
throw BedrockLibraryError.invalidSDKType(
@@ -114,22 +123,22 @@ public struct ConverseReplyStream: Sendable {
114123
state.bufferReasoningData.append(redactedContent)
115124
// do not yield partial reasoning data, wait for full JSON data
116125
case .sdkUnknown(let output):
117-
logger?.warning(
126+
logger.warning(
118127
"Received unknown SDK Reasoning Delta",
119128
metadata: ["reasoning delta": "\(output)"]
120129
)
121130
}
122131
case .sdkUnknown(let output):
123-
logger?.warning(
132+
logger.warning(
124133
"Received unknown SDK Event Delta",
125134
metadata: ["delta": "\(output)"]
126135
)
127136
case .none:
128-
logger?.warning("Received none SDK Event Delta")
137+
logger.warning("Received none SDK Event Delta")
129138
}
130139

131140
case .contentblockstop(let event):
132-
logger?.trace("Content Block Stop")
141+
logger.trace("Content Block Stop")
133142
guard state.currentBlockId != -1 else {
134143
// If we don't have a block started, this is an error
135144
throw BedrockLibraryError.invalidSDKType(
@@ -165,7 +174,7 @@ public struct ConverseReplyStream: Sendable {
165174
state.currentBlockId = -1
166175

167176
case .messagestop(let event):
168-
logger?.trace("Message Stop")
177+
logger.trace("Message Stop")
169178
state.messageComplete = true
170179

171180
// create a Message with all content blocks
@@ -177,7 +186,7 @@ public struct ConverseReplyStream: Sendable {
177186
continuation.yield(.messageComplete(message))
178187

179188
case .metadata(let event):
180-
logger?.trace("Metadata", metadata: ["event": "\(event)"])
189+
logger.trace("Metadata", metadata: ["event": "\(event)"])
181190

182191
// Convert the metadata event to our ResponseMetadata type
183192
let metadata = try ResponseMetadata(from: event)
@@ -187,7 +196,7 @@ public struct ConverseReplyStream: Sendable {
187196
// Handle unknown SDK output
188197
// This is a catch-all for any future SDK output types that we don't handle yet
189198
// We log it and continue, but we could also throw an error if desired
190-
logger?.warning(
199+
logger.warning(
191200
"Received unknown SDK ConverseStreamOutput",
192201
metadata: ["output": "\(output)"]
193202
)
@@ -214,7 +223,7 @@ public struct ConverseReplyStream: Sendable {
214223
t.cancel() // Cancel the task when the stream is terminated
215224
}
216225
}
217-
}
226+
}
218227
}
219228

220229
/// Flushes and processes the buffered content from the stream state

Tests/ConverseStream/ConverseReplyStreamTests.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct ConverseReplyStreamTests {
3333
@Test("Test streaming text response")
3434
func testStreamingTextResponse() async throws {
3535
// Create the ConverseReplyStream from the simulated stream
36-
let converseReplyStream = ConverseReplyStream(createSingleTextBlockStream())
36+
let converseReplyStream = try ConverseReplyStream(createSingleTextBlockStream())
3737

3838
// Collect all the stream elements
3939
var streamElements: [ConverseStreamElement] = []
@@ -77,7 +77,7 @@ struct ConverseReplyStreamTests {
7777
@Test("Test multiple content blocks")
7878
func testMultipleContentBlocks() async throws {
7979
// Create the ConverseReplyStream from the simulated stream
80-
let converseReplyStream = ConverseReplyStream(createMultipleContentBlocksStream())
80+
let converseReplyStream = try ConverseReplyStream(createMultipleContentBlocksStream())
8181

8282
// Collect all the stream elements
8383
var streamElements: [ConverseStreamElement] = []
@@ -133,7 +133,7 @@ struct ConverseReplyStreamTests {
133133
@Test("Test cancellation of never-ending stream")
134134
func testCancellationOfNeverEndingStream() async throws {
135135
// Create the ConverseReplyStream from the simulated never-ending stream
136-
let converseReplyStream = ConverseReplyStream(createNeverEndingStream())
136+
let converseReplyStream = try ConverseReplyStream(createNeverEndingStream())
137137

138138
// Create a task to consume the stream
139139
let consumptionTask = Task {

Tests/ConverseStream/ConverseStreamReasoningTests.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ extension ConverseReplyStreamTests {
3535
#expect(builder.maxReasoningTokens == 4096)
3636
#expect(builder.history.count == 0)
3737

38-
var stream = try await bedrock.converseStream(with: builder)
39-
var message: Message = try await validateStream(stream, elementsCount: 6)
38+
var reply = try await bedrock.converseStream(with: builder)
39+
var message: Message = try await validateStream(reply.stream, elementsCount: 6)
4040

4141
try checkReasoningContent(message)
4242
try checkTextContent(message, prompt: prompt)
@@ -51,8 +51,8 @@ extension ConverseReplyStreamTests {
5151
#expect(builder.maxReasoningTokens == 4096)
5252
#expect(builder.history.count == 2)
5353

54-
stream = try await bedrock.converseStream(with: builder)
55-
message = try await validateStream(stream, elementsCount: 6)
54+
reply = try await bedrock.converseStream(with: builder)
55+
message = try await validateStream(reply.stream, elementsCount: 6)
5656

5757
try checkReasoningContent(message)
5858
try checkTextContent(message, prompt: prompt)
@@ -68,8 +68,8 @@ extension ConverseReplyStreamTests {
6868
#expect(builder.maxReasoningTokens == nil)
6969
#expect(builder.history.count == 4)
7070

71-
stream = try await bedrock.converseStream(with: builder)
72-
message = try await validateStream(stream, elementsCount: 6, contentCount: 1)
71+
reply = try await bedrock.converseStream(with: builder)
72+
message = try await validateStream(reply.stream, elementsCount: 6, contentCount: 1)
7373
try checkTextContent(message, prompt: prompt)
7474
try checkReasoningContent(message, hasReasoningContent: false)
7575
}

Tests/ConverseStream/ConverseStreamToolTests.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ extension ConverseReplyStreamTests {
4444
#expect(builder.systemPrompts == ["You are a helpful assistant."])
4545
#expect(builder.tools != nil)
4646

47-
var stream = try await bedrock.converseStream(with: builder)
47+
var reply = try await bedrock.converseStream(with: builder)
4848

4949
// Collect all the stream elements
5050
var streamElements: [ConverseStreamElement] = []
51-
for try await element in stream {
51+
for try await element in reply.stream {
5252
streamElements.append(element)
5353
}
5454

@@ -88,10 +88,10 @@ extension ConverseReplyStreamTests {
8888
#expect(builder.history.count == 2)
8989
#expect(builder.tools != nil)
9090

91-
stream = try await bedrock.converseStream(with: builder)
91+
reply = try await bedrock.converseStream(with: builder)
9292
// Collect all the stream elements
9393
streamElements = []
94-
for try await element in stream {
94+
for try await element in reply.stream {
9595
streamElements.append(element)
9696
}
9797

0 commit comments

Comments
 (0)