Skip to content
Merged
66 changes: 61 additions & 5 deletions src/api/providers/bedrock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import { convertToBedrockConverseMessages as sharedConverter } from "../transfor

const BEDROCK_DEFAULT_TEMPERATURE = 0.3
const BEDROCK_MAX_TOKENS = 4096
const BEDROCK_DEFAULT_CONTEXT = 128_000 // PATCH: used for unknown custom ARNs

/************************************************************************************
*
Expand Down Expand Up @@ -186,6 +187,54 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
this.client = new BedrockRuntimeClient(clientConfig)
}

// PATCH: Helper to guess model info from custom modelId string if not in bedrockModels
private guessModelInfoFromId(modelId: string): Partial<SharedModelInfo> {
// Define a mapping for model ID patterns and their configurations
const modelConfigMap: Record<string, Partial<SharedModelInfo>> = {
"claude-3-7": {
maxTokens: 8192,
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
},
"claude-3-5": {
maxTokens: 8192,
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
},
"claude-3-opus": {
maxTokens: 4096,
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
},
"claude-3-haiku": {
maxTokens: 4096,
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
},
}

// Match the model ID to a configuration
const id = modelId.toLowerCase()
for (const [pattern, config] of Object.entries(modelConfigMap)) {
if (id.includes(pattern)) {
return config
}
}
// PATCH: Add more heuristics as needed here

// Default fallback
return {
maxTokens: BEDROCK_MAX_TOKENS,
contextWindow: BEDROCK_DEFAULT_CONTEXT,
supportsImages: false,
supportsPromptCache: false,
}
}

override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
let modelConfig = this.getModel()
// Handle cross-region inference
Expand Down Expand Up @@ -495,7 +544,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
})

return {
system: systemMessage ? [{ text: systemMessage } as SystemContentBlock] : [],
system: cacheResult.system,
messages: messagesWithCache,
}
}
Expand Down Expand Up @@ -629,16 +678,24 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultPromptRouterModelId])),
}
} else {
// PATCH: Use heuristics for model info, then allow overrides from ProviderSettings
const guessed = this.guessModelInfoFromId(modelId)
model = {
id: bedrockDefaultModelId,
info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])),
info: {
...JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])),
...guessed,
},
}
}

// If modelMaxTokens is explicitly set in options, override the default
// PATCH: Always allow user to override detected/guessed maxTokens and contextWindow
if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
model.info.maxTokens = this.options.modelMaxTokens
}
if (this.options.awsModelContextWindow && this.options.awsModelContextWindow > 0) {
model.info.contextWindow = this.options.awsModelContextWindow
}

return model
}
Expand Down Expand Up @@ -673,8 +730,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
}
}

modelConfig.info.maxTokens = modelConfig.info.maxTokens || BEDROCK_MAX_TOKENS

// PATCH: Don't override maxTokens/contextWindow here; handled in getModelById (and includes user overrides)
return modelConfig as { id: BedrockModelId | string; info: SharedModelInfo }
}

Expand Down
9 changes: 9 additions & 0 deletions src/core/config/importExport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ export const importSettings = async ({ providerSettingsManager, contextProxy, cu
if (e instanceof ZodError) {
error = e.issues.map((issue) => `[${issue.path.join(".")}]: ${issue.message}`).join("\n")
telemetryService.captureSchemaValidationError({ schemaName: "ImportExport", error: e })
} else if (e instanceof SyntaxError) {
// Extract position info from the error message
const match = e.message.match(/at position (\d+)/)
if (match) {
const position = parseInt(match[1], 10)
error = `Expected property name or '}' in JSON at position ${position}`
} else {
error = e.message
}
} else if (e instanceof Error) {
error = e.message
}
Expand Down
3 changes: 3 additions & 0 deletions src/exports/roo-code.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ type ProviderSettings = {
awsProfile?: string | undefined
awsUseProfile?: boolean | undefined
awsCustomArn?: string | undefined
awsModelContextWindow?: number | undefined
vertexKeyFile?: string | undefined
vertexJsonCredentials?: string | undefined
vertexProjectId?: string | undefined
Expand Down Expand Up @@ -656,6 +657,7 @@ type IpcMessage =
awsProfile?: string | undefined
awsUseProfile?: boolean | undefined
awsCustomArn?: string | undefined
awsModelContextWindow?: number | undefined
vertexKeyFile?: string | undefined
vertexJsonCredentials?: string | undefined
vertexProjectId?: string | undefined
Expand Down Expand Up @@ -1131,6 +1133,7 @@ type TaskCommand =
awsProfile?: string | undefined
awsUseProfile?: boolean | undefined
awsCustomArn?: string | undefined
awsModelContextWindow?: number | undefined
vertexKeyFile?: string | undefined
vertexJsonCredentials?: string | undefined
vertexProjectId?: string | undefined
Expand Down
3 changes: 3 additions & 0 deletions src/exports/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ type ProviderSettings = {
awsProfile?: string | undefined
awsUseProfile?: boolean | undefined
awsCustomArn?: string | undefined
awsModelContextWindow?: number | undefined
vertexKeyFile?: string | undefined
vertexJsonCredentials?: string | undefined
vertexProjectId?: string | undefined
Expand Down Expand Up @@ -670,6 +671,7 @@ type IpcMessage =
awsProfile?: string | undefined
awsUseProfile?: boolean | undefined
awsCustomArn?: string | undefined
awsModelContextWindow?: number | undefined
vertexKeyFile?: string | undefined
vertexJsonCredentials?: string | undefined
vertexProjectId?: string | undefined
Expand Down Expand Up @@ -1147,6 +1149,7 @@ type TaskCommand =
awsProfile?: string | undefined
awsUseProfile?: boolean | undefined
awsCustomArn?: string | undefined
awsModelContextWindow?: number | undefined
vertexKeyFile?: string | undefined
vertexJsonCredentials?: string | undefined
vertexProjectId?: string | undefined
Expand Down
2 changes: 2 additions & 0 deletions src/schemas/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ const bedrockSchema = apiModelIdProviderModelSchema.extend({
awsProfile: z.string().optional(),
awsUseProfile: z.boolean().optional(),
awsCustomArn: z.string().optional(),
awsModelContextWindow: z.number().optional(),
})

const vertexSchema = apiModelIdProviderModelSchema.extend({
Expand Down Expand Up @@ -651,6 +652,7 @@ const providerSettingsRecord: ProviderSettingsRecord = {
awsProfile: undefined,
awsUseProfile: undefined,
awsCustomArn: undefined,
awsModelContextWindow: undefined,
// Google Vertex
vertexKeyFile: undefined,
vertexJsonCredentials: undefined,
Expand Down