Skip to content

Commit 921fdad

Browse files
dbanksdesignsobolk
andauthored
Add document support in AI kit (#2566)
* Add document support in AI kit * Update packages/ai-constructs/src/conversation/runtime/types.ts Co-authored-by: Kamil Sobol <[email protected]> * adding e2e tests * adding docx to dictionary * fixing e2e test * Update packages/integration-tests/src/test-project-setup/conversation_handler_project.ts Co-authored-by: Kamil Sobol <[email protected]> --------- Co-authored-by: Kamil Sobol <[email protected]>
1 parent fad46a4 commit 921fdad

File tree

7 files changed

+200
-0
lines changed

7 files changed

+200
-0
lines changed

.changeset/famous-cougars-fold.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@aws-amplify/ai-constructs': minor
3+
---
4+
5+
Adding document support for ai conversation routes

.eslint_dictionary.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"deprecator",
5050
"deserializer",
5151
"disambiguator",
52+
"docx",
5253
"downlevel",
5354
"durations",
5455
"dynamodb",

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,72 @@ void describe('Bedrock converse adapter', () => {
824824
},
825825
]);
826826
});
827+
828+
void it('decodes base64 encoded documents', async () => {
829+
const event: ConversationTurnEvent = {
830+
...commonEvent,
831+
};
832+
833+
const fakeDocumentPayload = randomBytes(32);
834+
835+
messageHistoryRetrieverMockGetEventMessages.mock.mockImplementationOnce(
836+
() => {
837+
return Promise.resolve([
838+
{
839+
id: '',
840+
conversationId: '',
841+
role: 'user',
842+
content: [
843+
{
844+
document: {
845+
name: 'test',
846+
format: 'doc',
847+
source: {
848+
bytes: fakeDocumentPayload.toString('base64'),
849+
},
850+
},
851+
},
852+
],
853+
},
854+
]);
855+
}
856+
);
857+
858+
const bedrockClient = new BedrockRuntimeClient();
859+
const content = [{ text: 'block1' }, { text: 'block2' }];
860+
const bedrockResponse = mockBedrockResponse(content, streamResponse);
861+
const bedrockClientSendMock = mock.method(bedrockClient, 'send', () =>
862+
Promise.resolve(bedrockResponse)
863+
);
864+
865+
await new BedrockConverseAdapter(
866+
event,
867+
[],
868+
bedrockClient,
869+
undefined,
870+
messageHistoryRetriever
871+
).askBedrock();
872+
873+
assert.strictEqual(bedrockClientSendMock.mock.calls.length, 1);
874+
const bedrockRequest = bedrockClientSendMock.mock.calls[0]
875+
.arguments[0] as unknown as ConverseCommand;
876+
assert.deepStrictEqual(bedrockRequest.input.messages, [
877+
{
878+
role: 'user',
879+
content: [
880+
{
881+
document: {
882+
format: 'doc',
883+
name: 'test',
884+
source: {
885+
bytes: fakeDocumentPayload,
886+
},
887+
},
888+
},
889+
],
890+
},
891+
]);
892+
});
827893
});
828894
});
829895

packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,18 @@ export class BedrockConverseAdapter {
385385
},
386386
},
387387
});
388+
} else if (typeof contentElement.document?.source?.bytes === 'string') {
389+
messageContent.push({
390+
document: {
391+
...contentElement.document,
392+
source: {
393+
bytes: Buffer.from(
394+
contentElement.document.source.bytes,
395+
'base64'
396+
),
397+
},
398+
},
399+
});
388400
} else {
389401
// Otherwise type conforms to Bedrock's type and it's safe to cast.
390402
messageContent.push(contentElement as ContentBlock);

packages/ai-constructs/src/conversation/runtime/types.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@ export type ConversationMessage = {
2121

2222
export type ConversationMessageContentBlock =
2323
| bedrock.ContentBlock
24+
| {
25+
image?: never;
26+
// These are needed so that union with other content block types works.
27+
// See https://docs.aws.amazon.com/AWSJavaScriptSDK/v3/latest/Package/-aws-sdk-client-bedrock-runtime/TypeAlias/ContentBlock/.
28+
text?: never;
29+
document: Omit<bedrock.DocumentBlock, 'source'> & {
30+
// Upstream (Appsync) may send documents in a form of Base64 encoded strings
31+
source: { bytes: string };
32+
};
33+
toolUse?: never;
34+
toolResult?: never;
35+
guardContent?: never;
36+
$unknown?: never;
37+
}
2438
| {
2539
image: Omit<bedrock.ImageBlock, 'source'> & {
2640
// Upstream (Appsync) may send images in a form of Base64 encoded strings

packages/integration-tests/src/test-project-setup/conversation_handler_project.ts

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ type ConversationMessage = {
6060

6161
type ConversationMessageContentBlock =
6262
| bedrock.ContentBlock
63+
| {
64+
image?: never;
65+
// These are needed so that union with other content block types works.
66+
// See https://docs.aws.amazon.com/AWSJavaScriptSDK/v3/latest/Package/-aws-sdk-client-bedrock-runtime/TypeAlias/ContentBlock/.
67+
text?: never;
68+
document: Omit<bedrock.DocumentBlock, 'source'> & {
69+
// Upstream (Appsync) may send images in a form of Base64 encoded strings
70+
source: { bytes: string };
71+
};
72+
toolUse?: never;
73+
toolResult?: never;
74+
guardContent?: never;
75+
$unknown?: never;
76+
}
6377
| {
6478
image: Omit<bedrock.ImageBlock, 'source'> & {
6579
// Upstream (Appsync) may send images in a form of Base64 encoded strings
@@ -352,6 +366,26 @@ class ConversationHandlerTestProject extends TestProjectBase {
352366
)
353367
);
354368

369+
await this.executeWithRetry(() =>
370+
this.assertDefaultConversationHandlerCanExecuteTurnWithDocument(
371+
backendId,
372+
authenticatedUserCredentials.accessToken,
373+
dataUrl,
374+
apolloClient,
375+
false
376+
)
377+
);
378+
379+
await this.executeWithRetry(() =>
380+
this.assertDefaultConversationHandlerCanExecuteTurnWithDocument(
381+
backendId,
382+
authenticatedUserCredentials.accessToken,
383+
dataUrl,
384+
apolloClient,
385+
true
386+
)
387+
);
388+
355389
await this.executeWithRetry((attempt) =>
356390
this.assertDefaultConversationHandlerCanPropagateError(
357391
backendId,
@@ -496,6 +530,74 @@ class ConversationHandlerTestProject extends TestProjectBase {
496530
assert.match(response.content, /(aws)|(AWS)|(Amazon Web Services)/);
497531
};
498532

533+
private assertDefaultConversationHandlerCanExecuteTurnWithDocument = async (
534+
backendId: BackendIdentifier,
535+
accessToken: string,
536+
graphqlApiEndpoint: string,
537+
apolloClient: ApolloClient<NormalizedCacheObject>,
538+
streamResponse: boolean
539+
): Promise<void> => {
540+
const defaultConversationHandlerFunction = (
541+
await this.resourceFinder.findByBackendIdentifier(
542+
backendId,
543+
'AWS::Lambda::Function',
544+
(name) => name.includes('default')
545+
)
546+
)[0];
547+
548+
const documentPath = resolve(
549+
fileURLToPath(import.meta.url),
550+
'..',
551+
'..',
552+
'..',
553+
'src',
554+
'test-projects',
555+
'conversation-handler',
556+
'resources',
557+
'sample-document.docx'
558+
);
559+
560+
const documentSource = await fs.readFile(documentPath, 'base64');
561+
562+
const message: CreateConversationMessageChatInput = {
563+
id: randomUUID().toString(),
564+
conversationId: randomUUID().toString(),
565+
role: 'user',
566+
content: [
567+
{
568+
text: 'What is in the attached document?',
569+
},
570+
{
571+
document: {
572+
format: 'docx',
573+
name: 'sample-document',
574+
source: { bytes: documentSource },
575+
},
576+
},
577+
],
578+
};
579+
580+
// send event
581+
const event: ConversationTurnEvent = {
582+
conversationId: message.conversationId,
583+
currentMessageId: message.id,
584+
graphqlApiEndpoint: graphqlApiEndpoint,
585+
request: {
586+
headers: { authorization: accessToken },
587+
},
588+
...this.getCommonEventProperties(streamResponse),
589+
};
590+
await this.insertMessage(apolloClient, message);
591+
const response = await this.executeConversationTurn(
592+
event,
593+
defaultConversationHandlerFunction,
594+
apolloClient
595+
);
596+
// The document contains a hello world string. Responses may vary, but they should always contain statements below.
597+
assert.match(response.content, /document/);
598+
assert.match(response.content, /(H|h)ello (W|w)orld/);
599+
};
600+
499601
private assertDefaultConversationHandlerCanExecuteTurnWithDataTool = async (
500602
backendId: BackendIdentifier,
501603
accessToken: string,
12.9 KB
Binary file not shown.

0 commit comments

Comments
 (0)