Skip to content

Commit 6e4a62f

Browse files
author
Kamil Sobol
authored
Fix case when multiple tools are requested by Bedrock. (#2083)
* Limit tool usage in conversation turn * Limit tool usage in conversation turn
1 parent 45fe8b7 commit 6e4a62f

File tree

6 files changed

+95
-42
lines changed

6 files changed

+95
-42
lines changed

.changeset/tender-camels-agree.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@aws-amplify/ai-constructs': patch
3+
---
4+
5+
Fix multi tool usage in single turn.

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,14 @@ void describe('Bedrock converse adapter', () => {
174174
toolUse: {
175175
toolUseId: randomUUID().toString(),
176176
name: additionalTool.name,
177-
input: 'additionalToolInput',
177+
input: 'additionalToolInput1',
178+
},
179+
},
180+
{
181+
toolUse: {
182+
toolUseId: randomUUID().toString(),
183+
name: additionalTool.name,
184+
input: 'additionalToolInput2',
178185
},
179186
},
180187
],
@@ -195,7 +202,14 @@ void describe('Bedrock converse adapter', () => {
195202
toolUse: {
196203
toolUseId: randomUUID().toString(),
197204
name: eventTool.name,
198-
input: 'eventToolToolInput',
205+
input: 'eventToolToolInput1',
206+
},
207+
},
208+
{
209+
toolUse: {
210+
toolUseId: randomUUID().toString(),
211+
name: eventTool.name,
212+
input: 'eventToolToolInput2',
199213
},
200214
},
201215
],
@@ -289,6 +303,10 @@ void describe('Bedrock converse adapter', () => {
289303
additionalToolUseBedrockResponse.output?.message?.content[0].toolUse
290304
?.toolUseId
291305
);
306+
assert.ok(
307+
additionalToolUseBedrockResponse.output?.message?.content[1].toolUse
308+
?.toolUseId
309+
);
292310
const expectedBedrockInput2: ConverseCommandInput = {
293311
messages: [
294312
...(messages as Array<Message>),
@@ -305,6 +323,15 @@ void describe('Bedrock converse adapter', () => {
305323
.toolUse.toolUseId,
306324
},
307325
},
326+
{
327+
toolResult: {
328+
content: [additionalToolOutput],
329+
status: 'success',
330+
toolUseId:
331+
additionalToolUseBedrockResponse.output?.message.content[1]
332+
.toolUse.toolUseId,
333+
},
334+
},
308335
],
309336
},
310337
],
@@ -317,6 +344,9 @@ void describe('Bedrock converse adapter', () => {
317344
assert.ok(
318345
eventToolUseBedrockResponse.output?.message?.content[0].toolUse?.toolUseId
319346
);
347+
assert.ok(
348+
eventToolUseBedrockResponse.output?.message?.content[1].toolUse?.toolUseId
349+
);
320350
assert.ok(expectedBedrockInput2.messages);
321351
const expectedBedrockInput3: ConverseCommandInput = {
322352
messages: [
@@ -334,6 +364,15 @@ void describe('Bedrock converse adapter', () => {
334364
.toolUseId,
335365
},
336366
},
367+
{
368+
toolResult: {
369+
content: [eventToolOutput],
370+
status: 'success',
371+
toolUseId:
372+
eventToolUseBedrockResponse.output?.message.content[1].toolUse
373+
.toolUseId,
374+
},
375+
},
337376
],
338377
},
339378
],

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,17 @@ export class BedrockConverseAdapter {
121121
// and propagate result back to client.
122122
return clientToolUseBlocks;
123123
}
124+
const toolResponseContentBlocks: Array<ContentBlock> = [];
124125
for (const responseContentBlock of toolUseBlocks) {
125126
const toolUseBlock =
126127
responseContentBlock as ContentBlock.ToolUseMember;
127-
const toolMessage = await this.executeTool(toolUseBlock);
128-
messages.push(toolMessage);
128+
const toolResultContentBlock = await this.executeTool(toolUseBlock);
129+
toolResponseContentBlocks.push(toolResultContentBlock);
129130
}
131+
messages.push({
132+
role: 'user',
133+
content: toolResponseContentBlocks,
134+
});
130135
}
131136
} while (bedrockResponse.stopReason === 'tool_use');
132137

@@ -191,7 +196,7 @@ export class BedrockConverseAdapter {
191196

192197
private executeTool = async (
193198
toolUseBlock: ContentBlock.ToolUseMember
194-
): Promise<Message> => {
199+
): Promise<ContentBlock> => {
195200
if (!toolUseBlock.toolUse.name) {
196201
throw Error('Bedrock tool use response is missing a tool name');
197202
}
@@ -208,43 +213,28 @@ export class BedrockConverseAdapter {
208213
this.logger.info(`Received response from ${tool.name} tool`);
209214
this.logger.debug(toolResponse);
210215
return {
211-
role: 'user',
212-
content: [
213-
{
214-
toolResult: {
215-
toolUseId: toolUseBlock.toolUse.toolUseId,
216-
content: [toolResponse],
217-
status: 'success',
218-
},
219-
},
220-
],
216+
toolResult: {
217+
toolUseId: toolUseBlock.toolUse.toolUseId,
218+
content: [toolResponse],
219+
status: 'success',
220+
},
221221
};
222222
} catch (e) {
223223
if (e instanceof Error) {
224224
return {
225-
role: 'user',
226-
content: [
227-
{
228-
toolResult: {
229-
toolUseId: toolUseBlock.toolUse.toolUseId,
230-
content: [{ text: e.toString() }],
231-
status: 'error',
232-
},
233-
},
234-
],
225+
toolResult: {
226+
toolUseId: toolUseBlock.toolUse.toolUseId,
227+
content: [{ text: e.toString() }],
228+
status: 'error',
229+
},
235230
};
236231
}
237232
return {
238-
role: 'user',
239-
content: [
240-
{
241-
toolResult: {
242-
toolUseId: toolUseBlock.toolUse.toolUseId,
243-
content: [{ text: 'unknown error occurred' }],
244-
status: 'error',
245-
},
246-
},
247-
],
233+
toolResult: {
234+
toolUseId: toolUseBlock.toolUse.toolUseId,
235+
content: [{ text: 'unknown error occurred' }],
236+
status: 'error',
237+
},
248238
};
249239
}
250240
};

packages/integration-tests/src/test-project-setup/conversation_handler_project.ts

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import { NormalizedCacheObject } from '@apollo/client';
3030
import {
3131
bedrockModelId,
3232
expectedTemperatureInDataToolScenario,
33-
expectedTemperatureInProgrammaticToolScenario,
33+
expectedTemperaturesInProgrammaticToolScenario,
3434
} from '../test-projects/conversation-handler/amplify/constants.js';
3535
import { resolve } from 'path';
3636
import { fileURLToPath } from 'url';
@@ -581,7 +581,7 @@ class ConversationHandlerTestProject extends TestProjectBase {
581581
role: 'user',
582582
content: [
583583
{
584-
text: 'What is the temperature in Seattle?',
584+
text: 'What is the temperature in Seattle, Boston and Miami?',
585585
},
586586
],
587587
};
@@ -605,7 +605,21 @@ class ConversationHandlerTestProject extends TestProjectBase {
605605
// Assert that tool was used. I.e. LLM used value provided by the tool.
606606
assert.match(
607607
response.content,
608-
new RegExp(expectedTemperatureInProgrammaticToolScenario.toString())
608+
new RegExp(
609+
expectedTemperaturesInProgrammaticToolScenario.Seattle.toString()
610+
)
611+
);
612+
assert.match(
613+
response.content,
614+
new RegExp(
615+
expectedTemperaturesInProgrammaticToolScenario.Boston.toString()
616+
)
617+
);
618+
assert.match(
619+
response.content,
620+
new RegExp(
621+
expectedTemperaturesInProgrammaticToolScenario.Miami.toString()
622+
)
609623
);
610624
};
611625

packages/integration-tests/src/test-projects/conversation-handler/amplify/constants.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
*/
99
export const bedrockModelId = 'anthropic.claude-3-haiku-20240307-v1:0';
1010

11-
export const expectedTemperatureInProgrammaticToolScenario = 75;
11+
export const expectedTemperaturesInProgrammaticToolScenario = {
12+
Seattle: 75,
13+
Boston: 58,
14+
Miami: 97,
15+
};
1216

1317
export const expectedTemperatureInDataToolScenario = 85;

packages/integration-tests/src/test-projects/conversation-handler/amplify/custom-conversation-handler/custom_handler.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import {
33
createExecutableTool,
44
handleConversationTurnEvent,
55
} from '@aws-amplify/backend-ai/conversation/runtime';
6-
import { expectedTemperatureInProgrammaticToolScenario } from '../constants.js';
6+
import { expectedTemperaturesInProgrammaticToolScenario } from '../constants.js';
77

88
const thermometerInputSchema = {
99
type: 'object',
@@ -20,12 +20,13 @@ const thermometer = createExecutableTool(
2020
json: thermometerInputSchema,
2121
},
2222
(input) => {
23-
if (input.city === 'Seattle') {
23+
const city = input.city;
24+
if (city === 'Seattle' || city === 'Boston' || city === 'Miami') {
2425
return Promise.resolve({
2526
// We use this value in test assertion.
2627
// LLM uses tool to get temperature and serves this value in final response.
2728
// We're matching number only as LLM may translate unit to something more descriptive.
28-
text: `${expectedTemperatureInProgrammaticToolScenario}F`,
29+
text: `${expectedTemperaturesInProgrammaticToolScenario[city]}F`,
2930
});
3031
}
3132
throw new Error(`Unknown city ${input.city}`);

0 commit comments

Comments
 (0)