@@ -26,6 +26,7 @@ import assert from 'assert';
26
26
import { NormalizedCacheObject } from '@apollo/client' ;
27
27
import {
28
28
bedrockModelId ,
29
+ expectedRandomNumber ,
29
30
expectedTemperatureInDataToolScenario ,
30
31
expectedTemperaturesInProgrammaticToolScenario ,
31
32
} from '../test-projects/conversation-handler/amplify/constants.js' ;
@@ -278,6 +279,16 @@ class ConversationHandlerTestProject extends TestProjectBase {
278
279
)
279
280
) ;
280
281
282
+ await this . executeWithRetry ( ( ) =>
283
+ this . assertCustomConversationHandlerCanExecuteTurnWithParameterLessTool (
284
+ backendId ,
285
+ authenticatedUserCredentials . accessToken ,
286
+ dataUrl ,
287
+ apolloClient ,
288
+ true
289
+ )
290
+ ) ;
291
+
281
292
await this . executeWithRetry ( ( ) =>
282
293
this . assertDefaultConversationHandlerCanExecuteTurnWithDataTool (
283
294
backendId ,
@@ -853,6 +864,56 @@ class ConversationHandlerTestProject extends TestProjectBase {
853
864
) ;
854
865
} ;
855
866
867
+ private assertCustomConversationHandlerCanExecuteTurnWithParameterLessTool =
868
+ async (
869
+ backendId : BackendIdentifier ,
870
+ accessToken : string ,
871
+ graphqlApiEndpoint : string ,
872
+ apolloClient : ApolloClient < NormalizedCacheObject > ,
873
+ streamResponse : boolean
874
+ ) : Promise < void > => {
875
+ const customConversationHandlerFunction = (
876
+ await this . resourceFinder . findByBackendIdentifier (
877
+ backendId ,
878
+ 'AWS::Lambda::Function' ,
879
+ ( name ) => name . includes ( 'custom' )
880
+ )
881
+ ) [ 0 ] ;
882
+
883
+ const message : CreateConversationMessageChatInput = {
884
+ conversationId : randomUUID ( ) . toString ( ) ,
885
+ id : randomUUID ( ) . toString ( ) ,
886
+ role : 'user' ,
887
+ content : [
888
+ {
889
+ text : 'Give me a random number' ,
890
+ } ,
891
+ ] ,
892
+ } ;
893
+ await this . insertMessage ( apolloClient , message ) ;
894
+
895
+ // send event
896
+ const event : ConversationTurnEvent = {
897
+ conversationId : message . conversationId ,
898
+ currentMessageId : message . id ,
899
+ graphqlApiEndpoint : graphqlApiEndpoint ,
900
+ request : {
901
+ headers : { authorization : accessToken } ,
902
+ } ,
903
+ ...this . getCommonEventProperties ( streamResponse ) ,
904
+ } ;
905
+ const response = await this . executeConversationTurn (
906
+ event ,
907
+ customConversationHandlerFunction ,
908
+ apolloClient
909
+ ) ;
910
+ // Assert that tool was used. I.e. LLM used value provided by the tool.
911
+ assert . match (
912
+ response . content ,
913
+ new RegExp ( expectedRandomNumber . toString ( ) )
914
+ ) ;
915
+ } ;
916
+
856
917
private assertDefaultConversationHandlerCanPropagateError = async (
857
918
backendId : BackendIdentifier ,
858
919
accessToken : string ,
0 commit comments