Skip to content

Commit fd8759d

Browse files
author
Kamil Sobol
authored
Fix a case when Bedrock throws validation error if tool input is not persisted in history (#2230)
* Fix a case when Bedrock throws validation error if tool input is not persisted in history * some prompt engineering
1 parent 5d873a1 commit fd8759d

File tree

6 files changed

+100
-4
lines changed

6 files changed

+100
-4
lines changed

.changeset/clever-emus-hope.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@aws-amplify/ai-constructs': patch
3+
---
4+
5+
Fix a case when Bedrock throws validation error if tool input is not persisted in history

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -901,8 +901,8 @@ void describe('Bedrock converse adapter', () => {
901901
assert.strictEqual(responseText, 'finalResponse');
902902

903903
assert.strictEqual(toolExecuteMock.mock.calls.length, 2);
904-
assert.strictEqual(toolExecuteMock.mock.calls[0].arguments[0], undefined);
905-
assert.strictEqual(toolExecuteMock.mock.calls[1].arguments[0], undefined);
904+
assert.deepStrictEqual(toolExecuteMock.mock.calls[0].arguments[0], {});
905+
assert.deepStrictEqual(toolExecuteMock.mock.calls[1].arguments[0], {});
906906
});
907907

908908
void it('throws if tool is duplicated', () => {

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,10 @@ export class BedrockConverseAdapter {
252252
if (toolUseBlock) {
253253
if (toolUseInput) {
254254
toolUseBlock.toolUse.input = JSON.parse(toolUseInput);
255+
} else {
256+
// Bedrock API requires tool input to be non-null in message history.
257+
// Therefore, falling back to empty object.
258+
toolUseBlock.toolUse.input = {};
255259
}
256260
accumulatedAssistantMessage.content?.push(toolUseBlock);
257261
if (

packages/integration-tests/src/test-project-setup/conversation_handler_project.ts

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import assert from 'assert';
2626
import { NormalizedCacheObject } from '@apollo/client';
2727
import {
2828
bedrockModelId,
29+
expectedRandomNumber,
2930
expectedTemperatureInDataToolScenario,
3031
expectedTemperaturesInProgrammaticToolScenario,
3132
} from '../test-projects/conversation-handler/amplify/constants.js';
@@ -278,6 +279,16 @@ class ConversationHandlerTestProject extends TestProjectBase {
278279
)
279280
);
280281

282+
await this.executeWithRetry(() =>
283+
this.assertCustomConversationHandlerCanExecuteTurnWithParameterLessTool(
284+
backendId,
285+
authenticatedUserCredentials.accessToken,
286+
dataUrl,
287+
apolloClient,
288+
true
289+
)
290+
);
291+
281292
await this.executeWithRetry(() =>
282293
this.assertDefaultConversationHandlerCanExecuteTurnWithDataTool(
283294
backendId,
@@ -853,6 +864,56 @@ class ConversationHandlerTestProject extends TestProjectBase {
853864
);
854865
};
855866

867+
private assertCustomConversationHandlerCanExecuteTurnWithParameterLessTool =
868+
async (
869+
backendId: BackendIdentifier,
870+
accessToken: string,
871+
graphqlApiEndpoint: string,
872+
apolloClient: ApolloClient<NormalizedCacheObject>,
873+
streamResponse: boolean
874+
): Promise<void> => {
875+
const customConversationHandlerFunction = (
876+
await this.resourceFinder.findByBackendIdentifier(
877+
backendId,
878+
'AWS::Lambda::Function',
879+
(name) => name.includes('custom')
880+
)
881+
)[0];
882+
883+
const message: CreateConversationMessageChatInput = {
884+
conversationId: randomUUID().toString(),
885+
id: randomUUID().toString(),
886+
role: 'user',
887+
content: [
888+
{
889+
text: 'Give me a random number',
890+
},
891+
],
892+
};
893+
await this.insertMessage(apolloClient, message);
894+
895+
// send event
896+
const event: ConversationTurnEvent = {
897+
conversationId: message.conversationId,
898+
currentMessageId: message.id,
899+
graphqlApiEndpoint: graphqlApiEndpoint,
900+
request: {
901+
headers: { authorization: accessToken },
902+
},
903+
...this.getCommonEventProperties(streamResponse),
904+
};
905+
const response = await this.executeConversationTurn(
906+
event,
907+
customConversationHandlerFunction,
908+
apolloClient
909+
);
910+
// Assert that tool was used. I.e. LLM used value provided by the tool.
911+
assert.match(
912+
response.content,
913+
new RegExp(expectedRandomNumber.toString())
914+
);
915+
};
916+
856917
private assertDefaultConversationHandlerCanPropagateError = async (
857918
backendId: BackendIdentifier,
858919
accessToken: string,

packages/integration-tests/src/test-projects/conversation-handler/amplify/constants.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ export const expectedTemperaturesInProgrammaticToolScenario = {
1414
};
1515

1616
export const expectedTemperatureInDataToolScenario = 85;
17+
18+
export const expectedRandomNumber = 42;

packages/integration-tests/src/test-projects/conversation-handler/amplify/custom-conversation-handler/custom_handler.ts

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ import {
33
createExecutableTool,
44
handleConversationTurnEvent,
55
} from '@aws-amplify/backend-ai/conversation/runtime';
6-
import { expectedTemperaturesInProgrammaticToolScenario } from '../constants.js';
6+
import {
7+
expectedRandomNumber,
8+
expectedTemperaturesInProgrammaticToolScenario,
9+
} from '../constants.js';
710

811
const thermometerInputSchema = {
912
type: 'object',
@@ -33,11 +36,32 @@ const thermometer = createExecutableTool(
3336
}
3437
);
3538

39+
// Parameter-less tool.
40+
const randomNumberGeneratorInputSchema = {
41+
type: 'object',
42+
properties: {},
43+
required: [],
44+
} as const;
45+
46+
const randomNumberGenerator = createExecutableTool(
47+
'randomNumberGenerator',
48+
'Returns a random number',
49+
{
50+
json: randomNumberGeneratorInputSchema,
51+
},
52+
() => {
53+
return Promise.resolve({
54+
// We use this value in test assertion.
55+
text: `${expectedRandomNumber}`,
56+
});
57+
}
58+
);
59+
3660
/**
3761
* Handler with simple tool.
3862
*/
3963
export const handler = async (event: ConversationTurnEvent) => {
4064
await handleConversationTurnEvent(event, {
41-
tools: [thermometer],
65+
tools: [randomNumberGenerator, thermometer],
4266
});
4367
};

0 commit comments

Comments
 (0)