Skip to content

Commit 7abfeb5

Browse files
committed
refactor: Update AWS credential handling to create new BedrockRuntimeClient instances with refreshed credentials
1 parent d1a8809 commit 7abfeb5

File tree

4 files changed

+139
-57
lines changed

4 files changed

+139
-57
lines changed

src/api/providers/__tests__/bedrock-custom-arn.test.ts

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@ import { AwsBedrockHandler } from "../bedrock"
22
import { ApiHandlerOptions } from "../../../shared/api"
33
import { logger } from "../../../utils/logging"
44

5+
// Mock AWS SDK credential providers
6+
jest.mock("@aws-sdk/credential-providers", () => {
7+
const mockFromIni = jest.fn().mockImplementation(() => {
8+
return async () => ({
9+
accessKeyId: "profile-access-key",
10+
secretAccessKey: "profile-secret-key",
11+
})
12+
})
13+
return { fromIni: mockFromIni }
14+
})
15+
516
// Mock the logger
617
jest.mock("../../../utils/logging", () => ({
718
logger: {
@@ -253,15 +264,21 @@ describe("Bedrock ARN Handling", () => {
253264
})
254265

255266
it("should refresh AWS credentials when they expire", async () => {
256-
// Mock the send method to simulate expired credentials
257-
const mockSend = jest.fn().mockImplementationOnce(async () => {
258-
const error = new Error("The security token included in the request is expired")
259-
error.name = "ExpiredTokenException"
260-
throw error
261-
})
262-
263-
// Mock the BedrockRuntimeClient to use the custom send method
264-
bedrockMock.MockBedrockRuntimeClient.prototype.send = mockSend
267+
// Get the mock send function
268+
const mockSend = bedrockMock.mockSend
269+
270+
// Configure mockSend to throw an error on first call and succeed on second call
271+
mockSend
272+
.mockImplementationOnce(async () => {
273+
const error = new Error("The security token included in the request is expired")
274+
error.name = "ExpiredTokenException"
275+
throw error
276+
})
277+
.mockImplementationOnce(async () => {
278+
return {
279+
output: new TextEncoder().encode(JSON.stringify({ content: "Test response" })),
280+
}
281+
})
265282

266283
// Create a handler with profile-based credentials
267284
const profileHandler = new AwsBedrockHandler({
@@ -271,13 +288,16 @@ describe("Bedrock ARN Handling", () => {
271288
awsRegion: "us-east-1",
272289
})
273290

291+
// Import fromIni after mocking
292+
const { fromIni } = require("@aws-sdk/credential-providers")
293+
274294
// Mock the fromIni method to simulate refreshed credentials
275-
const mockFromIni = jest.fn().mockReturnValue({
276-
accessKeyId: "refreshed-access-key",
277-
secretAccessKey: "refreshed-secret-key",
295+
fromIni.mockImplementation(() => {
296+
return async () => ({
297+
accessKeyId: "refreshed-access-key",
298+
secretAccessKey: "refreshed-secret-key",
299+
})
278300
})
279-
const { fromIni } = require("@aws-sdk/credential-providers")
280-
fromIni.mockImplementation(mockFromIni)
281301

282302
// Attempt to create a message, which should trigger credential refresh
283303
const messageGenerator = profileHandler.createMessage("system prompt", [
@@ -294,7 +314,7 @@ describe("Bedrock ARN Handling", () => {
294314
}
295315

296316
// Verify that fromIni was called to refresh credentials
297-
expect(mockFromIni).toHaveBeenCalledWith({ profile: "test-profile" })
317+
expect(fromIni).toHaveBeenCalledWith({ profile: "test-profile" })
298318

299319
// Verify that the send method was called twice (once for the initial attempt and once after refresh)
300320
expect(mockSend).toHaveBeenCalledTimes(2)

src/api/providers/__tests__/bedrock-invokedModelId.test.ts

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import { AwsBedrockHandler } from "../bedrock"
22
import { ApiHandlerOptions } from "../../../shared/api"
33
import { BedrockRuntimeClient } from "@aws-sdk/client-bedrock-runtime"
4-
const { fromIni } = require("@aws-sdk/credential-providers")
5-
64
// Mock AWS SDK credential providers and Bedrock client
7-
jest.mock("@aws-sdk/credential-providers", () => ({
8-
fromIni: jest.fn().mockReturnValue({
9-
accessKeyId: "profile-access-key",
10-
secretAccessKey: "profile-secret-key",
11-
}),
12-
}))
5+
jest.mock("@aws-sdk/credential-providers", () => {
6+
const mockFromIni = jest.fn().mockImplementation(() => {
7+
return async () => ({
8+
accessKeyId: "profile-access-key",
9+
secretAccessKey: "profile-secret-key",
10+
})
11+
})
12+
return { fromIni: mockFromIni }
13+
})
14+
15+
const { fromIni } = require("@aws-sdk/credential-providers")
1316

1417
// Mock Smithy client
1518
jest.mock("@smithy/smithy-client", () => ({
@@ -78,7 +81,7 @@ describe("AwsBedrockHandler with invokedModelId", () => {
7881
})
7982

8083
// Helper function to create a mock async iterable stream
81-
function createMockStream(events: StreamEvent[]) {
84+
function createMockStream(events: any[]) {
8285
return {
8386
[Symbol.asyncIterator]: async function* () {
8487
for (const event of events) {
@@ -363,15 +366,37 @@ describe("AwsBedrockHandler with invokedModelId", () => {
363366
})
364367

365368
it("should refresh AWS credentials when they expire", async () => {
366-
// Mock the send method to simulate expired credentials
367-
const mockSend = jest.fn().mockImplementationOnce(async () => {
368-
const error = new Error("The security token included in the request is expired")
369-
error.name = "ExpiredTokenException"
370-
throw error
371-
})
372-
373-
// Mock the BedrockRuntimeClient to use the custom send method
374-
BedrockRuntimeClient.prototype.send = mockSend
369+
// Get the mock send function from our mocked module
370+
const { BedrockRuntimeClient } = require("@aws-sdk/client-bedrock-runtime")
371+
const mockSend = BedrockRuntimeClient().send
372+
373+
// Configure mockSend to throw an error on first call and succeed on second call
374+
mockSend
375+
.mockImplementationOnce(async () => {
376+
const error = new Error("The security token included in the request is expired")
377+
error.name = "ExpiredTokenException"
378+
throw error
379+
})
380+
.mockImplementationOnce(async () => {
381+
return {
382+
$metadata: {
383+
httpStatusCode: 200,
384+
requestId: "mock-request-id",
385+
},
386+
stream: {
387+
[Symbol.asyncIterator]: async function* () {
388+
yield {
389+
metadata: {
390+
usage: {
391+
inputTokens: 100,
392+
outputTokens: 200,
393+
},
394+
},
395+
}
396+
},
397+
},
398+
}
399+
})
375400

376401
// Create a handler with profile-based credentials
377402
const profileHandler = new AwsBedrockHandler({
@@ -382,11 +407,12 @@ describe("AwsBedrockHandler with invokedModelId", () => {
382407
})
383408

384409
// Mock the fromIni method to simulate refreshed credentials
385-
const mockFromIni = jest.fn().mockReturnValue({
386-
accessKeyId: "refreshed-access-key",
387-
secretAccessKey: "refreshed-secret-key",
410+
fromIni.mockImplementation(() => {
411+
return async () => ({
412+
accessKeyId: "refreshed-access-key",
413+
secretAccessKey: "refreshed-secret-key",
414+
})
388415
})
389-
fromIni.mockImplementation(mockFromIni)
390416

391417
// Attempt to create a message, which should trigger credential refresh
392418
const messageGenerator = profileHandler.createMessage("system prompt", [
@@ -403,7 +429,7 @@ describe("AwsBedrockHandler with invokedModelId", () => {
403429
}
404430

405431
// Verify that fromIni was called to refresh credentials
406-
expect(mockFromIni).toHaveBeenCalledWith({ profile: "test-profile" })
432+
expect(fromIni).toHaveBeenCalledWith({ profile: "test-profile" })
407433

408434
// Verify that the send method was called twice (once for the initial attempt and once after refresh)
409435
expect(mockSend).toHaveBeenCalledTimes(2)

src/api/providers/__tests__/bedrock.test.ts

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ import { logger } from "../../../utils/logging"
77

88
// Mock AWS SDK credential providers
99
jest.mock("@aws-sdk/credential-providers", () => {
10-
const mockFromIni = jest.fn().mockReturnValue({
11-
accessKeyId: "profile-access-key",
12-
secretAccessKey: "profile-secret-key",
10+
const mockFromIni = jest.fn().mockImplementation(() => {
11+
return async () => ({
12+
accessKeyId: "profile-access-key",
13+
secretAccessKey: "profile-secret-key",
14+
})
1315
})
1416
return { fromIni: mockFromIni }
1517
})
@@ -171,15 +173,37 @@ describe("AwsBedrockHandler", () => {
171173

172174
describe("AWS Credential Refresh", () => {
173175
it("should refresh AWS credentials when they expire", async () => {
174-
// Mock the send method to simulate expired credentials
175-
const mockSend = jest.fn().mockImplementationOnce(async () => {
176-
const error = new Error("The security token included in the request is expired")
177-
error.name = "ExpiredTokenException"
178-
throw error
179-
})
180-
181-
// Mock the BedrockRuntimeClient to use the custom send method
182-
BedrockRuntimeClient.prototype.send = mockSend
176+
// Get the mock send function from our mocked module
177+
const { BedrockRuntimeClient } = require("@aws-sdk/client-bedrock-runtime")
178+
const mockSend = BedrockRuntimeClient().send
179+
180+
// Configure mockSend to throw an error on first call and succeed on second call
181+
mockSend
182+
.mockImplementationOnce(async () => {
183+
const error = new Error("The security token included in the request is expired")
184+
error.name = "ExpiredTokenException"
185+
throw error
186+
})
187+
.mockImplementationOnce(async () => {
188+
return {
189+
$metadata: {
190+
httpStatusCode: 200,
191+
requestId: "mock-request-id",
192+
},
193+
stream: {
194+
[Symbol.asyncIterator]: async function* () {
195+
yield {
196+
metadata: {
197+
usage: {
198+
inputTokens: 100,
199+
outputTokens: 200,
200+
},
201+
},
202+
}
203+
},
204+
},
205+
}
206+
})
183207

184208
// Create a handler with profile-based credentials
185209
const profileHandler = new AwsBedrockHandler({
@@ -190,11 +214,12 @@ describe("AwsBedrockHandler", () => {
190214
})
191215

192216
// Mock the fromIni method to simulate refreshed credentials
193-
const mockFromIni = jest.fn().mockReturnValue({
194-
accessKeyId: "refreshed-access-key",
195-
secretAccessKey: "refreshed-secret-key",
217+
fromIni.mockImplementation(() => {
218+
return async () => ({
219+
accessKeyId: "refreshed-access-key",
220+
secretAccessKey: "refreshed-secret-key",
221+
})
196222
})
197-
fromIni.mockImplementation(mockFromIni)
198223

199224
// Attempt to create a message, which should trigger credential refresh
200225
const messageGenerator = profileHandler.createMessage("system prompt", [
@@ -211,7 +236,7 @@ describe("AwsBedrockHandler", () => {
211236
}
212237

213238
// Verify that fromIni was called to refresh credentials
214-
expect(mockFromIni).toHaveBeenCalledWith({ profile: "test-profile" })
239+
expect(fromIni).toHaveBeenCalledWith({ profile: "test-profile" })
215240

216241
// Verify that the send method was called twice (once for the initial attempt and once after refresh)
217242
expect(mockSend).toHaveBeenCalledTimes(2)

src/api/providers/bedrock.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,13 +1132,24 @@ Suggestions:
11321132
const refreshedCredentials = await fromIni({
11331133
profile: this.options.awsProfile,
11341134
})()
1135-
this.client.config.credentials = refreshedCredentials
1135+
// Create a new client with the refreshed credentials
1136+
const clientConfig: BedrockRuntimeClientConfig = {
1137+
region: this.options.awsRegion,
1138+
credentials: refreshedCredentials,
1139+
}
1140+
this.client = new BedrockRuntimeClient(clientConfig)
11361141
} else if (this.options.awsAccessKey && this.options.awsSecretKey) {
1137-
this.client.config.credentials = {
1142+
const newCredentials = {
11381143
accessKeyId: this.options.awsAccessKey,
11391144
secretAccessKey: this.options.awsSecretKey,
11401145
...(this.options.awsSessionToken ? { sessionToken: this.options.awsSessionToken } : {}),
11411146
}
1147+
// Create a new client with the new credentials
1148+
const clientConfig: BedrockRuntimeClientConfig = {
1149+
region: this.options.awsRegion,
1150+
credentials: newCredentials,
1151+
}
1152+
this.client = new BedrockRuntimeClient(clientConfig)
11421153
}
11431154
}
11441155
}

0 commit comments

Comments
 (0)