Skip to content

Commit d1a8809

Browse files
committed
Fix AWS credentials caching issue
Fixes #2469 Add logic to refresh AWS credentials when they expire. * **src/api/providers/bedrock.ts** - Add logic to refresh AWS credentials using `fromIni` from `@aws-sdk/credential-providers` when a request fails due to expired credentials. - Update `createMessage` and `completePrompt` methods to handle credential refresh. - Add a private method `refreshCredentials` to handle the actual credential refresh process. * **src/api/providers/__tests__/bedrock.test.ts** - Add tests to verify that refreshed credentials are automatically detected and used without requiring a restart. * **src/api/providers/__tests__/bedrock-invokedModelId.test.ts** - Add tests for detecting and using refreshed credentials. * **src/api/providers/__tests__/bedrock-custom-arn.test.ts** - Add tests for detecting and using refreshed credentials. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/RooVetGit/Roo-Code/issues/2469?shareId=XXXX-XXXX-XXXX-XXXX).
1 parent 5352beb commit d1a8809

File tree

4 files changed

+414
-26
lines changed

4 files changed

+414
-26
lines changed

src/api/providers/__tests__/bedrock-custom-arn.test.ts

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,55 @@ describe("Bedrock ARN Handling", () => {
249249
arnRegion: "eu-west-1",
250250
}),
251251
)
252+
infoSpy.mockRestore()
253+
})
254+
255+
it("should refresh AWS credentials when they expire", async () => {
256+
// Mock the send method to simulate expired credentials
257+
const mockSend = jest.fn().mockImplementationOnce(async () => {
258+
const error = new Error("The security token included in the request is expired")
259+
error.name = "ExpiredTokenException"
260+
throw error
261+
})
262+
263+
// Mock the BedrockRuntimeClient to use the custom send method
264+
bedrockMock.MockBedrockRuntimeClient.prototype.send = mockSend
265+
266+
// Create a handler with profile-based credentials
267+
const profileHandler = new AwsBedrockHandler({
268+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
269+
awsProfile: "test-profile",
270+
awsUseProfile: true,
271+
awsRegion: "us-east-1",
272+
})
273+
274+
// Mock the fromIni method to simulate refreshed credentials
275+
const mockFromIni = jest.fn().mockReturnValue({
276+
accessKeyId: "refreshed-access-key",
277+
secretAccessKey: "refreshed-secret-key",
278+
})
279+
const { fromIni } = require("@aws-sdk/credential-providers")
280+
fromIni.mockImplementation(mockFromIni)
281+
282+
// Attempt to create a message, which should trigger credential refresh
283+
const messageGenerator = profileHandler.createMessage("system prompt", [
284+
{ role: "user", content: "user message" },
285+
])
286+
287+
// Consume the generator to trigger the send method
288+
try {
289+
for await (const _ of messageGenerator) {
290+
// Just consume the messages
291+
}
292+
} catch (error) {
293+
// Ignore errors for this test
294+
}
295+
296+
// Verify that fromIni was called to refresh credentials
297+
expect(mockFromIni).toHaveBeenCalledWith({ profile: "test-profile" })
298+
299+
// Verify that the send method was called twice (once for the initial attempt and once after refresh)
300+
expect(mockSend).toHaveBeenCalledTimes(2)
252301
})
253302
})
254303
})

src/api/providers/__tests__/bedrock-invokedModelId.test.ts

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import { AwsBedrockHandler } from "../bedrock"
2+
import { ApiHandlerOptions } from "../../../shared/api"
3+
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
4+
const { fromIni } = require("@aws-sdk/credential-providers")
5+
16
// Mock AWS SDK credential providers and Bedrock client
27
jest.mock("@aws-sdk/credential-providers", () => ({
38
fromIni: jest.fn().mockReturnValue({
@@ -62,11 +67,6 @@ jest.mock("@aws-sdk/client-bedrock-runtime", () => {
6267
}
6368
})
6469

65-
import { AwsBedrockHandler, StreamEvent } from "../bedrock"
66-
import { ApiHandlerOptions } from "../../../shared/api"
67-
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
68-
const { fromIni } = require("@aws-sdk/credential-providers")
69-
7070
describe("AwsBedrockHandler with invokedModelId", () => {
7171
let mockSend: jest.Mock
7272

@@ -361,4 +361,51 @@ describe("AwsBedrockHandler with invokedModelId", () => {
361361
const costModel = handler.getModel()
362362
expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0")
363363
})
364+
365+
it("should refresh AWS credentials when they expire", async () => {
366+
// Mock the send method to simulate expired credentials
367+
const mockSend = jest.fn().mockImplementationOnce(async () => {
368+
const error = new Error("The security token included in the request is expired")
369+
error.name = "ExpiredTokenException"
370+
throw error
371+
})
372+
373+
// Mock the BedrockRuntimeClient to use the custom send method
374+
BedrockRuntimeClient.prototype.send = mockSend
375+
376+
// Create a handler with profile-based credentials
377+
const profileHandler = new AwsBedrockHandler({
378+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
379+
awsProfile: "test-profile",
380+
awsUseProfile: true,
381+
awsRegion: "us-east-1",
382+
})
383+
384+
// Mock the fromIni method to simulate refreshed credentials
385+
const mockFromIni = jest.fn().mockReturnValue({
386+
accessKeyId: "refreshed-access-key",
387+
secretAccessKey: "refreshed-secret-key",
388+
})
389+
fromIni.mockImplementation(mockFromIni)
390+
391+
// Attempt to create a message, which should trigger credential refresh
392+
const messageGenerator = profileHandler.createMessage("system prompt", [
393+
{ role: "user", content: "user message" },
394+
])
395+
396+
// Consume the generator to trigger the send method
397+
try {
398+
for await (const _ of messageGenerator) {
399+
// Just consume the messages
400+
}
401+
} catch (error) {
402+
// Ignore errors for this test
403+
}
404+
405+
// Verify that fromIni was called to refresh credentials
406+
expect(mockFromIni).toHaveBeenCalledWith({ profile: "test-profile" })
407+
408+
// Verify that the send method was called twice (once for the initial attempt and once after refresh)
409+
expect(mockSend).toHaveBeenCalledTimes(2)
410+
})
364411
})

src/api/providers/__tests__/bedrock.test.ts

Lines changed: 106 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
import { AwsBedrockHandler } from "../bedrock"
2+
import { MessageContent } from "../../../shared/api"
3+
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
4+
import { Anthropic } from "@anthropic-ai/sdk"
5+
const { fromIni } = require("@aws-sdk/credential-providers")
6+
import { logger } from "../../../utils/logging"
7+
18
// Mock AWS SDK credential providers
29
jest.mock("@aws-sdk/credential-providers", () => {
310
const mockFromIni = jest.fn().mockReturnValue({
@@ -7,12 +14,56 @@ jest.mock("@aws-sdk/credential-providers", () => {
714
return { fromIni: mockFromIni }
815
})
916

10-
import { AwsBedrockHandler } from "../bedrock"
11-
import { MessageContent } from "../../../shared/api"
12-
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
13-
import { Anthropic } from "@anthropic-ai/sdk"
14-
const { fromIni } = require("@aws-sdk/credential-providers")
15-
import { logger } from "../../../utils/logging"
17+
// Mock AWS SDK modules
18+
jest.mock("@aws-sdk/client-bedrock-runtime", () => {
19+
const mockSend = jest.fn().mockImplementation(async () => {
20+
return {
21+
$metadata: {
22+
httpStatusCode: 200,
23+
requestId: "mock-request-id",
24+
},
25+
stream: {
26+
[Symbol.asyncIterator]: async function* () {
27+
yield {
28+
metadata: {
29+
usage: {
30+
inputTokens: 100,
31+
outputTokens: 200,
32+
},
33+
},
34+
}
35+
},
36+
},
37+
}
38+
})
39+
40+
return {
41+
BedrockRuntimeClient: jest.fn().mockImplementation(() => ({
42+
send: mockSend,
43+
config: { region: "us-east-1" },
44+
middlewareStack: {
45+
clone: () => ({ resolve: () => {} }),
46+
use: () => {},
47+
},
48+
})),
49+
ConverseStreamCommand: jest.fn((params) => ({
50+
...params,
51+
input: params,
52+
middlewareStack: {
53+
clone: () => ({ resolve: () => {} }),
54+
use: () => {},
55+
},
56+
})),
57+
ConverseCommand: jest.fn((params) => ({
58+
...params,
59+
input: params,
60+
middlewareStack: {
61+
clone: () => ({ resolve: () => {} }),
62+
use: () => {},
63+
},
64+
})),
65+
}
66+
})
1667

1768
describe("AwsBedrockHandler", () => {
1869
let handler: AwsBedrockHandler
@@ -117,4 +168,53 @@ describe("AwsBedrockHandler", () => {
117168
expect(modelInfo.info.maxTokens).toBe(4096)
118169
})
119170
})
171+
172+
describe("AWS Credential Refresh", () => {
173+
it("should refresh AWS credentials when they expire", async () => {
174+
// Mock the send method to simulate expired credentials
175+
const mockSend = jest.fn().mockImplementationOnce(async () => {
176+
const error = new Error("The security token included in the request is expired")
177+
error.name = "ExpiredTokenException"
178+
throw error
179+
})
180+
181+
// Mock the BedrockRuntimeClient to use the custom send method
182+
BedrockRuntimeClient.prototype.send = mockSend
183+
184+
// Create a handler with profile-based credentials
185+
const profileHandler = new AwsBedrockHandler({
186+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
187+
awsProfile: "test-profile",
188+
awsUseProfile: true,
189+
awsRegion: "us-east-1",
190+
})
191+
192+
// Mock the fromIni method to simulate refreshed credentials
193+
const mockFromIni = jest.fn().mockReturnValue({
194+
accessKeyId: "refreshed-access-key",
195+
secretAccessKey: "refreshed-secret-key",
196+
})
197+
fromIni.mockImplementation(mockFromIni)
198+
199+
// Attempt to create a message, which should trigger credential refresh
200+
const messageGenerator = profileHandler.createMessage("system prompt", [
201+
{ role: "user", content: "user message" },
202+
])
203+
204+
// Consume the generator to trigger the send method
205+
try {
206+
for await (const _ of messageGenerator) {
207+
// Just consume the messages
208+
}
209+
} catch (error) {
210+
// Ignore errors for this test
211+
}
212+
213+
// Verify that fromIni was called to refresh credentials
214+
expect(mockFromIni).toHaveBeenCalledWith({ profile: "test-profile" })
215+
216+
// Verify that the send method was called twice (once for the initial attempt and once after refresh)
217+
expect(mockSend).toHaveBeenCalledTimes(2)
218+
})
219+
})
120220
})

0 commit comments

Comments
 (0)