Skip to content

Commit fc33123

Browse files
auto refresh AWS SSO sesthe AWS provider will automatically refresh its credentials and retry the request without requiring a restart of Visual Studio Code
1 parent 4e623bc commit fc33123

File tree

2 files changed

+295
-123
lines changed

2 files changed

+295
-123
lines changed

src/api/providers/__tests__/bedrock.test.ts

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,115 @@ describe("AwsBedrockHandler", () => {
318318
})
319319
})
320320

321+
describe("credential refresh functionality", () => {
322+
it("should refresh client and retry when SSO session expires", async () => {
323+
// Create a handler with SSO auth
324+
const handler = new AwsBedrockHandler({
325+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
326+
awsRegion: "us-east-1",
327+
awsUseSso: true,
328+
awsProfile: "test-profile",
329+
})
330+
331+
// Create a spy for the refreshClient method but don't actually call it
332+
const refreshClientSpy = jest.spyOn(handler as any, "refreshClient").mockImplementation(() => {
333+
// Do nothing - we're just testing if it gets called
334+
})
335+
336+
// Mock the client.send method to fail with an SSO error on first call, then succeed on second call
337+
let callCount = 0
338+
const mockSend = jest.fn().mockImplementation(() => {
339+
if (callCount === 0) {
340+
callCount++
341+
throw new Error("The SSO session associated with this profile has expired")
342+
}
343+
return {
344+
output: new TextEncoder().encode(JSON.stringify({ content: "Success after refresh" })),
345+
}
346+
})
347+
348+
handler["client"] = {
349+
send: mockSend,
350+
} as unknown as BedrockRuntimeClient
351+
352+
// Call completePrompt
353+
const result = await handler.completePrompt("Test prompt")
354+
355+
// Verify that refreshClient was called
356+
expect(refreshClientSpy).toHaveBeenCalledTimes(1)
357+
358+
// Verify that send was called twice (once before refresh, once after)
359+
expect(mockSend).toHaveBeenCalledTimes(2)
360+
361+
// Verify the result
362+
expect(result).toBe("Success after refresh")
363+
})
364+
365+
it("should refresh client and retry when createMessage encounters SSO session expiration", async () => {
366+
// Create a handler with SSO auth
367+
const handler = new AwsBedrockHandler({
368+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
369+
awsRegion: "us-east-1",
370+
awsUseSso: true,
371+
awsProfile: "test-profile",
372+
})
373+
374+
// Create a spy for the refreshClient method but don't actually call it
375+
const refreshClientSpy = jest.spyOn(handler as any, "refreshClient").mockImplementation(() => {
376+
// Do nothing - we're just testing if it gets called
377+
})
378+
379+
// Mock the client.send method to fail with an SSO error on first call, then succeed on second call
380+
let callCount = 0
381+
const mockSend = jest.fn().mockImplementation(() => {
382+
if (callCount === 0) {
383+
callCount++
384+
throw new Error("The SSO session associated with this profile has expired")
385+
}
386+
return {
387+
stream: {
388+
[Symbol.asyncIterator]: async function* () {
389+
yield {
390+
metadata: {
391+
usage: {
392+
inputTokens: 10,
393+
outputTokens: 5,
394+
},
395+
},
396+
}
397+
},
398+
},
399+
}
400+
})
401+
402+
handler["client"] = {
403+
send: mockSend,
404+
} as unknown as BedrockRuntimeClient
405+
406+
// Call createMessage
407+
const stream = handler.createMessage("System prompt", [{ role: "user", content: "Test message" }])
408+
const chunks = []
409+
410+
for await (const chunk of stream) {
411+
chunks.push(chunk)
412+
}
413+
414+
// Verify that refreshClient was called
415+
expect(refreshClientSpy).toHaveBeenCalledTimes(1)
416+
417+
// Verify that send was called twice (once before refresh, once after)
418+
expect(mockSend).toHaveBeenCalledTimes(2)
419+
420+
// Verify the result
421+
expect(chunks.length).toBeGreaterThan(0)
422+
expect(chunks[0]).toEqual({
423+
type: "usage",
424+
inputTokens: 10,
425+
outputTokens: 5,
426+
})
427+
})
428+
})
429+
321430
describe("completePrompt", () => {
322431
it("should complete prompt successfully", async () => {
323432
const mockResponse = {

0 commit comments

Comments
 (0)