1
1
import { describe , it , mock } from 'node:test' ;
2
2
import assert from 'node:assert' ;
3
- import { ConversationTurnEvent , ExecutableTool , ToolDefinition } from './types' ;
3
+ import {
4
+ ConversationMessage ,
5
+ ConversationTurnEvent ,
6
+ ExecutableTool ,
7
+ ToolDefinition ,
8
+ } from './types' ;
4
9
import { BedrockConverseAdapter } from './bedrock_converse_adapter' ;
5
10
import {
6
11
BedrockRuntimeClient ,
@@ -13,22 +18,19 @@ import {
13
18
} from '@aws-sdk/client-bedrock-runtime' ;
14
19
import { ConversationTurnEventToolsProvider } from './event-tools-provider' ;
15
20
import { randomBytes , randomUUID } from 'node:crypto' ;
21
+ import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever' ;
16
22
17
23
void describe ( 'Bedrock converse adapter' , ( ) => {
18
24
const commonEvent : Readonly < ConversationTurnEvent > = {
19
25
conversationId : '' ,
20
26
currentMessageId : '' ,
21
27
graphqlApiEndpoint : '' ,
22
- messages : [
23
- {
24
- role : 'user' ,
25
- content : [
26
- {
27
- text : 'event message' ,
28
- } ,
29
- ] ,
30
- } ,
31
- ] ,
28
+ messageHistoryQuery : {
29
+ getQueryName : '' ,
30
+ getQueryInputTypeName : '' ,
31
+ listQueryName : '' ,
32
+ listQueryInputTypeName : '' ,
33
+ } ,
32
34
modelConfiguration : {
33
35
modelId : 'testModelId' ,
34
36
systemPrompt : 'testSystemPrompt' ,
@@ -46,6 +48,27 @@ void describe('Bedrock converse adapter', () => {
46
48
} ,
47
49
} ;
48
50
51
+ const messages : Array < ConversationMessage > = [
52
+ {
53
+ role : 'user' ,
54
+ content : [
55
+ {
56
+ text : 'event message' ,
57
+ } ,
58
+ ] ,
59
+ } ,
60
+ ] ;
61
+ const messageHistoryRetriever = new ConversationMessageHistoryRetriever (
62
+ commonEvent
63
+ ) ;
64
+ const messageHistoryRetrieverMockGetEventMessages = mock . method (
65
+ messageHistoryRetriever ,
66
+ 'getMessageHistory' ,
67
+ ( ) => {
68
+ return Promise . resolve ( messages ) ;
69
+ }
70
+ ) ;
71
+
49
72
void it ( 'calls bedrock to get conversation response' , async ( ) => {
50
73
const event : ConversationTurnEvent = {
51
74
...commonEvent ,
@@ -78,7 +101,9 @@ void describe('Bedrock converse adapter', () => {
78
101
const responseContent = await new BedrockConverseAdapter (
79
102
event ,
80
103
[ ] ,
81
- bedrockClient
104
+ bedrockClient ,
105
+ undefined ,
106
+ messageHistoryRetriever
82
107
) . askBedrock ( ) ;
83
108
84
109
assert . deepStrictEqual (
@@ -90,7 +115,7 @@ void describe('Bedrock converse adapter', () => {
90
115
const bedrockRequest = bedrockClientSendMock . mock . calls [ 0 ]
91
116
. arguments [ 0 ] as unknown as ConverseCommand ;
92
117
const expectedBedrockInput : ConverseCommandInput = {
93
- messages : event . messages as Array < Message > ,
118
+ messages : messages as Array < Message > ,
94
119
modelId : event . modelConfiguration . modelId ,
95
120
inferenceConfig : event . modelConfiguration . inferenceConfiguration ,
96
121
system : [
@@ -211,7 +236,8 @@ void describe('Bedrock converse adapter', () => {
211
236
event ,
212
237
[ additionalTool ] ,
213
238
bedrockClient ,
214
- eventToolsProvider
239
+ eventToolsProvider ,
240
+ messageHistoryRetriever
215
241
) . askBedrock ( ) ;
216
242
217
243
assert . deepStrictEqual (
@@ -251,7 +277,7 @@ void describe('Bedrock converse adapter', () => {
251
277
const bedrockRequest1 = bedrockClientSendMock . mock . calls [ 0 ]
252
278
. arguments [ 0 ] as unknown as ConverseCommand ;
253
279
const expectedBedrockInput1 : ConverseCommandInput = {
254
- messages : event . messages as Array < Message > ,
280
+ messages : messages as Array < Message > ,
255
281
...expectedBedrockInputCommonProperties ,
256
282
} ;
257
283
assert . deepStrictEqual ( bedrockRequest1 . input , expectedBedrockInput1 ) ;
@@ -264,7 +290,7 @@ void describe('Bedrock converse adapter', () => {
264
290
) ;
265
291
const expectedBedrockInput2 : ConverseCommandInput = {
266
292
messages : [
267
- ...( event . messages as Array < Message > ) ,
293
+ ...( messages as Array < Message > ) ,
268
294
additionalToolUseBedrockResponse . output ?. message ,
269
295
{
270
296
role : 'user' ,
@@ -447,7 +473,9 @@ void describe('Bedrock converse adapter', () => {
447
473
const responseContent = await new BedrockConverseAdapter (
448
474
event ,
449
475
[ tool ] ,
450
- bedrockClient
476
+ bedrockClient ,
477
+ undefined ,
478
+ messageHistoryRetriever
451
479
) . askBedrock ( ) ;
452
480
453
481
assert . deepStrictEqual (
@@ -543,7 +571,9 @@ void describe('Bedrock converse adapter', () => {
543
571
const responseContent = await new BedrockConverseAdapter (
544
572
event ,
545
573
[ tool ] ,
546
- bedrockClient
574
+ bedrockClient ,
575
+ undefined ,
576
+ messageHistoryRetriever
547
577
) . askBedrock ( ) ;
548
578
549
579
assert . deepStrictEqual (
@@ -645,7 +675,9 @@ void describe('Bedrock converse adapter', () => {
645
675
const responseContent = await new BedrockConverseAdapter (
646
676
event ,
647
677
[ additionalTool ] ,
648
- bedrockClient
678
+ bedrockClient ,
679
+ undefined ,
680
+ messageHistoryRetriever
649
681
) . askBedrock ( ) ;
650
682
651
683
assert . deepStrictEqual ( responseContent , [ clientToolUseBlock ] ) ;
@@ -682,7 +714,7 @@ void describe('Bedrock converse adapter', () => {
682
714
const bedrockRequest = bedrockClientSendMock . mock . calls [ 0 ]
683
715
. arguments [ 0 ] as unknown as ConverseCommand ;
684
716
const expectedBedrockInput : ConverseCommandInput = {
685
- messages : event . messages as Array < Message > ,
717
+ messages : messages as Array < Message > ,
686
718
...expectedBedrockInputCommonProperties ,
687
719
} ;
688
720
assert . deepStrictEqual ( bedrockRequest . input , expectedBedrockInput ) ;
@@ -695,21 +727,27 @@ void describe('Bedrock converse adapter', () => {
695
727
696
728
const fakeImagePayload = randomBytes ( 32 ) ;
697
729
698
- event . messages = [
699
- {
700
- role : 'user' ,
701
- content : [
730
+ messageHistoryRetrieverMockGetEventMessages . mock . mockImplementationOnce (
731
+ ( ) => {
732
+ return Promise . resolve ( [
702
733
{
703
- image : {
704
- format : 'png' ,
705
- source : {
706
- bytes : fakeImagePayload . toString ( 'base64' ) ,
734
+ id : '' ,
735
+ conversationId : '' ,
736
+ role : 'user' ,
737
+ content : [
738
+ {
739
+ image : {
740
+ format : 'png' ,
741
+ source : {
742
+ bytes : fakeImagePayload . toString ( 'base64' ) ,
743
+ } ,
744
+ } ,
707
745
} ,
708
- } ,
746
+ ] ,
709
747
} ,
710
- ] ,
711
- } ,
712
- ] ;
748
+ ] ) ;
749
+ }
750
+ ) ;
713
751
714
752
const bedrockClient = new BedrockRuntimeClient ( ) ;
715
753
const bedrockResponse : ConverseCommandOutput = {
@@ -735,7 +773,13 @@ void describe('Bedrock converse adapter', () => {
735
773
Promise . resolve ( bedrockResponse )
736
774
) ;
737
775
738
- await new BedrockConverseAdapter ( event , [ ] , bedrockClient ) . askBedrock ( ) ;
776
+ await new BedrockConverseAdapter (
777
+ event ,
778
+ [ ] ,
779
+ bedrockClient ,
780
+ undefined ,
781
+ messageHistoryRetriever
782
+ ) . askBedrock ( ) ;
739
783
740
784
assert . strictEqual ( bedrockClientSendMock . mock . calls . length , 1 ) ;
741
785
const bedrockRequest = bedrockClientSendMock . mock . calls [ 0 ]
0 commit comments