Skip to content

Commit 0dabea4

Browse files
Fix: Enable cross-region inference for 'ap-xx' region in AwsBedrockHandler.completePrompt
1 parent f1efeca commit 0dabea4

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

src/api/providers/__tests__/bedrock.test.ts

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ describe("AwsBedrockHandler", () => {
260260
expect(result).toBe("")
261261
})
262262

263-
it("should handle cross-region inference", async () => {
263+
it("should handle cross-region inference for us-xx region", async () => {
264264
handler = new AwsBedrockHandler({
265265
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
266266
awsAccessKey: "test-access-key",
@@ -292,6 +292,105 @@ describe("AwsBedrockHandler", () => {
292292
}),
293293
)
294294
})
295+
296+
it("should handle cross-region inference for eu-xx region", async () => {
297+
handler = new AwsBedrockHandler({
298+
apiModelId: "anthropic.claude-3-5-sonnet-20240620-v1:0",
299+
awsAccessKey: "test-access-key",
300+
awsSecretKey: "test-secret-key",
301+
awsRegion: "eu-west-1",
302+
awsUseCrossRegionInference: true,
303+
})
304+
305+
const mockResponse = {
306+
output: new TextEncoder().encode(
307+
JSON.stringify({
308+
content: "Test response",
309+
}),
310+
),
311+
}
312+
313+
const mockSend = jest.fn().mockResolvedValue(mockResponse)
314+
handler["client"] = {
315+
send: mockSend,
316+
} as unknown as BedrockRuntimeClient
317+
318+
const result = await handler.completePrompt("Test prompt")
319+
expect(result).toBe("Test response")
320+
expect(mockSend).toHaveBeenCalledWith(
321+
expect.objectContaining({
322+
input: expect.objectContaining({
323+
modelId: "eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
324+
}),
325+
}),
326+
)
327+
})
328+
329+
it("should handle cross-region inference for ap-xx region", async () => {
330+
handler = new AwsBedrockHandler({
331+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
332+
awsAccessKey: "test-access-key",
333+
awsSecretKey: "test-secret-key",
334+
awsRegion: "ap-northeast-1",
335+
awsUseCrossRegionInference: true,
336+
})
337+
338+
const mockResponse = {
339+
output: new TextEncoder().encode(
340+
JSON.stringify({
341+
content: "Test response",
342+
}),
343+
),
344+
}
345+
346+
const mockSend = jest.fn().mockResolvedValue(mockResponse)
347+
handler["client"] = {
348+
send: mockSend,
349+
} as unknown as BedrockRuntimeClient
350+
351+
const result = await handler.completePrompt("Test prompt")
352+
expect(result).toBe("Test response")
353+
expect(mockSend).toHaveBeenCalledWith(
354+
expect.objectContaining({
355+
input: expect.objectContaining({
356+
modelId: "apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
357+
}),
358+
}),
359+
)
360+
})
361+
362+
it("should handle cross-region inference for other regions", async () => {
363+
handler = new AwsBedrockHandler({
364+
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
365+
awsAccessKey: "test-access-key",
366+
awsSecretKey: "test-secret-key",
367+
awsRegion: "ca-central-1",
368+
awsUseCrossRegionInference: true,
369+
})
370+
371+
const mockResponse = {
372+
output: new TextEncoder().encode(
373+
JSON.stringify({
374+
content: "Test response",
375+
}),
376+
),
377+
}
378+
379+
const mockSend = jest.fn().mockResolvedValue(mockResponse)
380+
handler["client"] = {
381+
send: mockSend,
382+
} as unknown as BedrockRuntimeClient
383+
384+
const result = await handler.completePrompt("Test prompt")
385+
expect(result).toBe("Test response")
386+
expect(mockSend).toHaveBeenCalledWith(
387+
expect.objectContaining({
388+
input: expect.objectContaining({
389+
modelId: "anthropic.claude-3-sonnet-20240229-v1:0",
390+
}),
391+
}),
392+
)
393+
})
295394
})
296395

297396
describe("getModel", () => {

src/api/providers/bedrock.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,9 @@ Please check:
610610
case "eu-":
611611
modelId = `eu.${modelConfig.id}`
612612
break
613+
case "ap-":
614+
modelId = `apac.${modelConfig.id}`
615+
break
613616
default:
614617
modelId = modelConfig.id
615618
break

0 commit comments

Comments
 (0)