Skip to content

Commit d0a90b1

Browse files
author
Kamil Sobol
authored
Use message history instead of event payload for conversation handler (#2047)
* Use message history instead of event payload for conversational route * refactor e2e * refactor gql requests * fallback * lint * add test for retriever * refactor that * todo comments * lint * refactor that * rename * process history * process history test * more tests * more tests
1 parent d538ecc commit d0a90b1

17 files changed

+1311
-342
lines changed

.changeset/plenty-wombats-fry.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@aws-amplify/ai-constructs': minor
3+
---
4+
5+
Use message history instead of event payload for conversational route

packages/ai-constructs/API.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,14 @@ type ConversationTurnEvent = {
9191
authorization: string;
9292
};
9393
};
94-
messages: Array<ConversationMessage>;
94+
messages?: Array<ConversationMessage>;
95+
messageHistoryQuery: {
96+
getQueryName: string;
97+
getQueryInputTypeName: string;
98+
listQueryName: string;
99+
listQueryInputTypeName: string;
100+
listQueryLimit?: number;
101+
};
95102
toolsConfiguration?: {
96103
dataTools?: Array<ToolDefinition & {
97104
graphqlRequestInputDescriptor: {

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts

Lines changed: 77 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import { describe, it, mock } from 'node:test';
22
import assert from 'node:assert';
3-
import { ConversationTurnEvent, ExecutableTool, ToolDefinition } from './types';
3+
import {
4+
ConversationMessage,
5+
ConversationTurnEvent,
6+
ExecutableTool,
7+
ToolDefinition,
8+
} from './types';
49
import { BedrockConverseAdapter } from './bedrock_converse_adapter';
510
import {
611
BedrockRuntimeClient,
@@ -13,22 +18,19 @@ import {
1318
} from '@aws-sdk/client-bedrock-runtime';
1419
import { ConversationTurnEventToolsProvider } from './event-tools-provider';
1520
import { randomBytes, randomUUID } from 'node:crypto';
21+
import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever';
1622

1723
void describe('Bedrock converse adapter', () => {
1824
const commonEvent: Readonly<ConversationTurnEvent> = {
1925
conversationId: '',
2026
currentMessageId: '',
2127
graphqlApiEndpoint: '',
22-
messages: [
23-
{
24-
role: 'user',
25-
content: [
26-
{
27-
text: 'event message',
28-
},
29-
],
30-
},
31-
],
28+
messageHistoryQuery: {
29+
getQueryName: '',
30+
getQueryInputTypeName: '',
31+
listQueryName: '',
32+
listQueryInputTypeName: '',
33+
},
3234
modelConfiguration: {
3335
modelId: 'testModelId',
3436
systemPrompt: 'testSystemPrompt',
@@ -46,6 +48,27 @@ void describe('Bedrock converse adapter', () => {
4648
},
4749
};
4850

51+
const messages: Array<ConversationMessage> = [
52+
{
53+
role: 'user',
54+
content: [
55+
{
56+
text: 'event message',
57+
},
58+
],
59+
},
60+
];
61+
const messageHistoryRetriever = new ConversationMessageHistoryRetriever(
62+
commonEvent
63+
);
64+
const messageHistoryRetrieverMockGetEventMessages = mock.method(
65+
messageHistoryRetriever,
66+
'getMessageHistory',
67+
() => {
68+
return Promise.resolve(messages);
69+
}
70+
);
71+
4972
void it('calls bedrock to get conversation response', async () => {
5073
const event: ConversationTurnEvent = {
5174
...commonEvent,
@@ -78,7 +101,9 @@ void describe('Bedrock converse adapter', () => {
78101
const responseContent = await new BedrockConverseAdapter(
79102
event,
80103
[],
81-
bedrockClient
104+
bedrockClient,
105+
undefined,
106+
messageHistoryRetriever
82107
).askBedrock();
83108

84109
assert.deepStrictEqual(
@@ -90,7 +115,7 @@ void describe('Bedrock converse adapter', () => {
90115
const bedrockRequest = bedrockClientSendMock.mock.calls[0]
91116
.arguments[0] as unknown as ConverseCommand;
92117
const expectedBedrockInput: ConverseCommandInput = {
93-
messages: event.messages as Array<Message>,
118+
messages: messages as Array<Message>,
94119
modelId: event.modelConfiguration.modelId,
95120
inferenceConfig: event.modelConfiguration.inferenceConfiguration,
96121
system: [
@@ -211,7 +236,8 @@ void describe('Bedrock converse adapter', () => {
211236
event,
212237
[additionalTool],
213238
bedrockClient,
214-
eventToolsProvider
239+
eventToolsProvider,
240+
messageHistoryRetriever
215241
).askBedrock();
216242

217243
assert.deepStrictEqual(
@@ -251,7 +277,7 @@ void describe('Bedrock converse adapter', () => {
251277
const bedrockRequest1 = bedrockClientSendMock.mock.calls[0]
252278
.arguments[0] as unknown as ConverseCommand;
253279
const expectedBedrockInput1: ConverseCommandInput = {
254-
messages: event.messages as Array<Message>,
280+
messages: messages as Array<Message>,
255281
...expectedBedrockInputCommonProperties,
256282
};
257283
assert.deepStrictEqual(bedrockRequest1.input, expectedBedrockInput1);
@@ -264,7 +290,7 @@ void describe('Bedrock converse adapter', () => {
264290
);
265291
const expectedBedrockInput2: ConverseCommandInput = {
266292
messages: [
267-
...(event.messages as Array<Message>),
293+
...(messages as Array<Message>),
268294
additionalToolUseBedrockResponse.output?.message,
269295
{
270296
role: 'user',
@@ -447,7 +473,9 @@ void describe('Bedrock converse adapter', () => {
447473
const responseContent = await new BedrockConverseAdapter(
448474
event,
449475
[tool],
450-
bedrockClient
476+
bedrockClient,
477+
undefined,
478+
messageHistoryRetriever
451479
).askBedrock();
452480

453481
assert.deepStrictEqual(
@@ -543,7 +571,9 @@ void describe('Bedrock converse adapter', () => {
543571
const responseContent = await new BedrockConverseAdapter(
544572
event,
545573
[tool],
546-
bedrockClient
574+
bedrockClient,
575+
undefined,
576+
messageHistoryRetriever
547577
).askBedrock();
548578

549579
assert.deepStrictEqual(
@@ -645,7 +675,9 @@ void describe('Bedrock converse adapter', () => {
645675
const responseContent = await new BedrockConverseAdapter(
646676
event,
647677
[additionalTool],
648-
bedrockClient
678+
bedrockClient,
679+
undefined,
680+
messageHistoryRetriever
649681
).askBedrock();
650682

651683
assert.deepStrictEqual(responseContent, [clientToolUseBlock]);
@@ -682,7 +714,7 @@ void describe('Bedrock converse adapter', () => {
682714
const bedrockRequest = bedrockClientSendMock.mock.calls[0]
683715
.arguments[0] as unknown as ConverseCommand;
684716
const expectedBedrockInput: ConverseCommandInput = {
685-
messages: event.messages as Array<Message>,
717+
messages: messages as Array<Message>,
686718
...expectedBedrockInputCommonProperties,
687719
};
688720
assert.deepStrictEqual(bedrockRequest.input, expectedBedrockInput);
@@ -695,21 +727,27 @@ void describe('Bedrock converse adapter', () => {
695727

696728
const fakeImagePayload = randomBytes(32);
697729

698-
event.messages = [
699-
{
700-
role: 'user',
701-
content: [
730+
messageHistoryRetrieverMockGetEventMessages.mock.mockImplementationOnce(
731+
() => {
732+
return Promise.resolve([
702733
{
703-
image: {
704-
format: 'png',
705-
source: {
706-
bytes: fakeImagePayload.toString('base64'),
734+
id: '',
735+
conversationId: '',
736+
role: 'user',
737+
content: [
738+
{
739+
image: {
740+
format: 'png',
741+
source: {
742+
bytes: fakeImagePayload.toString('base64'),
743+
},
744+
},
707745
},
708-
},
746+
],
709747
},
710-
],
711-
},
712-
];
748+
]);
749+
}
750+
);
713751

714752
const bedrockClient = new BedrockRuntimeClient();
715753
const bedrockResponse: ConverseCommandOutput = {
@@ -735,7 +773,13 @@ void describe('Bedrock converse adapter', () => {
735773
Promise.resolve(bedrockResponse)
736774
);
737775

738-
await new BedrockConverseAdapter(event, [], bedrockClient).askBedrock();
776+
await new BedrockConverseAdapter(
777+
event,
778+
[],
779+
bedrockClient,
780+
undefined,
781+
messageHistoryRetriever
782+
).askBedrock();
739783

740784
assert.strictEqual(bedrockClientSendMock.mock.calls.length, 1);
741785
const bedrockRequest = bedrockClientSendMock.mock.calls[0]

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import {
1414
ToolDefinition,
1515
} from './types.js';
1616
import { ConversationTurnEventToolsProvider } from './event-tools-provider';
17+
import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever';
1718

1819
/**
1920
* This class is responsible for interacting with Bedrock Converse API
@@ -36,7 +37,10 @@ export class BedrockConverseAdapter {
3637
private readonly bedrockClient: BedrockRuntimeClient = new BedrockRuntimeClient(
3738
{ region: event.modelConfiguration.region }
3839
),
39-
eventToolsProvider = new ConversationTurnEventToolsProvider(event)
40+
eventToolsProvider = new ConversationTurnEventToolsProvider(event),
41+
private readonly messageHistoryRetriever = new ConversationMessageHistoryRetriever(
42+
event
43+
)
4044
) {
4145
this.executableTools = [
4246
...eventToolsProvider.getEventTools(),
@@ -73,7 +77,8 @@ export class BedrockConverseAdapter {
7377
const { modelId, systemPrompt, inferenceConfiguration } =
7478
this.event.modelConfiguration;
7579

76-
const messages: Array<Message> = this.getEventMessagesAsBedrockMessages();
80+
const messages: Array<Message> =
81+
await this.getEventMessagesAsBedrockMessages();
7782

7883
let bedrockResponse: ConverseCommandOutput;
7984
do {
@@ -124,9 +129,13 @@ export class BedrockConverseAdapter {
124129
* 1. Makes a copy so that we don't mutate event.
125130
* 2. Decodes Base64 encoded images.
126131
*/
127-
private getEventMessagesAsBedrockMessages = (): Array<Message> => {
132+
private getEventMessagesAsBedrockMessages = async (): Promise<
133+
Array<Message>
134+
> => {
128135
const messages: Array<Message> = [];
129-
for (const message of this.event.messages) {
136+
const eventMessages =
137+
await this.messageHistoryRetriever.getMessageHistory();
138+
for (const message of eventMessages) {
130139
const messageContent: Array<ContentBlock> = [];
131140
for (const contentElement of message.content) {
132141
if (typeof contentElement.image?.source?.bytes === 'string') {

0 commit comments

Comments
 (0)