@@ -318,6 +318,115 @@ describe("AwsBedrockHandler", () => {
318318 } )
319319 } )
320320
321+ describe ( "credential refresh functionality" , ( ) => {
322+ it ( "should refresh client and retry when SSO session expires" , async ( ) => {
323+ // Create a handler with SSO auth
324+ const handler = new AwsBedrockHandler ( {
325+ apiModelId : "anthropic.claude-3-5-sonnet-20241022-v2:0" ,
326+ awsRegion : "us-east-1" ,
327+ awsUseSso : true ,
328+ awsProfile : "test-profile" ,
329+ } )
330+
331+ // Create a spy for the refreshClient method but don't actually call it
332+ const refreshClientSpy = jest . spyOn ( handler as any , "refreshClient" ) . mockImplementation ( ( ) => {
333+ // Do nothing - we're just testing if it gets called
334+ } )
335+
336+ // Mock the client.send method to fail with an SSO error on first call, then succeed on second call
337+ let callCount = 0
338+ const mockSend = jest . fn ( ) . mockImplementation ( ( ) => {
339+ if ( callCount === 0 ) {
340+ callCount ++
341+ throw new Error ( "The SSO session associated with this profile has expired" )
342+ }
343+ return {
344+ output : new TextEncoder ( ) . encode ( JSON . stringify ( { content : "Success after refresh" } ) ) ,
345+ }
346+ } )
347+
348+ handler [ "client" ] = {
349+ send : mockSend ,
350+ } as unknown as BedrockRuntimeClient
351+
352+ // Call completePrompt
353+ const result = await handler . completePrompt ( "Test prompt" )
354+
355+ // Verify that refreshClient was called
356+ expect ( refreshClientSpy ) . toHaveBeenCalledTimes ( 1 )
357+
358+ // Verify that send was called twice (once before refresh, once after)
359+ expect ( mockSend ) . toHaveBeenCalledTimes ( 2 )
360+
361+ // Verify the result
362+ expect ( result ) . toBe ( "Success after refresh" )
363+ } )
364+
365+ it ( "should refresh client and retry when createMessage encounters SSO session expiration" , async ( ) => {
366+ // Create a handler with SSO auth
367+ const handler = new AwsBedrockHandler ( {
368+ apiModelId : "anthropic.claude-3-5-sonnet-20241022-v2:0" ,
369+ awsRegion : "us-east-1" ,
370+ awsUseSso : true ,
371+ awsProfile : "test-profile" ,
372+ } )
373+
374+ // Create a spy for the refreshClient method but don't actually call it
375+ const refreshClientSpy = jest . spyOn ( handler as any , "refreshClient" ) . mockImplementation ( ( ) => {
376+ // Do nothing - we're just testing if it gets called
377+ } )
378+
379+ // Mock the client.send method to fail with an SSO error on first call, then succeed on second call
380+ let callCount = 0
381+ const mockSend = jest . fn ( ) . mockImplementation ( ( ) => {
382+ if ( callCount === 0 ) {
383+ callCount ++
384+ throw new Error ( "The SSO session associated with this profile has expired" )
385+ }
386+ return {
387+ stream : {
388+ [ Symbol . asyncIterator ] : async function * ( ) {
389+ yield {
390+ metadata : {
391+ usage : {
392+ inputTokens : 10 ,
393+ outputTokens : 5 ,
394+ } ,
395+ } ,
396+ }
397+ } ,
398+ } ,
399+ }
400+ } )
401+
402+ handler [ "client" ] = {
403+ send : mockSend ,
404+ } as unknown as BedrockRuntimeClient
405+
406+ // Call createMessage
407+ const stream = handler . createMessage ( "System prompt" , [ { role : "user" , content : "Test message" } ] )
408+ const chunks = [ ]
409+
410+ for await ( const chunk of stream ) {
411+ chunks . push ( chunk )
412+ }
413+
414+ // Verify that refreshClient was called
415+ expect ( refreshClientSpy ) . toHaveBeenCalledTimes ( 1 )
416+
417+ // Verify that send was called twice (once before refresh, once after)
418+ expect ( mockSend ) . toHaveBeenCalledTimes ( 2 )
419+
420+ // Verify the result
421+ expect ( chunks . length ) . toBeGreaterThan ( 0 )
422+ expect ( chunks [ 0 ] ) . toEqual ( {
423+ type : "usage" ,
424+ inputTokens : 10 ,
425+ outputTokens : 5 ,
426+ } )
427+ } )
428+ } )
429+
321430 describe ( "completePrompt" , ( ) => {
322431 it ( "should complete prompt successfully" , async ( ) => {
323432 const mockResponse = {
0 commit comments