Skip to content

Commit e4a4685

Browse files
committed
Improves model info detection for custom Bedrock ARNs
Adds heuristics to better estimate model capabilities when using unknown or custom model ARNs, including context window and max tokens. Allows user overrides for key model parameters via provider settings, improving flexibility and reliability for non-standard model integrations. Fixes #3712
1 parent d7ae811 commit e4a4685

File tree

4 files changed

+68
-4
lines changed

4 files changed

+68
-4
lines changed

src/api/providers/bedrock.ts

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import { convertToBedrockConverseMessages as sharedConverter } from "../transfor
2929

3030
const BEDROCK_DEFAULT_TEMPERATURE = 0.3
3131
const BEDROCK_MAX_TOKENS = 4096
32+
const BEDROCK_DEFAULT_CONTEXT = 128_000 // PATCH: used for unknown custom ARNs
3233

3334
/************************************************************************************
3435
*
@@ -186,6 +187,54 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
186187
this.client = new BedrockRuntimeClient(clientConfig)
187188
}
188189

190+
// PATCH: Helper to guess model info from custom modelId string if not in bedrockModels
191+
private guessModelInfoFromId(modelId: string): Partial<SharedModelInfo> {
192+
// Try to match Claude 3.7, 3.5, Opus, Haiku, etc.
193+
const id = modelId.toLowerCase()
194+
195+
if (id.includes("claude-3-7")) {
196+
return {
197+
maxTokens: 8192,
198+
contextWindow: 200_000,
199+
supportsImages: true,
200+
supportsPromptCache: true,
201+
}
202+
}
203+
if (id.includes("claude-3-5")) {
204+
return {
205+
maxTokens: 8192,
206+
contextWindow: 200_000,
207+
supportsImages: true,
208+
supportsPromptCache: true,
209+
}
210+
}
211+
if (id.includes("claude-3-opus")) {
212+
return {
213+
maxTokens: 4096,
214+
contextWindow: 200_000,
215+
supportsImages: true,
216+
supportsPromptCache: true,
217+
}
218+
}
219+
if (id.includes("claude-3-haiku")) {
220+
return {
221+
maxTokens: 4096,
222+
contextWindow: 200_000,
223+
supportsImages: true,
224+
supportsPromptCache: true,
225+
}
226+
}
227+
// PATCH: Add more heuristics as needed here
228+
229+
// Default fallback
230+
return {
231+
maxTokens: BEDROCK_MAX_TOKENS,
232+
contextWindow: BEDROCK_DEFAULT_CONTEXT,
233+
supportsImages: false,
234+
supportsPromptCache: false,
235+
}
236+
}
237+
189238
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
190239
let modelConfig = this.getModel()
191240
// Handle cross-region inference
@@ -629,16 +678,24 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
629678
info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultPromptRouterModelId])),
630679
}
631680
} else {
681+
// PATCH: Use heuristics for model info, then allow overrides from ProviderSettings
682+
const guessed = this.guessModelInfoFromId(modelId)
632683
model = {
633684
id: bedrockDefaultModelId,
634-
info: JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])),
685+
info: {
686+
...JSON.parse(JSON.stringify(bedrockModels[bedrockDefaultModelId])),
687+
...guessed,
688+
},
635689
}
636690
}
637691

638-
// If modelMaxTokens is explicitly set in options, override the default
692+
// PATCH: Always allow user to override detected/guessed maxTokens and contextWindow
639693
if (this.options.modelMaxTokens && this.options.modelMaxTokens > 0) {
640694
model.info.maxTokens = this.options.modelMaxTokens
641695
}
696+
if (this.options.awsModelContextWindow && this.options.awsModelContextWindow > 0) {
697+
model.info.contextWindow = this.options.awsModelContextWindow
698+
}
642699

643700
return model
644701
}
@@ -673,8 +730,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH
673730
}
674731
}
675732

676-
modelConfig.info.maxTokens = modelConfig.info.maxTokens || BEDROCK_MAX_TOKENS
677-
733+
// PATCH: Don't override maxTokens/contextWindow here; handled in getModelById (and includes user overrides)
678734
return modelConfig as { id: BedrockModelId | string; info: SharedModelInfo }
679735
}
680736

src/exports/roo-code.d.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ type ProviderSettings = {
253253
awsProfile?: string | undefined
254254
awsUseProfile?: boolean | undefined
255255
awsCustomArn?: string | undefined
256+
awsModelContextWindow?: number | undefined
256257
vertexKeyFile?: string | undefined
257258
vertexJsonCredentials?: string | undefined
258259
vertexProjectId?: string | undefined
@@ -656,6 +657,7 @@ type IpcMessage =
656657
awsProfile?: string | undefined
657658
awsUseProfile?: boolean | undefined
658659
awsCustomArn?: string | undefined
660+
awsModelContextWindow?: number | undefined
659661
vertexKeyFile?: string | undefined
660662
vertexJsonCredentials?: string | undefined
661663
vertexProjectId?: string | undefined
@@ -1131,6 +1133,7 @@ type TaskCommand =
11311133
awsProfile?: string | undefined
11321134
awsUseProfile?: boolean | undefined
11331135
awsCustomArn?: string | undefined
1136+
awsModelContextWindow?: number | undefined
11341137
vertexKeyFile?: string | undefined
11351138
vertexJsonCredentials?: string | undefined
11361139
vertexProjectId?: string | undefined

src/exports/types.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ type ProviderSettings = {
257257
awsProfile?: string | undefined
258258
awsUseProfile?: boolean | undefined
259259
awsCustomArn?: string | undefined
260+
awsModelContextWindow?: number | undefined
260261
vertexKeyFile?: string | undefined
261262
vertexJsonCredentials?: string | undefined
262263
vertexProjectId?: string | undefined
@@ -670,6 +671,7 @@ type IpcMessage =
670671
awsProfile?: string | undefined
671672
awsUseProfile?: boolean | undefined
672673
awsCustomArn?: string | undefined
674+
awsModelContextWindow?: number | undefined
673675
vertexKeyFile?: string | undefined
674676
vertexJsonCredentials?: string | undefined
675677
vertexProjectId?: string | undefined
@@ -1147,6 +1149,7 @@ type TaskCommand =
11471149
awsProfile?: string | undefined
11481150
awsUseProfile?: boolean | undefined
11491151
awsCustomArn?: string | undefined
1152+
awsModelContextWindow?: number | undefined
11501153
vertexKeyFile?: string | undefined
11511154
vertexJsonCredentials?: string | undefined
11521155
vertexProjectId?: string | undefined

src/schemas/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ const bedrockSchema = apiModelIdProviderModelSchema.extend({
465465
awsProfile: z.string().optional(),
466466
awsUseProfile: z.boolean().optional(),
467467
awsCustomArn: z.string().optional(),
468+
awsModelContextWindow: z.number().optional(),
468469
})
469470

470471
const vertexSchema = apiModelIdProviderModelSchema.extend({
@@ -651,6 +652,7 @@ const providerSettingsRecord: ProviderSettingsRecord = {
651652
awsProfile: undefined,
652653
awsUseProfile: undefined,
653654
awsCustomArn: undefined,
655+
awsModelContextWindow: undefined,
654656
// Google Vertex
655657
vertexKeyFile: undefined,
656658
vertexJsonCredentials: undefined,

0 commit comments

Comments
 (0)