Skip to content

Commit 63fb254

Browse files
author
Kamil Sobol
authored
Include accumulated turn content in chunk mutation (#2149)
1 parent 10ef35d commit 63fb254

12 files changed

+249
-179
lines changed

.changeset/sour-seahorses-walk.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@aws-amplify/ai-constructs': minor
3+
---
4+
5+
Include accumulated turn content in chunk mutation

package-lock.json

Lines changed: 22 additions & 22 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,46 +105,93 @@ void describe('Bedrock converse adapter', () => {
105105
// See mockConverseStreamCommandOutput below of how split chunks are mocked.
106106
assert.deepStrictEqual(chunks, [
107107
{
108+
accumulatedTurnContent: [
109+
{
110+
text: 'b',
111+
},
112+
],
108113
conversationId: event.conversationId,
109114
associatedUserMessageId: event.currentMessageId,
110115
contentBlockText: 'b',
111116
contentBlockIndex: 0,
112117
contentBlockDeltaIndex: 0,
113118
},
114119
{
120+
accumulatedTurnContent: [
121+
{
122+
text: 'block1',
123+
},
124+
],
115125
conversationId: event.conversationId,
116126
associatedUserMessageId: event.currentMessageId,
117127
contentBlockText: 'lock1',
118128
contentBlockIndex: 0,
119129
contentBlockDeltaIndex: 1,
120130
},
121131
{
132+
accumulatedTurnContent: [
133+
{
134+
text: 'block1',
135+
},
136+
],
122137
conversationId: event.conversationId,
123138
associatedUserMessageId: event.currentMessageId,
124139
contentBlockIndex: 0,
125140
contentBlockDoneAtIndex: 1,
126141
},
127142
{
143+
accumulatedTurnContent: [
144+
{
145+
text: 'block1',
146+
},
147+
{
148+
text: 'b',
149+
},
150+
],
128151
conversationId: event.conversationId,
129152
associatedUserMessageId: event.currentMessageId,
130153
contentBlockText: 'b',
131154
contentBlockIndex: 1,
132155
contentBlockDeltaIndex: 0,
133156
},
134157
{
158+
accumulatedTurnContent: [
159+
{
160+
text: 'block1',
161+
},
162+
{
163+
text: 'block2',
164+
},
165+
],
135166
conversationId: event.conversationId,
136167
associatedUserMessageId: event.currentMessageId,
137168
contentBlockText: 'lock2',
138169
contentBlockIndex: 1,
139170
contentBlockDeltaIndex: 1,
140171
},
141172
{
173+
accumulatedTurnContent: [
174+
{
175+
text: 'block1',
176+
},
177+
{
178+
text: 'block2',
179+
},
180+
],
142181
conversationId: event.conversationId,
143182
associatedUserMessageId: event.currentMessageId,
144183
contentBlockIndex: 1,
145184
contentBlockDoneAtIndex: 1,
146185
},
147186
{
187+
accumulatedTurnContent: [
188+
{
189+
text: 'block1',
190+
},
191+
{
192+
text: 'block2',
193+
},
194+
],
148195
conversationId: event.conversationId,
149196
associatedUserMessageId: event.currentMessageId,
150197
contentBlockIndex: 1,
@@ -648,12 +695,14 @@ void describe('Bedrock converse adapter', () => {
648695
await askBedrockWithStreaming(adapter);
649696
assert.deepStrictEqual(chunks, [
650697
{
698+
accumulatedTurnContent: [{ toolUse: clientToolUse }],
651699
conversationId: event.conversationId,
652700
associatedUserMessageId: event.currentMessageId,
653701
contentBlockIndex: 0,
654702
contentBlockToolUse: JSON.stringify({ toolUse: clientToolUse }),
655703
},
656704
{
705+
accumulatedTurnContent: [{ toolUse: clientToolUse }],
657706
conversationId: event.conversationId,
658707
associatedUserMessageId: event.currentMessageId,
659708
contentBlockIndex: 0,

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {
2020
} from './types.js';
2121
import { ConversationTurnEventToolsProvider } from './event-tools-provider';
2222
import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever';
23+
import * as bedrock from '@aws-sdk/client-bedrock-runtime';
2324

2425
/**
2526
* This class is responsible for interacting with Bedrock Converse API
@@ -173,6 +174,9 @@ export class BedrockConverseAdapter {
173174
let blockIndex = 0;
174175
let lastBlockIndex = 0;
175176
let stopReason = '';
177+
// Accumulates client facing content per turn.
178+
// So that upstream can persist full message at the end of the streaming.
179+
const accumulatedTurnContent: Array<bedrock.ContentBlock> = [];
176180
do {
177181
const toolConfig = this.createToolConfiguration();
178182
const converseCommandInput: ConverseStreamCommandInput = {
@@ -202,10 +206,12 @@ export class BedrockConverseAdapter {
202206
let toolUseInput: string = '';
203207
let blockDeltaIndex = 0;
204208
let lastBlockDeltaIndex = 0;
209+
// Accumulate current message for the tool use loop purpose.
205210
const accumulatedAssistantMessage: Message = {
206211
role: undefined,
207212
content: [],
208213
};
214+
209215
for await (const chunk of bedrockResponse.stream) {
210216
this.logger.debug('Bedrock Converse Stream response chunk:', chunk);
211217
if (chunk.messageStart) {
@@ -230,6 +236,7 @@ export class BedrockConverseAdapter {
230236
} else if (chunk.contentBlockDelta.delta?.text) {
231237
text += chunk.contentBlockDelta.delta.text;
232238
yield {
239+
accumulatedTurnContent: [...accumulatedTurnContent, { text }],
233240
conversationId: this.event.conversationId,
234241
associatedUserMessageId: this.event.currentMessageId,
235242
contentBlockText: chunk.contentBlockDelta.delta.text,
@@ -248,7 +255,9 @@ export class BedrockConverseAdapter {
248255
this.clientToolByName.has(toolUseBlock.toolUse.name)
249256
) {
250257
clientToolsRequested = true;
258+
accumulatedTurnContent.push(toolUseBlock);
251259
yield {
260+
accumulatedTurnContent: [...accumulatedTurnContent],
252261
conversationId: this.event.conversationId,
253262
associatedUserMessageId: this.event.currentMessageId,
254263
contentBlockIndex: blockIndex,
@@ -263,7 +272,9 @@ export class BedrockConverseAdapter {
263272
accumulatedAssistantMessage.content?.push({
264273
text,
265274
});
275+
accumulatedTurnContent.push({ text });
266276
yield {
277+
accumulatedTurnContent: [...accumulatedTurnContent],
267278
conversationId: this.event.conversationId,
268279
associatedUserMessageId: this.event.currentMessageId,
269280
contentBlockIndex: blockIndex,
@@ -285,6 +296,7 @@ export class BedrockConverseAdapter {
285296
// For now if any of client tools is used we ignore executable tools
286297
// and propagate result back to client.
287298
yield {
299+
accumulatedTurnContent: [...accumulatedTurnContent],
288300
conversationId: this.event.conversationId,
289301
associatedUserMessageId: this.event.currentMessageId,
290302
contentBlockIndex: lastBlockIndex,
@@ -313,6 +325,7 @@ export class BedrockConverseAdapter {
313325
} while (stopReason === 'tool_use');
314326

315327
yield {
328+
accumulatedTurnContent: [...accumulatedTurnContent],
316329
conversationId: this.event.conversationId,
317330
associatedUserMessageId: this.event.currentMessageId,
318331
contentBlockIndex: lastBlockIndex,

0 commit comments

Comments
 (0)