Skip to content

Commit b033082

Browse files
Reflect Cross-region inference option in ap-xx region (#1842)
* Fix: Enable cross-region inference for 'ap-xx' region in AwsBedrockHandler.completePrompt * Fix: Enable cross-region inference for 'ap-xx' region in AwsBedrockHandler.createMessage * Create itchy-waves-move.md --------- Co-authored-by: Matt Rubens <[email protected]>
1 parent 842e69c commit b033082

File tree

3 files changed

+327
-1
lines changed

3 files changed

+327
-1
lines changed

.changeset/itchy-waves-move.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"roo-cline": patch
3+
---
4+
5+
Reflect Cross-region inference option in `ap-xx` region

src/api/providers/__tests__/bedrock.test.ts

Lines changed: 316 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,222 @@ describe("AwsBedrockHandler", () => {
165165
)
166166
})
167167

168+
it("should handle cross-region inference for us-xx region", async () => {
169+
const handlerWithProfile = new AwsBedrockHandler({
170+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
171+
awsAccessKey: "test-access-key",
172+
awsSecretKey: "test-secret-key",
173+
awsRegion: "us-east-1",
174+
awsUseCrossRegionInference: true,
175+
})
176+
177+
// Mock AWS SDK invoke
178+
const mockStream = {
179+
[Symbol.asyncIterator]: async function* () {
180+
yield {
181+
metadata: {
182+
usage: {
183+
inputTokens: 10,
184+
outputTokens: 5,
185+
},
186+
},
187+
}
188+
},
189+
}
190+
191+
const mockInvoke = jest.fn().mockResolvedValue({
192+
stream: mockStream,
193+
})
194+
195+
handlerWithProfile["client"] = {
196+
send: mockInvoke,
197+
} as unknown as BedrockRuntimeClient
198+
199+
const stream = handlerWithProfile.createMessage(systemPrompt, mockMessages)
200+
const chunks = []
201+
202+
for await (const chunk of stream) {
203+
chunks.push(chunk)
204+
}
205+
206+
expect(chunks.length).toBeGreaterThan(0)
207+
expect(chunks[0]).toEqual({
208+
type: "usage",
209+
inputTokens: 10,
210+
outputTokens: 5,
211+
})
212+
213+
expect(mockInvoke).toHaveBeenCalledWith(
214+
expect.objectContaining({
215+
input: expect.objectContaining({
216+
modelId: "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
217+
}),
218+
}),
219+
)
220+
})
221+
222+
it("should handle cross-region inference for eu-xx region", async () => {
223+
const handlerWithProfile = new AwsBedrockHandler({
224+
apiModelId: "anthropic.claude-3-5-sonnet-20240620-v1:0",
225+
awsAccessKey: "test-access-key",
226+
awsSecretKey: "test-secret-key",
227+
awsRegion: "eu-west-1",
228+
awsUseCrossRegionInference: true,
229+
})
230+
231+
// Mock AWS SDK invoke
232+
const mockStream = {
233+
[Symbol.asyncIterator]: async function* () {
234+
yield {
235+
metadata: {
236+
usage: {
237+
inputTokens: 10,
238+
outputTokens: 5,
239+
},
240+
},
241+
}
242+
},
243+
}
244+
245+
const mockInvoke = jest.fn().mockResolvedValue({
246+
stream: mockStream,
247+
})
248+
249+
handlerWithProfile["client"] = {
250+
send: mockInvoke,
251+
} as unknown as BedrockRuntimeClient
252+
253+
const stream = handlerWithProfile.createMessage(systemPrompt, mockMessages)
254+
const chunks = []
255+
256+
for await (const chunk of stream) {
257+
chunks.push(chunk)
258+
}
259+
260+
expect(chunks.length).toBeGreaterThan(0)
261+
expect(chunks[0]).toEqual({
262+
type: "usage",
263+
inputTokens: 10,
264+
outputTokens: 5,
265+
})
266+
267+
expect(mockInvoke).toHaveBeenCalledWith(
268+
expect.objectContaining({
269+
input: expect.objectContaining({
270+
modelId: "eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
271+
}),
272+
}),
273+
)
274+
})
275+
276+
it("should handle cross-region inference for ap-xx region", async () => {
277+
const handlerWithProfile = new AwsBedrockHandler({
278+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
279+
awsAccessKey: "test-access-key",
280+
awsSecretKey: "test-secret-key",
281+
awsRegion: "ap-northeast-1",
282+
awsUseCrossRegionInference: true,
283+
})
284+
285+
// Mock AWS SDK invoke
286+
const mockStream = {
287+
[Symbol.asyncIterator]: async function* () {
288+
yield {
289+
metadata: {
290+
usage: {
291+
inputTokens: 10,
292+
outputTokens: 5,
293+
},
294+
},
295+
}
296+
},
297+
}
298+
299+
const mockInvoke = jest.fn().mockResolvedValue({
300+
stream: mockStream,
301+
})
302+
303+
handlerWithProfile["client"] = {
304+
send: mockInvoke,
305+
} as unknown as BedrockRuntimeClient
306+
307+
const stream = handlerWithProfile.createMessage(systemPrompt, mockMessages)
308+
const chunks = []
309+
310+
for await (const chunk of stream) {
311+
chunks.push(chunk)
312+
}
313+
314+
expect(chunks.length).toBeGreaterThan(0)
315+
expect(chunks[0]).toEqual({
316+
type: "usage",
317+
inputTokens: 10,
318+
outputTokens: 5,
319+
})
320+
321+
expect(mockInvoke).toHaveBeenCalledWith(
322+
expect.objectContaining({
323+
input: expect.objectContaining({
324+
modelId: "apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
325+
}),
326+
}),
327+
)
328+
})
329+
330+
it("should handle cross-region inference for other region", async () => {
331+
const handlerWithProfile = new AwsBedrockHandler({
332+
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
333+
awsAccessKey: "test-access-key",
334+
awsSecretKey: "test-secret-key",
335+
awsRegion: "ca-central-1",
336+
awsUseCrossRegionInference: true,
337+
})
338+
339+
// Mock AWS SDK invoke
340+
const mockStream = {
341+
[Symbol.asyncIterator]: async function* () {
342+
yield {
343+
metadata: {
344+
usage: {
345+
inputTokens: 10,
346+
outputTokens: 5,
347+
},
348+
},
349+
}
350+
},
351+
}
352+
353+
const mockInvoke = jest.fn().mockResolvedValue({
354+
stream: mockStream,
355+
})
356+
357+
handlerWithProfile["client"] = {
358+
send: mockInvoke,
359+
} as unknown as BedrockRuntimeClient
360+
361+
const stream = handlerWithProfile.createMessage(systemPrompt, mockMessages)
362+
const chunks = []
363+
364+
for await (const chunk of stream) {
365+
chunks.push(chunk)
366+
}
367+
368+
expect(chunks.length).toBeGreaterThan(0)
369+
expect(chunks[0]).toEqual({
370+
type: "usage",
371+
inputTokens: 10,
372+
outputTokens: 5,
373+
})
374+
375+
expect(mockInvoke).toHaveBeenCalledWith(
376+
expect.objectContaining({
377+
input: expect.objectContaining({
378+
modelId: "anthropic.claude-3-sonnet-20240229-v1:0",
379+
}),
380+
}),
381+
)
382+
})
383+
168384
it("should handle API errors", async () => {
169385
// Mock AWS SDK invoke with error
170386
const mockInvoke = jest.fn().mockRejectedValue(new Error("AWS Bedrock error"))
@@ -260,7 +476,7 @@ describe("AwsBedrockHandler", () => {
260476
expect(result).toBe("")
261477
})
262478

263-
it("should handle cross-region inference", async () => {
479+
it("should handle cross-region inference for us-xx region", async () => {
264480
handler = new AwsBedrockHandler({
265481
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
266482
awsAccessKey: "test-access-key",
@@ -292,6 +508,105 @@ describe("AwsBedrockHandler", () => {
292508
}),
293509
)
294510
})
511+
512+
it("should handle cross-region inference for eu-xx region", async () => {
513+
handler = new AwsBedrockHandler({
514+
apiModelId: "anthropic.claude-3-5-sonnet-20240620-v1:0",
515+
awsAccessKey: "test-access-key",
516+
awsSecretKey: "test-secret-key",
517+
awsRegion: "eu-west-1",
518+
awsUseCrossRegionInference: true,
519+
})
520+
521+
const mockResponse = {
522+
output: new TextEncoder().encode(
523+
JSON.stringify({
524+
content: "Test response",
525+
}),
526+
),
527+
}
528+
529+
const mockSend = jest.fn().mockResolvedValue(mockResponse)
530+
handler["client"] = {
531+
send: mockSend,
532+
} as unknown as BedrockRuntimeClient
533+
534+
const result = await handler.completePrompt("Test prompt")
535+
expect(result).toBe("Test response")
536+
expect(mockSend).toHaveBeenCalledWith(
537+
expect.objectContaining({
538+
input: expect.objectContaining({
539+
modelId: "eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
540+
}),
541+
}),
542+
)
543+
})
544+
545+
it("should handle cross-region inference for ap-xx region", async () => {
546+
handler = new AwsBedrockHandler({
547+
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
548+
awsAccessKey: "test-access-key",
549+
awsSecretKey: "test-secret-key",
550+
awsRegion: "ap-northeast-1",
551+
awsUseCrossRegionInference: true,
552+
})
553+
554+
const mockResponse = {
555+
output: new TextEncoder().encode(
556+
JSON.stringify({
557+
content: "Test response",
558+
}),
559+
),
560+
}
561+
562+
const mockSend = jest.fn().mockResolvedValue(mockResponse)
563+
handler["client"] = {
564+
send: mockSend,
565+
} as unknown as BedrockRuntimeClient
566+
567+
const result = await handler.completePrompt("Test prompt")
568+
expect(result).toBe("Test response")
569+
expect(mockSend).toHaveBeenCalledWith(
570+
expect.objectContaining({
571+
input: expect.objectContaining({
572+
modelId: "apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
573+
}),
574+
}),
575+
)
576+
})
577+
578+
it("should handle cross-region inference for other regions", async () => {
579+
handler = new AwsBedrockHandler({
580+
apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0",
581+
awsAccessKey: "test-access-key",
582+
awsSecretKey: "test-secret-key",
583+
awsRegion: "ca-central-1",
584+
awsUseCrossRegionInference: true,
585+
})
586+
587+
const mockResponse = {
588+
output: new TextEncoder().encode(
589+
JSON.stringify({
590+
content: "Test response",
591+
}),
592+
),
593+
}
594+
595+
const mockSend = jest.fn().mockResolvedValue(mockResponse)
596+
handler["client"] = {
597+
send: mockSend,
598+
} as unknown as BedrockRuntimeClient
599+
600+
const result = await handler.completePrompt("Test prompt")
601+
expect(result).toBe("Test response")
602+
expect(mockSend).toHaveBeenCalledWith(
603+
expect.objectContaining({
604+
input: expect.objectContaining({
605+
modelId: "anthropic.claude-3-sonnet-20240229-v1:0",
606+
}),
607+
}),
608+
)
609+
})
295610
})
296611

297612
describe("getModel", () => {

src/api/providers/bedrock.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
211211
case "eu-":
212212
modelId = `eu.${modelConfig.id}`
213213
break
214+
case "ap-":
215+
modelId = `apac.${modelConfig.id}`
216+
break
214217
default:
215218
modelId = modelConfig.id
216219
break
@@ -610,6 +613,9 @@ Please check:
610613
case "eu-":
611614
modelId = `eu.${modelConfig.id}`
612615
break
616+
case "ap-":
617+
modelId = `apac.${modelConfig.id}`
618+
break
613619
default:
614620
modelId = modelConfig.id
615621
break

0 commit comments

Comments
 (0)