Skip to content

Commit d032db2

Browse files
Fix: Enable cross-region inference for 'ap-xx' region in AwsBedrockHandler.createMessage
1 parent 0dabea4 commit d032db2

File tree

2 files changed

+219
-0
lines changed

2 files changed

+219
-0
lines changed

src/api/providers/__tests__/bedrock.test.ts

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,222 @@ describe("AwsBedrockHandler", () => {
165165
)
166166
})
167167

168+
it("should handle cross-region inference for us-xx region", async () => {
169+
const handlerWithProfile = new AwsBedrockHandler({
170+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
171+
awsAccessKey: "test-access-key",
172+
awsSecretKey: "test-secret-key",
173+
awsRegion: "us-east-1",
174+
awsUseCrossRegionInference: true,
175+
})
176+
177+
// Mock AWS SDK invoke
178+
const mockStream = {
179+
[Symbol.asyncIterator]: async function* () {
180+
yield {
181+
metadata: {
182+
usage: {
183+
inputTokens: 10,
184+
outputTokens: 5,
185+
},
186+
},
187+
}
188+
},
189+
}
190+
191+
const mockInvoke = jest.fn().mockResolvedValue({
192+
stream: mockStream,
193+
})
194+
195+
handlerWithProfile["client"] = {
196+
send: mockInvoke,
197+
} as unknown as BedrockRuntimeClient
198+
199+
const stream = handlerWithProfile.createMessage(systemPrompt, mockMessages)
200+
const chunks = []
201+
202+
for await (const chunk of stream) {
203+
chunks.push(chunk)
204+
}
205+
206+
expect(chunks.length).toBeGreaterThan(0)
207+
expect(chunks[0]).toEqual({
208+
type: "usage",
209+
inputTokens: 10,
210+
outputTokens: 5,
211+
})
212+
213+
expect(mockInvoke).toHaveBeenCalledWith(
214+
expect.objectContaining({
215+
input: expect.objectContaining({
216+
modelId: "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
217+
}),
218+
}),
219+
)
220+
})
221+
222+
it("should handle cross-region inference for eu-xx region", async () => {
223+
const handlerWithProfile = new AwsBedrockHandler({
224+
apiModelId: "anthropic.claude-3-5-sonnet-20240620-v1:0",
225+
awsAccessKey: "test-access-key",
226+
awsSecretKey: "test-secret-key",
227+
awsRegion: "eu-west-1",
228+
awsUseCrossRegionInference: true,
229+
})
230+
231+
// Mock AWS SDK invoke
232+
const mockStream = {
233+
[Symbol.asyncIterator]: async function* () {
234+
yield {
235+
metadata: {
236+
usage: {
237+
inputTokens: 10,
238+
outputTokens: 5,
239+
},
240+
},
241+
}
242+
},
243+
}
244+
245+
const mockInvoke = jest.fn().mockResolvedValue({
246+
stream: mockStream,
247+
})
248+
249+
handlerWithProfile["client"] = {
250+
send: mockInvoke,
251+
} as unknown as BedrockRuntimeClient
252+
253+
const stream = handlerWithProfile.createMessage(systemPrompt, mockMessages)
254+
const chunks = []
255+
256+
for await (const chunk of stream) {
257+
chunks.push(chunk)
258+
}
259+
260+
expect(chunks.length).toBeGreaterThan(0)
261+
expect(chunks[0]).toEqual({
262+
type: "usage",
263+
inputTokens: 10,
264+
outputTokens: 5,
265+
})
266+
267+
expect(mockInvoke).toHaveBeenCalledWith(
268+
expect.objectContaining({
269+
input: expect.objectContaining({
270+
modelId: "eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
271+
}),
272+
}),
273+
)
274+
})
275+
276+
it("should handle cross-region inference for ap-xx region", async () => {
277+
const handlerWithProfile = new AwsBedrockHandler({
278+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
279+
awsAccessKey: "test-access-key",
280+
awsSecretKey: "test-secret-key",
281+
awsRegion: "ap-northeast-1",
282+
awsUseCrossRegionInference: true,
283+
})
284+
285+
// Mock AWS SDK invoke
286+
const mockStream = {
287+
[Symbol.asyncIterator]: async function* () {
288+
yield {
289+
metadata: {
290+
usage: {
291+
inputTokens: 10,
292+
outputTokens: 5,
293+
},
294+
},
295+
}
296+
},
297+
}
298+
299+
const mockInvoke = jest.fn().mockResolvedValue({
300+
stream: mockStream,
301+
})
302+
303+
handlerWithProfile["client"] = {
304+
send: mockInvoke,
305+
} as unknown as BedrockRuntimeClient
306+
307+
const stream = handlerWithProfile.createMessage(systemPrompt, mockMessages)
308+
const chunks = []
309+
310+
for await (const chunk of stream) {
311+
chunks.push(chunk)
312+
}
313+
314+
expect(chunks.length).toBeGreaterThan(0)
315+
expect(chunks[0]).toEqual({
316+
type: "usage",
317+
inputTokens: 10,
318+
outputTokens: 5,
319+
})
320+
321+
expect(mockInvoke).toHaveBeenCalledWith(
322+
expect.objectContaining({
323+
input: expect.objectContaining({
324+
modelId: "apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
325+
}),
326+
}),
327+
)
328+
})
329+
330+
it("should handle cross-region inference for other region", async () => {
331+
const handlerWithProfile = new AwsBedrockHandler({
332+
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
333+
awsAccessKey: "test-access-key",
334+
awsSecretKey: "test-secret-key",
335+
awsRegion: "ca-central-1",
336+
awsUseCrossRegionInference: true,
337+
})
338+
339+
// Mock AWS SDK invoke
340+
const mockStream = {
341+
[Symbol.asyncIterator]: async function* () {
342+
yield {
343+
metadata: {
344+
usage: {
345+
inputTokens: 10,
346+
outputTokens: 5,
347+
},
348+
},
349+
}
350+
},
351+
}
352+
353+
const mockInvoke = jest.fn().mockResolvedValue({
354+
stream: mockStream,
355+
})
356+
357+
handlerWithProfile["client"] = {
358+
send: mockInvoke,
359+
} as unknown as BedrockRuntimeClient
360+
361+
const stream = handlerWithProfile.createMessage(systemPrompt, mockMessages)
362+
const chunks = []
363+
364+
for await (const chunk of stream) {
365+
chunks.push(chunk)
366+
}
367+
368+
expect(chunks.length).toBeGreaterThan(0)
369+
expect(chunks[0]).toEqual({
370+
type: "usage",
371+
inputTokens: 10,
372+
outputTokens: 5,
373+
})
374+
375+
expect(mockInvoke).toHaveBeenCalledWith(
376+
expect.objectContaining({
377+
input: expect.objectContaining({
378+
modelId: "anthropic.claude-3-sonnet-20240229-v1:0",
379+
}),
380+
}),
381+
)
382+
})
383+
168384
it("should handle API errors", async () => {
169385
// Mock AWS SDK invoke with error
170386
const mockInvoke = jest.fn().mockRejectedValue(new Error("AWS Bedrock error"))

src/api/providers/bedrock.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
211211
case "eu-":
212212
modelId = `eu.${modelConfig.id}`
213213
break
214+
case "ap-":
215+
modelId = `apac.${modelConfig.id}`
216+
break
214217
default:
215218
modelId = modelConfig.id
216219
break

0 commit comments

Comments
 (0)