diff --git a/packages/types/src/providers/bedrock.ts b/packages/types/src/providers/bedrock.ts index a15f041252..58e860dd94 100644 --- a/packages/types/src/providers/bedrock.ts +++ b/packages/types/src/providers/bedrock.ts @@ -360,78 +360,49 @@ export const BEDROCK_MAX_TOKENS = 4096 export const BEDROCK_DEFAULT_CONTEXT = 128_000 -export const BEDROCK_REGION_INFO: Record< - string, - { - regionId: string - description: string - pattern?: string - multiRegion?: boolean - } -> = { - /* - * This JSON generated by AWS's AI assistant - Amazon Q on March 29, 2025 - * - * - Africa (Cape Town) region does not appear to support Amazon Bedrock at this time. - * - Some Asia Pacific regions, such as Asia Pacific (Hong Kong) and Asia Pacific (Jakarta), are not listed among the supported regions for Bedrock services. - * - Middle East regions, including Middle East (Bahrain) and Middle East (UAE), are not mentioned in the list of supported regions for Bedrock. [3] - * - China regions (Beijing and Ningxia) are not listed as supported for Amazon Bedrock. - * - Some newer or specialized AWS regions may not have Bedrock support yet. - */ - "us.": { regionId: "us-east-1", description: "US East (N. Virginia)", pattern: "us-", multiRegion: true }, - "use.": { regionId: "us-east-1", description: "US East (N. Virginia)" }, - "use1.": { regionId: "us-east-1", description: "US East (N. Virginia)" }, - "use2.": { regionId: "us-east-2", description: "US East (Ohio)" }, - "usw.": { regionId: "us-west-2", description: "US West (Oregon)" }, - "usw2.": { regionId: "us-west-2", description: "US West (Oregon)" }, - "ug.": { - regionId: "us-gov-west-1", - description: "AWS GovCloud (US-West)", - pattern: "us-gov-", - multiRegion: true, - }, - "uge1.": { regionId: "us-gov-east-1", description: "AWS GovCloud (US-East)" }, - "ugw1.": { regionId: "us-gov-west-1", description: "AWS GovCloud (US-West)" }, - "eu.": { regionId: "eu-west-1", description: "Europe (Ireland)", pattern: "eu-", multiRegion: true }, - "euw1.": { regionId: "eu-west-1", description: "Europe (Ireland)" }, - "euw2.": { regionId: "eu-west-2", description: "Europe (London)" }, - "euw3.": { regionId: "eu-west-3", description: "Europe (Paris)" }, - "euc1.": { regionId: "eu-central-1", description: "Europe (Frankfurt)" }, - "euc2.": { regionId: "eu-central-2", description: "Europe (Zurich)" }, - "eun1.": { regionId: "eu-north-1", description: "Europe (Stockholm)" }, - "eus1.": { regionId: "eu-south-1", description: "Europe (Milan)" }, - "eus2.": { regionId: "eu-south-2", description: "Europe (Spain)" }, - "ap.": { - regionId: "ap-southeast-1", - description: "Asia Pacific (Singapore)", - pattern: "ap-", - multiRegion: true, - }, - "ape1.": { regionId: "ap-east-1", description: "Asia Pacific (Hong Kong)" }, - "apne1.": { regionId: "ap-northeast-1", description: "Asia Pacific (Tokyo)" }, - "apne2.": { regionId: "ap-northeast-2", description: "Asia Pacific (Seoul)" }, - "apne3.": { regionId: "ap-northeast-3", description: "Asia Pacific (Osaka)" }, - "aps1.": { regionId: "ap-south-1", description: "Asia Pacific (Mumbai)" }, - "aps2.": { regionId: "ap-south-2", description: "Asia Pacific (Hyderabad)" }, - "apse1.": { regionId: "ap-southeast-1", description: "Asia Pacific (Singapore)" }, - "apse2.": { regionId: "ap-southeast-2", description: "Asia Pacific (Sydney)" }, - "ca.": { regionId: "ca-central-1", description: "Canada (Central)", pattern: "ca-", multiRegion: true }, - "cac1.": { regionId: "ca-central-1", description: "Canada (Central)" }, - "sa.": { regionId: "sa-east-1", description: "South America (São Paulo)", pattern: "sa-", multiRegion: true }, - "sae1.": { regionId: "sa-east-1", description: "South America (São Paulo)" }, +// AWS Bedrock Inference Profile mapping based on official documentation +// https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html +// This mapping is pre-ordered by pattern length (descending) to ensure more specific patterns match first +export const AWS_INFERENCE_PROFILE_MAPPING: Array<[string, string]> = [ + // US Government Cloud → ug. inference profile (most specific prefix first) + ["us-gov-", "ug."], + // Americas regions → us. inference profile + ["us-", "us."], + // Europe regions → eu. inference profile + ["eu-", "eu."], + // Asia Pacific regions → apac. inference profile + ["ap-", "apac."], + // Canada regions → ca. inference profile + ["ca-", "ca."], + // South America regions → sa. inference profile + ["sa-", "sa."], +] - // These are not official - they weren't generated by Amazon Q nor were - // found in the AWS documentation but another Roo contributor found apac. - // Was needed so I've added the pattern of the other geo zones. - "apac.": { regionId: "ap-southeast-1", description: "Default APAC region", pattern: "ap-", multiRegion: true }, - "emea.": { regionId: "eu-west-1", description: "Default EMEA region", pattern: "eu-", multiRegion: true }, - "amer.": { regionId: "us-east-1", description: "Default Americas region", pattern: "us-", multiRegion: true }, -} - -export const BEDROCK_REGIONS = Object.values(BEDROCK_REGION_INFO) - // Extract all region IDs - .map((info) => ({ value: info.regionId, label: info.regionId })) - // Filter to unique region IDs (remove duplicates) - .filter((region, index, self) => index === self.findIndex((r) => r.value === region.value)) - // Sort alphabetically by region ID - .sort((a, b) => a.value.localeCompare(b.value)) +// AWS Bedrock supported regions for the regions dropdown +// Based on official AWS documentation +export const BEDROCK_REGIONS = [ + { value: "us-east-1", label: "us-east-1" }, + { value: "us-east-2", label: "us-east-2" }, + { value: "us-west-1", label: "us-west-1" }, + { value: "us-west-2", label: "us-west-2" }, + { value: "ap-northeast-1", label: "ap-northeast-1" }, + { value: "ap-northeast-2", label: "ap-northeast-2" }, + { value: "ap-northeast-3", label: "ap-northeast-3" }, + { value: "ap-south-1", label: "ap-south-1" }, + { value: "ap-south-2", label: "ap-south-2" }, + { value: "ap-southeast-1", label: "ap-southeast-1" }, + { value: "ap-southeast-2", label: "ap-southeast-2" }, + { value: "ap-east-1", label: "ap-east-1" }, + { value: "eu-central-1", label: "eu-central-1" }, + { value: "eu-central-2", label: "eu-central-2" }, + { value: "eu-west-1", label: "eu-west-1" }, + { value: "eu-west-2", label: "eu-west-2" }, + { value: "eu-west-3", label: "eu-west-3" }, + { value: "eu-north-1", label: "eu-north-1" }, + { value: "eu-south-1", label: "eu-south-1" }, + { value: "eu-south-2", label: "eu-south-2" }, + { value: "ca-central-1", label: "ca-central-1" }, + { value: "sa-east-1", label: "sa-east-1" }, + { value: "us-gov-east-1", label: "us-gov-east-1" }, + { value: "us-gov-west-1", label: "us-gov-west-1" }, +].sort((a, b) => a.value.localeCompare(b.value)) diff --git a/src/api/providers/__tests__/bedrock-inference-profiles.spec.ts b/src/api/providers/__tests__/bedrock-inference-profiles.spec.ts new file mode 100644 index 0000000000..7eef16d241 --- /dev/null +++ b/src/api/providers/__tests__/bedrock-inference-profiles.spec.ts @@ -0,0 +1,249 @@ +// npx vitest run src/api/providers/__tests__/bedrock-inference-profiles.spec.ts + +import { AWS_INFERENCE_PROFILE_MAPPING } from "@roo-code/types" +import { AwsBedrockHandler } from "../bedrock" +import { ApiHandlerOptions } from "../../../shared/api" + +// Mock AWS SDK +vitest.mock("@aws-sdk/client-bedrock-runtime", () => { + return { + BedrockRuntimeClient: vitest.fn().mockImplementation(() => ({ + send: vitest.fn(), + config: { region: "us-east-1" }, + })), + ConverseCommand: vitest.fn(), + ConverseStreamCommand: vitest.fn(), + } +}) + +describe("AWS Bedrock Inference Profiles", () => { + // Helper function to create a handler with specific options + const createHandler = (options: Partial = {}) => { + const defaultOptions: ApiHandlerOptions = { + apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0", + awsRegion: "us-east-1", + ...options, + } + return new AwsBedrockHandler(defaultOptions) + } + + describe("AWS_INFERENCE_PROFILE_MAPPING constant", () => { + it("should contain all expected region mappings", () => { + expect(AWS_INFERENCE_PROFILE_MAPPING).toEqual([ + ["us-gov-", "ug."], + ["us-", "us."], + ["eu-", "eu."], + ["ap-", "apac."], + ["ca-", "ca."], + ["sa-", "sa."], + ]) + }) + + it("should be ordered by pattern length (descending)", () => { + const lengths = AWS_INFERENCE_PROFILE_MAPPING.map(([pattern]) => pattern.length) + const sortedLengths = [...lengths].sort((a, b) => b - a) + expect(lengths).toEqual(sortedLengths) + }) + + it("should have valid inference profile prefixes", () => { + AWS_INFERENCE_PROFILE_MAPPING.forEach(([regionPattern, inferenceProfile]) => { + expect(regionPattern).toMatch(/^[a-z-]+$/) + expect(inferenceProfile).toMatch(/^[a-z]+\.$/) + }) + }) + }) + + describe("getPrefixForRegion function", () => { + it("should return correct prefix for US government regions", () => { + const handler = createHandler() + expect((handler as any).constructor.getPrefixForRegion("us-gov-east-1")).toBe("ug.") + expect((handler as any).constructor.getPrefixForRegion("us-gov-west-1")).toBe("ug.") + }) + + it("should return correct prefix for US commercial regions", () => { + const handler = createHandler() + expect((handler as any).constructor.getPrefixForRegion("us-east-1")).toBe("us.") + expect((handler as any).constructor.getPrefixForRegion("us-west-1")).toBe("us.") + expect((handler as any).constructor.getPrefixForRegion("us-west-2")).toBe("us.") + }) + + it("should return correct prefix for European regions", () => { + const handler = createHandler() + expect((handler as any).constructor.getPrefixForRegion("eu-west-1")).toBe("eu.") + expect((handler as any).constructor.getPrefixForRegion("eu-central-1")).toBe("eu.") + expect((handler as any).constructor.getPrefixForRegion("eu-north-1")).toBe("eu.") + expect((handler as any).constructor.getPrefixForRegion("eu-south-1")).toBe("eu.") + }) + + it("should return correct prefix for Asia Pacific regions", () => { + const handler = createHandler() + expect((handler as any).constructor.getPrefixForRegion("ap-southeast-1")).toBe("apac.") + expect((handler as any).constructor.getPrefixForRegion("ap-northeast-1")).toBe("apac.") + expect((handler as any).constructor.getPrefixForRegion("ap-south-1")).toBe("apac.") + expect((handler as any).constructor.getPrefixForRegion("ap-east-1")).toBe("apac.") + }) + + it("should return correct prefix for Canada regions", () => { + const handler = createHandler() + expect((handler as any).constructor.getPrefixForRegion("ca-central-1")).toBe("ca.") + expect((handler as any).constructor.getPrefixForRegion("ca-west-1")).toBe("ca.") + }) + + it("should return correct prefix for South America regions", () => { + const handler = createHandler() + expect((handler as any).constructor.getPrefixForRegion("sa-east-1")).toBe("sa.") + }) + + it("should return undefined for unsupported regions", () => { + const handler = createHandler() + expect((handler as any).constructor.getPrefixForRegion("af-south-1")).toBeUndefined() + expect((handler as any).constructor.getPrefixForRegion("me-south-1")).toBeUndefined() + expect((handler as any).constructor.getPrefixForRegion("cn-north-1")).toBeUndefined() + expect((handler as any).constructor.getPrefixForRegion("invalid-region")).toBeUndefined() + }) + + it("should prioritize longer patterns over shorter ones", () => { + const handler = createHandler() + // us-gov- should be matched before us- + expect((handler as any).constructor.getPrefixForRegion("us-gov-east-1")).toBe("ug.") + expect((handler as any).constructor.getPrefixForRegion("us-gov-west-1")).toBe("ug.") + + // Regular us- regions should still work + expect((handler as any).constructor.getPrefixForRegion("us-east-1")).toBe("us.") + expect((handler as any).constructor.getPrefixForRegion("us-west-2")).toBe("us.") + }) + }) + + describe("Cross-region inference integration", () => { + it("should apply ug. prefix for US government regions", () => { + const handler = createHandler({ + awsUseCrossRegionInference: true, + awsRegion: "us-gov-east-1", + apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0", + }) + + const model = handler.getModel() + expect(model.id).toBe("ug.anthropic.claude-3-sonnet-20240229-v1:0") + }) + + it("should apply us. prefix for US commercial regions", () => { + const handler = createHandler({ + awsUseCrossRegionInference: true, + awsRegion: "us-east-1", + apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0", + }) + + const model = handler.getModel() + expect(model.id).toBe("us.anthropic.claude-3-sonnet-20240229-v1:0") + }) + + it("should apply eu. prefix for European regions", () => { + const handler = createHandler({ + awsUseCrossRegionInference: true, + awsRegion: "eu-west-1", + apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0", + }) + + const model = handler.getModel() + expect(model.id).toBe("eu.anthropic.claude-3-sonnet-20240229-v1:0") + }) + + it("should apply apac. prefix for Asia Pacific regions", () => { + const handler = createHandler({ + awsUseCrossRegionInference: true, + awsRegion: "ap-southeast-1", + apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0", + }) + + const model = handler.getModel() + expect(model.id).toBe("apac.anthropic.claude-3-sonnet-20240229-v1:0") + }) + + it("should apply ca. prefix for Canada regions", () => { + const handler = createHandler({ + awsUseCrossRegionInference: true, + awsRegion: "ca-central-1", + apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0", + }) + + const model = handler.getModel() + expect(model.id).toBe("ca.anthropic.claude-3-sonnet-20240229-v1:0") + }) + + it("should apply sa. prefix for South America regions", () => { + const handler = createHandler({ + awsUseCrossRegionInference: true, + awsRegion: "sa-east-1", + apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0", + }) + + const model = handler.getModel() + expect(model.id).toBe("sa.anthropic.claude-3-sonnet-20240229-v1:0") + }) + + it("should not apply prefix when cross-region inference is disabled", () => { + const handler = createHandler({ + awsUseCrossRegionInference: false, + awsRegion: "us-gov-east-1", + apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0", + }) + + const model = handler.getModel() + expect(model.id).toBe("anthropic.claude-3-sonnet-20240229-v1:0") + }) + + it("should handle unsupported regions gracefully", () => { + const handler = createHandler({ + awsUseCrossRegionInference: true, + awsRegion: "af-south-1", // Unsupported region + apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0", + }) + + const model = handler.getModel() + // Should remain unchanged when no prefix is found + expect(model.id).toBe("anthropic.claude-3-sonnet-20240229-v1:0") + }) + + it("should work with different model IDs", () => { + const testModels = [ + "anthropic.claude-3-haiku-20240307-v1:0", + "anthropic.claude-3-opus-20240229-v1:0", + "amazon.nova-pro-v1:0", + "meta.llama3-1-70b-instruct-v1:0", + ] + + testModels.forEach((modelId) => { + const handler = createHandler({ + awsUseCrossRegionInference: true, + awsRegion: "eu-west-1", + apiModelId: modelId, + }) + + const model = handler.getModel() + expect(model.id).toBe(`eu.${modelId}`) + }) + }) + + it("should prioritize us-gov- over us- in cross-region inference", () => { + // Test that us-gov-east-1 gets ug. prefix, not us. + const govHandler = createHandler({ + awsUseCrossRegionInference: true, + awsRegion: "us-gov-east-1", + apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0", + }) + + const govModel = govHandler.getModel() + expect(govModel.id).toBe("ug.anthropic.claude-3-sonnet-20240229-v1:0") + + // Test that regular us-east-1 still gets us. prefix + const usHandler = createHandler({ + awsUseCrossRegionInference: true, + awsRegion: "us-east-1", + apiModelId: "anthropic.claude-3-sonnet-20240229-v1:0", + }) + + const usModel = usHandler.getModel() + expect(usModel.id).toBe("us.anthropic.claude-3-sonnet-20240229-v1:0") + }) + }) +}) diff --git a/src/api/providers/__tests__/bedrock.spec.ts b/src/api/providers/__tests__/bedrock.spec.ts index 80f6338629..ad0ae2bdb5 100644 --- a/src/api/providers/__tests__/bedrock.spec.ts +++ b/src/api/providers/__tests__/bedrock.spec.ts @@ -74,42 +74,6 @@ describe("AwsBedrockHandler", () => { expect(modelInfo.info).toBeDefined() }) - it("should handle inference-profile ARN with apne3 region prefix", () => { - const originalParseArn = AwsBedrockHandler.prototype["parseArn"] - const parseArnMock = vi.fn().mockImplementation(function (this: any, arn: string, region?: string) { - return originalParseArn.call(this, arn, region) - }) - AwsBedrockHandler.prototype["parseArn"] = parseArnMock - - try { - const customArnHandler = new AwsBedrockHandler({ - apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", - awsAccessKey: "test-access-key", - awsSecretKey: "test-secret-key", - awsRegion: "ap-northeast-3", - awsCustomArn: - "arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0", - }) - - const modelInfo = customArnHandler.getModel() - - expect(modelInfo.id).toBe( - "arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0", - ) - expect(modelInfo.info).toBeDefined() - - expect(parseArnMock).toHaveBeenCalledWith( - "arn:aws:bedrock:ap-northeast-3:123456789012:inference-profile/apne3.anthropic.claude-3-5-sonnet-20241022-v2:0", - "ap-northeast-3", - ) - - expect((customArnHandler as any).arnInfo.modelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") - expect((customArnHandler as any).arnInfo.crossRegionInference).toBe(false) - } finally { - AwsBedrockHandler.prototype["parseArn"] = originalParseArn - } - }) - it("should use default prompt router model when prompt router arn is entered but no model can be identified from the ARN", () => { const customArnHandler = new AwsBedrockHandler({ awsCustomArn: @@ -127,6 +91,270 @@ describe("AwsBedrockHandler", () => { }) }) + describe("region mapping and cross-region inference", () => { + describe("getPrefixForRegion", () => { + it("should return correct prefix for US regions", () => { + // Access private static method using type casting + const getPrefixForRegion = (AwsBedrockHandler as any).getPrefixForRegion + + expect(getPrefixForRegion("us-east-1")).toBe("us.") + expect(getPrefixForRegion("us-west-2")).toBe("us.") + expect(getPrefixForRegion("us-gov-west-1")).toBe("ug.") + }) + + it("should return correct prefix for EU regions", () => { + const getPrefixForRegion = (AwsBedrockHandler as any).getPrefixForRegion + + expect(getPrefixForRegion("eu-west-1")).toBe("eu.") + expect(getPrefixForRegion("eu-central-1")).toBe("eu.") + expect(getPrefixForRegion("eu-north-1")).toBe("eu.") + }) + + it("should return correct prefix for APAC regions", () => { + const getPrefixForRegion = (AwsBedrockHandler as any).getPrefixForRegion + + expect(getPrefixForRegion("ap-southeast-1")).toBe("apac.") + expect(getPrefixForRegion("ap-northeast-1")).toBe("apac.") + expect(getPrefixForRegion("ap-south-1")).toBe("apac.") + }) + + it("should return undefined for unsupported regions", () => { + const getPrefixForRegion = (AwsBedrockHandler as any).getPrefixForRegion + + expect(getPrefixForRegion("unknown-region")).toBeUndefined() + expect(getPrefixForRegion("")).toBeUndefined() + expect(getPrefixForRegion("invalid")).toBeUndefined() + }) + }) + + describe("isSystemInferenceProfile", () => { + it("should return true for AWS inference profile prefixes", () => { + const isSystemInferenceProfile = (AwsBedrockHandler as any).isSystemInferenceProfile + + expect(isSystemInferenceProfile("us.")).toBe(true) + expect(isSystemInferenceProfile("eu.")).toBe(true) + expect(isSystemInferenceProfile("apac.")).toBe(true) + }) + + it("should return false for other prefixes", () => { + const isSystemInferenceProfile = (AwsBedrockHandler as any).isSystemInferenceProfile + + expect(isSystemInferenceProfile("ap.")).toBe(false) + expect(isSystemInferenceProfile("apne1.")).toBe(false) + expect(isSystemInferenceProfile("use1.")).toBe(false) + expect(isSystemInferenceProfile("custom.")).toBe(false) + expect(isSystemInferenceProfile("")).toBe(false) + }) + }) + + describe("parseBaseModelId", () => { + it("should remove defined inference profile prefixes", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + }) + + // Access private method using type casting + const parseBaseModelId = (handler as any).parseBaseModelId.bind(handler) + + expect(parseBaseModelId("us.anthropic.claude-3-5-sonnet-20241022-v2:0")).toBe( + "anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + expect(parseBaseModelId("eu.anthropic.claude-3-haiku-20240307-v1:0")).toBe( + "anthropic.claude-3-haiku-20240307-v1:0", + ) + expect(parseBaseModelId("apac.anthropic.claude-3-opus-20240229-v1:0")).toBe( + "anthropic.claude-3-opus-20240229-v1:0", + ) + }) + + it("should not modify model IDs without defined prefixes", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + }) + + const parseBaseModelId = (handler as any).parseBaseModelId.bind(handler) + + expect(parseBaseModelId("anthropic.claude-3-5-sonnet-20241022-v2:0")).toBe( + "anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + expect(parseBaseModelId("amazon.titan-text-express-v1")).toBe("amazon.titan-text-express-v1") + }) + + it("should not modify model IDs with other prefixes", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + }) + + const parseBaseModelId = (handler as any).parseBaseModelId.bind(handler) + + // Other prefixes should be preserved as part of the model ID + expect(parseBaseModelId("ap.anthropic.claude-3-5-sonnet-20241022-v2:0")).toBe( + "ap.anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + expect(parseBaseModelId("apne1.anthropic.claude-3-5-sonnet-20241022-v2:0")).toBe( + "apne1.anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + expect(parseBaseModelId("use1.anthropic.claude-3-5-sonnet-20241022-v2:0")).toBe( + "use1.anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + }) + }) + + describe("cross-region inference integration", () => { + it("should apply correct prefix when cross-region inference is enabled", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + awsUseCrossRegionInference: true, + }) + + const model = handler.getModel() + expect(model.id).toBe("us.anthropic.claude-3-5-sonnet-20241022-v2:0") + }) + + it("should apply correct prefix for different regions", () => { + const euHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "eu-west-1", + awsUseCrossRegionInference: true, + }) + + const apacHandler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "ap-southeast-1", + awsUseCrossRegionInference: true, + }) + + expect(euHandler.getModel().id).toBe("eu.anthropic.claude-3-5-sonnet-20241022-v2:0") + expect(apacHandler.getModel().id).toBe("apac.anthropic.claude-3-5-sonnet-20241022-v2:0") + }) + + it("should not apply prefix when cross-region inference is disabled", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + awsUseCrossRegionInference: false, + }) + + const model = handler.getModel() + expect(model.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + }) + + it("should not apply prefix for unsupported regions", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "unknown-region", + awsUseCrossRegionInference: true, + }) + + const model = handler.getModel() + expect(model.id).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + }) + }) + + describe("ARN parsing with inference profiles", () => { + it("should detect cross-region inference from ARN model ID", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "test", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + }) + + const parseArn = (handler as any).parseArn.bind(handler) + + const result = parseArn( + "arn:aws:bedrock:us-east-1:123456789012:foundation-model/us.anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + expect(result.isValid).toBe(true) + expect(result.crossRegionInference).toBe(true) + expect(result.modelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + }) + + it("should not detect cross-region inference for non-prefixed models", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "test", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + }) + + const parseArn = (handler as any).parseArn.bind(handler) + + const result = parseArn( + "arn:aws:bedrock:us-east-1:123456789012:foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + expect(result.isValid).toBe(true) + expect(result.crossRegionInference).toBe(false) + expect(result.modelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + }) + + it("should detect cross-region inference for defined prefixes", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "test", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + }) + + const parseArn = (handler as any).parseArn.bind(handler) + + const euResult = parseArn( + "arn:aws:bedrock:eu-west-1:123456789012:foundation-model/eu.anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + const apacResult = parseArn( + "arn:aws:bedrock:ap-southeast-1:123456789012:foundation-model/apac.anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + expect(euResult.crossRegionInference).toBe(true) + expect(euResult.modelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + + expect(apacResult.crossRegionInference).toBe(true) + expect(apacResult.modelId).toBe("anthropic.claude-3-5-sonnet-20241022-v2:0") + }) + + it("should not detect cross-region inference for other prefixes", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "test", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + }) + + const parseArn = (handler as any).parseArn.bind(handler) + + // Other prefixes should not trigger cross-region inference detection + const result = parseArn( + "arn:aws:bedrock:us-east-1:123456789012:foundation-model/ap.anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + + expect(result.crossRegionInference).toBe(false) + expect(result.modelId).toBe("ap.anthropic.claude-3-5-sonnet-20241022-v2:0") // Should be preserved as-is + }) + }) + }) + describe("image handling", () => { const mockImageData = Buffer.from("test-image-data").toString("base64") @@ -242,4 +470,98 @@ describe("AwsBedrockHandler", () => { expect(secondImage.image).toHaveProperty("format", "png") }) }) + + describe("error handling and validation", () => { + it("should handle invalid regions gracefully", () => { + expect(() => { + new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "", // Empty region + }) + }).not.toThrow() + }) + + it("should validate ARN format and provide helpful error messages", () => { + expect(() => { + new AwsBedrockHandler({ + apiModelId: "test", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + awsCustomArn: "invalid-arn-format", + }) + }).toThrow(/INVALID_ARN_FORMAT/) + }) + + it("should handle malformed ARNs with missing components", () => { + expect(() => { + new AwsBedrockHandler({ + apiModelId: "test", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + awsCustomArn: "arn:aws:bedrock:us-east-1", + }) + }).toThrow(/INVALID_ARN_FORMAT/) + }) + }) + + describe("model information and configuration", () => { + it("should preserve model information after applying cross-region prefixes", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + awsUseCrossRegionInference: true, + }) + + const model = handler.getModel() + + // Model ID should have prefix + expect(model.id).toBe("us.anthropic.claude-3-5-sonnet-20241022-v2:0") + + // But model info should remain the same + expect(model.info.maxTokens).toBe(8192) + expect(model.info.contextWindow).toBe(200_000) + expect(model.info.supportsImages).toBe(true) + expect(model.info.supportsPromptCache).toBe(true) + }) + + it("should handle model configuration overrides correctly", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + modelMaxTokens: 4096, + awsModelContextWindow: 100_000, + }) + + const model = handler.getModel() + + // Should use override values + expect(model.info.maxTokens).toBe(4096) + expect(model.info.contextWindow).toBe(100_000) + }) + + it("should handle unknown models with sensible defaults", () => { + const handler = new AwsBedrockHandler({ + apiModelId: "unknown.model.id", + awsAccessKey: "test", + awsSecretKey: "test", + awsRegion: "us-east-1", + }) + + const model = handler.getModel() + + // Should fall back to default model info + expect(model.info.maxTokens).toBeDefined() + expect(model.info.contextWindow).toBeDefined() + expect(typeof model.info.supportsImages).toBe("boolean") + expect(typeof model.info.supportsPromptCache).toBe("boolean") + }) + }) }) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index b5474cce50..beba8ccaf3 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -20,7 +20,7 @@ import { BEDROCK_DEFAULT_TEMPERATURE, BEDROCK_MAX_TOKENS, BEDROCK_DEFAULT_CONTEXT, - BEDROCK_REGION_INFO, + AWS_INFERENCE_PROFILE_MAPPING, } from "@roo-code/types" import { ApiStream } from "../transform/stream" @@ -482,7 +482,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH if (streamEvent.contentBlockStart) { const cbStart = streamEvent.contentBlockStart - // Check if this is a reasoning block (official AWS SDK structure) + // Check if this is a reasoning block (AWS SDK structure) if (cbStart.contentBlock?.reasoningContent) { if (cbStart.contentBlockIndex && cbStart.contentBlockIndex > 0) { yield { type: "reasoning", text: "\n" } @@ -493,7 +493,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } } // Check for thinking block - handle both possible AWS SDK structures - // cbStart.contentBlock: newer/official structure + // cbStart.contentBlock: newer structure // cbStart.content_block: alternative structure seen in some AWS SDK versions else if (cbStart.contentBlock?.type === "thinking" || cbStart.content_block?.type === "thinking") { const contentBlock = cbStart.contentBlock || cbStart.content_block @@ -522,11 +522,11 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH // Process reasoning and text content deltas // Multiple structures are supported for AWS SDK compatibility: - // - delta.reasoningContent.text: official AWS docs structure for reasoning + // - delta.reasoningContent.text: AWS docs structure for reasoning // - delta.thinking: alternative structure for thinking content // - delta.text: standard text content if (delta) { - // Check for reasoningContent property (official AWS SDK structure) + // Check for reasoningContent property (AWS SDK structure) if (delta.reasoningContent?.text) { yield { type: "reasoning", @@ -777,7 +777,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH if (originalModelId && result.modelId !== originalModelId) { // If the model ID changed after parsing, it had a region prefix let prefix = originalModelId.replace(result.modelId, "") - result.crossRegionInference = AwsBedrockHandler.prefixIsMultiRegion(prefix) + result.crossRegionInference = AwsBedrockHandler.isSystemInferenceProfile(prefix) } // Check if region in ARN matches provided region (if specified) @@ -801,29 +801,21 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } //This strips any region prefix that used on cross-region model inference ARNs - private parseBaseModelId(modelId: string) { + private parseBaseModelId(modelId: string): string { if (!modelId) { return modelId } - const knownRegionPrefixes = AwsBedrockHandler.getPrefixList() - - // Find if the model ID starts with any known region prefix - const matchedPrefix = knownRegionPrefixes.find((prefix) => modelId.startsWith(prefix)) - - if (matchedPrefix) { - // Remove the region prefix from the model ID - return modelId.substring(matchedPrefix.length) - } else { - // If no known prefix was found, check for a generic pattern - // Look for a pattern where the first segment before a dot doesn't contain dots or colons - // and the remaining parts still contain at least one dot - const genericPrefixMatch = modelId.match(/^([^.:]+)\.(.+\..+)$/) - - if (genericPrefixMatch) { - return genericPrefixMatch[2] + // Remove AWS cross-region inference profile prefixes + // as defined in AWS_INFERENCE_PROFILE_MAPPING + for (const [_, inferenceProfile] of AWS_INFERENCE_PROFILE_MAPPING) { + if (modelId.startsWith(inferenceProfile)) { + // Remove the inference profile prefix from the model ID + return modelId.substring(inferenceProfile.length) } } + + // Return the model ID as-is for all other cases return modelId } @@ -900,14 +892,12 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH //a model was selected from the drop down modelConfig = this.getModelById(this.options.apiModelId as string) - if (this.options.awsUseCrossRegionInference) { - // Get the current region - const region = this.options.awsRegion || "" - // Use the helper method to get the appropriate prefix for this region - const prefix = AwsBedrockHandler.getPrefixForRegion(region) - - // Apply the prefix if one was found, otherwise use the model ID as is - modelConfig.id = prefix ? `${prefix}${modelConfig.id}` : modelConfig.id + // Add cross-region inference prefix if enabled + if (this.options.awsUseCrossRegionInference && this.options.awsRegion) { + const prefix = AwsBedrockHandler.getPrefixForRegion(this.options.awsRegion) + if (prefix) { + modelConfig.id = `${prefix}${modelConfig.id}` + } } } @@ -973,24 +963,23 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH * *************************************************************************************/ - private static getPrefixList(): string[] { - return Object.keys(BEDROCK_REGION_INFO) - } - private static getPrefixForRegion(region: string): string | undefined { - for (const [prefix, info] of Object.entries(BEDROCK_REGION_INFO)) { - if (info.pattern && region.startsWith(info.pattern)) { - return prefix + // Use AWS recommended inference profile prefixes + // Array is pre-sorted by pattern length (descending) to ensure more specific patterns match first + for (const [regionPattern, inferenceProfile] of AWS_INFERENCE_PROFILE_MAPPING) { + if (region.startsWith(regionPattern)) { + return inferenceProfile } } + return undefined } - private static prefixIsMultiRegion(arnPrefix: string): boolean { - for (const [prefix, info] of Object.entries(BEDROCK_REGION_INFO)) { - if (arnPrefix === prefix) { - if (info?.multiRegion) return info.multiRegion - else return false + private static isSystemInferenceProfile(prefix: string): boolean { + // Check if the prefix is defined in AWS_INFERENCE_PROFILE_MAPPING + for (const [_, inferenceProfile] of AWS_INFERENCE_PROFILE_MAPPING) { + if (prefix === inferenceProfile) { + return true } } return false