diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 08a328379d..39479737ed 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -100,6 +100,7 @@ const bedrockSchema = apiModelIdProviderModelSchema.extend({ awsProfile: z.string().optional(), awsUseProfile: z.boolean().optional(), awsCustomArn: z.string().optional(), + awsModelContextWindow: z.number().optional(), awsBedrockEndpointEnabled: z.boolean().optional(), awsBedrockEndpoint: z.string().optional(), }) @@ -285,6 +286,7 @@ export const PROVIDER_SETTINGS_KEYS = keysOf()([ "awsProfile", "awsUseProfile", "awsCustomArn", + "awsModelContextWindow", "awsBedrockEndpointEnabled", "awsBedrockEndpoint", // Google Vertex diff --git a/packages/types/src/providers/bedrock.ts b/packages/types/src/providers/bedrock.ts index f40dc8c8f6..ce5ea28e95 100644 --- a/packages/types/src/providers/bedrock.ts +++ b/packages/types/src/providers/bedrock.ts @@ -355,6 +355,8 @@ export const BEDROCK_DEFAULT_TEMPERATURE = 0.3 export const BEDROCK_MAX_TOKENS = 4096 +export const BEDROCK_DEFAULT_CONTEXT = 128_000 + export const BEDROCK_REGION_INFO: Record< string, { diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 0e335755ec..16ce3289aa 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -19,6 +19,7 @@ import { bedrockDefaultPromptRouterModelId, BEDROCK_DEFAULT_TEMPERATURE, BEDROCK_MAX_TOKENS, + BEDROCK_DEFAULT_CONTEXT, BEDROCK_REGION_INFO, } from "@roo-code/types" @@ -192,6 +193,65 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH this.client = new BedrockRuntimeClient(clientConfig) } + // Helper to guess model info from custom modelId string if not in bedrockModels + private guessModelInfoFromId(modelId: string): Partial { + // Define a mapping for model ID patterns and their configurations + const modelConfigMap: Record> = { + "claude-4": { + maxTokens: 8192, + contextWindow: 200_000, + supportsImages: true, + supportsPromptCache: true, + }, + "claude-3-7": { + maxTokens: 8192, + contextWindow: 200_000, + supportsImages: true, + supportsPromptCache: true, + }, + "claude-3-5": { + maxTokens: 8192, + contextWindow: 200_000, + supportsImages: true, + supportsPromptCache: true, + }, + "claude-4-opus": { + maxTokens: 4096, + contextWindow: 200_000, + supportsImages: true, + supportsPromptCache: true, + }, + "claude-3-opus": { + maxTokens: 4096, + contextWindow: 200_000, + supportsImages: true, + supportsPromptCache: true, + }, + "claude-3-haiku": { + maxTokens: 4096, + contextWindow: 200_000, + supportsImages: true, + supportsPromptCache: true, + }, + } + + // Match the model ID to a configuration + const id = modelId.toLowerCase() + for (const [pattern, config] of Object.entries(modelConfigMap)) { + if (id.includes(pattern)) { + return config + } + } + + // Default fallback + return { + maxTokens: BEDROCK_MAX_TOKENS, + contextWindow: BEDROCK_DEFAULT_CONTEXT, + supportsImages: false, + supportsPromptCache: false, + } + } + override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], @@ -640,16 +700,24 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultPromptRouterModelId])), } } else { + // Use heuristics for model info, then allow overrides from ProviderSettings + const guessed = this.guessModelInfoFromId(modelId) model = { id: bedrockDefaultModelId, - info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])), + info: { + ...JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])), + ...guessed, + }, } } - // If modelMaxTokens is explicitly set in options, override the default + // Always allow user to override detected/guessed maxTokens and contextWindow if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) { model.info.maxTokens = this.options.modelMaxTokens } + if (this.options.awsModelContextWindow && this.options.awsModelContextWindow > 0) { + model.info.contextWindow = this.options.awsModelContextWindow + } return model } @@ -684,8 +752,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } } - modelConfig.info.maxTokens = modelConfig.info.maxTokens || BEDROCK_MAX_TOKENS - + // Don't override maxTokens/contextWindow here; handled in getModelById (and includes user overrides) return modelConfig as { id: BedrockModelId | string; info: ModelInfo } }