Skip to content

Commit 79b76fd

Browse files
authored
feat: Support AWS Bedrock Application Inference Profiles for Cost Tracking (RooCodeInc#2078)
* feat: Add support for custom model ID in AWS Bedrock provider * preserve settings when switching Act-Plan modes * Use base model ID for ApiHandler behavior determination when using a custom model on AWS Bedrock.
1 parent 19cc8bc commit 79b76fd

File tree

7 files changed

+175
-11
lines changed

7 files changed

+175
-11
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"claude-dev": minor
3+
---
4+
5+
Add support for custom model ID in AWS Bedrock provider, enabling use of Application Inference Profile.

src/api/providers/bedrock.ts

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,26 @@ export class AwsBedrockHandler implements ApiHandler {
2828
const modelId = await this.getModelId()
2929
const model = this.getModel()
3030

31+
// This baseModelId is used to indicate the capabilities of the model.
32+
// If the user selects a custom model, baseModelId will be set to the base model ID of the custom model.
33+
// Otherwise, baseModelId will be the same as modelId.
34+
const baseModelId =
35+
(this.options.awsBedrockCustomSelected ? this.options.awsBedrockCustomModelBaseId : modelId) || modelId
36+
3137
// Check if this is an Amazon Nova model
32-
if (modelId.includes("amazon.nova")) {
38+
if (baseModelId.includes("amazon.nova")) {
3339
yield* this.createNovaMessage(systemPrompt, messages, modelId, model)
3440
return
3541
}
3642

3743
// Check if this is a Deepseek model
38-
if (modelId.includes("deepseek")) {
44+
if (baseModelId.includes("deepseek")) {
3945
yield* this.createDeepseekMessage(systemPrompt, messages, modelId, model)
4046
return
4147
}
4248

4349
const budget_tokens = this.options.thinkingBudgetTokens || 0
44-
const reasoningOn = modelId.includes("3-7") && budget_tokens !== 0 ? true : false
50+
const reasoningOn = baseModelId.includes("3-7") && budget_tokens !== 0 ? true : false
4551

4652
// Get model info and message indices for caching
4753
const userMsgIndices = messages.reduce((acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), [] as number[])
@@ -167,12 +173,23 @@ export class AwsBedrockHandler implements ApiHandler {
167173
}
168174
}
169175

170-
getModel(): { id: BedrockModelId; info: ModelInfo } {
176+
getModel(): { id: string; info: ModelInfo } {
171177
const modelId = this.options.apiModelId
172178
if (modelId && modelId in bedrockModels) {
173179
const id = modelId as BedrockModelId
174180
return { id, info: bedrockModels[id] }
175181
}
182+
183+
const customSelected = this.options.awsBedrockCustomSelected
184+
const baseModel = this.options.awsBedrockCustomModelBaseId
185+
if (customSelected && modelId && baseModel && baseModel in bedrockModels) {
186+
// Use the user-input model ID but inherit capabilities from the base model
187+
return {
188+
id: modelId,
189+
info: bedrockModels[baseModel],
190+
}
191+
}
192+
176193
return {
177194
id: bedrockDefaultModelId,
178195
info: bedrockModels[bedrockDefaultModelId],
@@ -290,7 +307,7 @@ export class AwsBedrockHandler implements ApiHandler {
290307
systemPrompt: string,
291308
messages: Anthropic.Messages.MessageParam[],
292309
modelId: string,
293-
model: { id: BedrockModelId; info: ModelInfo },
310+
model: { id: string; info: ModelInfo },
294311
): ApiStream {
295312
// Get Bedrock client with proper credentials
296313
const client = await this.getBedrockClient()
@@ -482,7 +499,7 @@ export class AwsBedrockHandler implements ApiHandler {
482499
systemPrompt: string,
483500
messages: Anthropic.Messages.MessageParam[],
484501
modelId: string,
485-
model: { id: BedrockModelId; info: ModelInfo },
502+
model: { id: string; info: ModelInfo },
486503
): ApiStream {
487504
// Get Bedrock client with proper credentials
488505
const client = await this.getBedrockClient()

src/core/controller/index.ts

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,8 @@ export class Controller {
832832
previousModeVsCodeLmModelSelector: newVsCodeLmModelSelector,
833833
previousModeThinkingBudgetTokens: newThinkingBudgetTokens,
834834
previousModeReasoningEffort: newReasoningEffort,
835+
previousModeAwsBedrockCustomSelected: newAwsBedrockCustomSelected,
836+
previousModeAwsBedrockCustomModelBaseId: newAwsBedrockCustomModelBaseId,
835837
planActSeparateModelsSetting,
836838
} = await getAllExtensionState(this.context)
837839

@@ -844,7 +846,6 @@ export class Controller {
844846
await updateGlobalState(this.context, "previousModeReasoningEffort", apiConfiguration.reasoningEffort)
845847
switch (apiConfiguration.apiProvider) {
846848
case "anthropic":
847-
case "bedrock":
848849
case "vertex":
849850
case "gemini":
850851
case "asksage":
@@ -854,6 +855,19 @@ export class Controller {
854855
case "xai":
855856
await updateGlobalState(this.context, "previousModeModelId", apiConfiguration.apiModelId)
856857
break
858+
case "bedrock":
859+
await updateGlobalState(this.context, "previousModeModelId", apiConfiguration.apiModelId)
860+
await updateGlobalState(
861+
this.context,
862+
"previousModeAwsBedrockCustomSelected",
863+
apiConfiguration.awsBedrockCustomSelected,
864+
)
865+
await updateGlobalState(
866+
this.context,
867+
"previousModeAwsBedrockCustomModelBaseId",
868+
apiConfiguration.awsBedrockCustomModelBaseId,
869+
)
870+
break
857871
case "openrouter":
858872
case "cline":
859873
await updateGlobalState(this.context, "previousModeModelId", apiConfiguration.openRouterModelId)
@@ -899,7 +913,6 @@ export class Controller {
899913
await updateGlobalState(this.context, "reasoningEffort", newReasoningEffort)
900914
switch (newApiProvider) {
901915
case "anthropic":
902-
case "bedrock":
903916
case "vertex":
904917
case "gemini":
905918
case "asksage":
@@ -909,6 +922,11 @@ export class Controller {
909922
case "xai":
910923
await updateGlobalState(this.context, "apiModelId", newModelId)
911924
break
925+
case "bedrock":
926+
await updateGlobalState(this.context, "apiModelId", newModelId)
927+
await updateGlobalState(this.context, "awsBedrockCustomSelected", newAwsBedrockCustomSelected)
928+
await updateGlobalState(this.context, "awsBedrockCustomModelBaseId", newAwsBedrockCustomModelBaseId)
929+
break
912930
case "openrouter":
913931
case "cline":
914932
await updateGlobalState(this.context, "openRouterModelId", newModelId)

src/core/storage/state-keys.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ export type GlobalStateKey =
2929
| "awsBedrockEndpoint"
3030
| "awsProfile"
3131
| "awsUseProfile"
32+
| "awsBedrockCustomSelected"
33+
| "awsBedrockCustomModelBaseId"
3234
| "vertexProjectId"
3335
| "vertexRegion"
3436
| "lastShownAnnouncementId"
@@ -60,6 +62,8 @@ export type GlobalStateKey =
6062
| "previousModeThinkingBudgetTokens"
6163
| "previousModeReasoningEffort"
6264
| "previousModeVsCodeLmModelSelector"
65+
| "previousModeAwsBedrockCustomSelected"
66+
| "previousModeAwsBedrockCustomModelBaseId"
6367
| "previousModeModelInfo"
6468
| "liteLlmBaseUrl"
6569
| "liteLlmModelId"

src/core/storage/state.ts

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { DEFAULT_CHAT_SETTINGS } from "@shared/ChatSettings"
33
import { DEFAULT_BROWSER_SETTINGS } from "@shared/BrowserSettings"
44
import { DEFAULT_AUTO_APPROVAL_SETTINGS } from "@shared/AutoApprovalSettings"
55
import { GlobalStateKey, SecretKey } from "./state-keys"
6-
import { ApiConfiguration, ApiProvider, ModelInfo } from "@shared/api"
6+
import { ApiConfiguration, ApiProvider, BedrockModelId, ModelInfo } from "@shared/api"
77
import { HistoryItem } from "@shared/HistoryItem"
88
import { AutoApprovalSettings } from "@shared/AutoApprovalSettings"
99
import { BrowserSettings } from "@shared/BrowserSettings"
@@ -67,6 +67,8 @@ export async function getAllExtensionState(context: vscode.ExtensionContext) {
6767
awsBedrockEndpoint,
6868
awsProfile,
6969
awsUseProfile,
70+
awsBedrockCustomSelected,
71+
awsBedrockCustomModelBaseId,
7072
vertexProjectId,
7173
vertexRegion,
7274
openAiBaseUrl,
@@ -113,6 +115,8 @@ export async function getAllExtensionState(context: vscode.ExtensionContext) {
113115
previousModeVsCodeLmModelSelector,
114116
previousModeThinkingBudgetTokens,
115117
previousModeReasoningEffort,
118+
previousModeAwsBedrockCustomSelected,
119+
previousModeAwsBedrockCustomModelBaseId,
116120
qwenApiLine,
117121
liteLlmApiKey,
118122
telemetrySetting,
@@ -142,6 +146,8 @@ export async function getAllExtensionState(context: vscode.ExtensionContext) {
142146
getGlobalState(context, "awsBedrockEndpoint") as Promise<string | undefined>,
143147
getGlobalState(context, "awsProfile") as Promise<string | undefined>,
144148
getGlobalState(context, "awsUseProfile") as Promise<boolean | undefined>,
149+
getGlobalState(context, "awsBedrockCustomSelected") as Promise<boolean | undefined>,
150+
getGlobalState(context, "awsBedrockCustomModelBaseId") as Promise<BedrockModelId | undefined>,
145151
getGlobalState(context, "vertexProjectId") as Promise<string | undefined>,
146152
getGlobalState(context, "vertexRegion") as Promise<string | undefined>,
147153
getGlobalState(context, "openAiBaseUrl") as Promise<string | undefined>,
@@ -188,6 +194,8 @@ export async function getAllExtensionState(context: vscode.ExtensionContext) {
188194
getGlobalState(context, "previousModeVsCodeLmModelSelector") as Promise<vscode.LanguageModelChatSelector | undefined>,
189195
getGlobalState(context, "previousModeThinkingBudgetTokens") as Promise<number | undefined>,
190196
getGlobalState(context, "previousModeReasoningEffort") as Promise<string | undefined>,
197+
getGlobalState(context, "previousModeAwsBedrockCustomSelected") as Promise<boolean | undefined>,
198+
getGlobalState(context, "previousModeAwsBedrockCustomModelBaseId") as Promise<BedrockModelId | undefined>,
191199
getGlobalState(context, "qwenApiLine") as Promise<string | undefined>,
192200
getSecret(context, "liteLlmApiKey") as Promise<string | undefined>,
193201
getGlobalState(context, "telemetrySetting") as Promise<TelemetrySetting | undefined>,
@@ -258,6 +266,8 @@ export async function getAllExtensionState(context: vscode.ExtensionContext) {
258266
awsBedrockEndpoint,
259267
awsProfile,
260268
awsUseProfile,
269+
awsBedrockCustomSelected,
270+
awsBedrockCustomModelBaseId,
261271
vertexProjectId,
262272
vertexRegion,
263273
openAiBaseUrl,
@@ -318,6 +328,8 @@ export async function getAllExtensionState(context: vscode.ExtensionContext) {
318328
previousModeVsCodeLmModelSelector,
319329
previousModeThinkingBudgetTokens,
320330
previousModeReasoningEffort,
331+
previousModeAwsBedrockCustomSelected,
332+
previousModeAwsBedrockCustomModelBaseId,
321333
mcpMarketplaceEnabled,
322334
telemetrySetting: telemetrySetting || "unset",
323335
planActSeparateModelsSetting,
@@ -340,6 +352,8 @@ export async function updateApiConfiguration(context: vscode.ExtensionContext, a
340352
awsBedrockEndpoint,
341353
awsProfile,
342354
awsUseProfile,
355+
awsBedrockCustomSelected,
356+
awsBedrockCustomModelBaseId,
343357
vertexProjectId,
344358
vertexRegion,
345359
openAiBaseUrl,
@@ -397,6 +411,8 @@ export async function updateApiConfiguration(context: vscode.ExtensionContext, a
397411
await updateGlobalState(context, "awsBedrockEndpoint", awsBedrockEndpoint)
398412
await updateGlobalState(context, "awsProfile", awsProfile)
399413
await updateGlobalState(context, "awsUseProfile", awsUseProfile)
414+
await updateGlobalState(context, "awsBedrockCustomSelected", awsBedrockCustomSelected)
415+
await updateGlobalState(context, "awsBedrockCustomModelBaseId", awsBedrockCustomModelBaseId)
400416
await updateGlobalState(context, "vertexProjectId", vertexProjectId)
401417
await updateGlobalState(context, "vertexRegion", vertexRegion)
402418
await updateGlobalState(context, "openAiBaseUrl", openAiBaseUrl)

src/shared/api.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ export interface ApiHandlerOptions {
4747
awsUseProfile?: boolean
4848
awsProfile?: string
4949
awsBedrockEndpoint?: string
50+
awsBedrockCustomSelected?: boolean
51+
awsBedrockCustomModelBaseId?: BedrockModelId
5052
vertexProjectId?: string
5153
vertexRegion?: string
5254
openAiBaseUrl?: string

webview-ui/src/components/settings/ApiOptions.tsx

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,101 @@ const ApiOptions = ({
769769
</>
770770
)}
771771
</p>
772+
<label htmlFor="bedrock-model-dropdown">
773+
<span style={{ fontWeight: 500 }}>Model</span>
774+
</label>
775+
<DropdownContainer zIndex={DROPDOWN_Z_INDEX - 2} className="dropdown-container">
776+
<VSCodeDropdown
777+
id="bedrock-model-dropdown"
778+
value={apiConfiguration?.awsBedrockCustomSelected ? "custom" : selectedModelId}
779+
onChange={(e: any) => {
780+
const isCustom = e.target.value === "custom"
781+
setApiConfiguration({
782+
...apiConfiguration,
783+
apiModelId: isCustom ? "" : e.target.value,
784+
awsBedrockCustomSelected: isCustom,
785+
awsBedrockCustomModelBaseId: bedrockDefaultModelId,
786+
})
787+
}}
788+
style={{ width: "100%" }}>
789+
<VSCodeOption value="">Select a model...</VSCodeOption>
790+
{Object.keys(bedrockModels).map((modelId) => (
791+
<VSCodeOption
792+
key={modelId}
793+
value={modelId}
794+
style={{
795+
whiteSpace: "normal",
796+
wordWrap: "break-word",
797+
maxWidth: "100%",
798+
}}>
799+
{modelId}
800+
</VSCodeOption>
801+
))}
802+
<VSCodeOption value="custom">Custom</VSCodeOption>
803+
</VSCodeDropdown>
804+
</DropdownContainer>
805+
{apiConfiguration?.awsBedrockCustomSelected && (
806+
<div>
807+
<p
808+
style={{
809+
fontSize: "12px",
810+
marginTop: "5px",
811+
color: "var(--vscode-descriptionForeground)",
812+
}}>
813+
Select "Custom" when using the Application Inference Profile in Bedrock. Enter the Application
814+
Inference Profile ID in the Model ID field. However, be sure to encode the / in the ARN as %2F.
815+
<br />
816+
Example: arn:aws:bedrock:us-west-2:&lt;AWS Account
817+
ID&gt;:application-inference-profile%2Fxxxxxxxxxxxx
818+
</p>
819+
<label htmlFor="bedrock-model-input">
820+
<span style={{ fontWeight: 500 }}>Model ID</span>
821+
</label>
822+
<VSCodeTextField
823+
id="bedrock-model-input"
824+
value={apiConfiguration?.apiModelId || ""}
825+
style={{ width: "100%", marginTop: 3 }}
826+
onInput={handleInputChange("apiModelId")}
827+
placeholder="Enter custom model ID..."
828+
/>
829+
<label htmlFor="bedrock-base-model-dropdown">
830+
<span style={{ fontWeight: 500 }}>Base Inference Model</span>
831+
</label>
832+
<DropdownContainer zIndex={DROPDOWN_Z_INDEX - 3} className="dropdown-container">
833+
<VSCodeDropdown
834+
id="bedrock-base-model-dropdown"
835+
value={apiConfiguration?.awsBedrockCustomModelBaseId || bedrockDefaultModelId}
836+
onChange={handleInputChange("awsBedrockCustomModelBaseId")}
837+
style={{ width: "100%" }}>
838+
<VSCodeOption value="">Select a model...</VSCodeOption>
839+
{Object.keys(bedrockModels).map((modelId) => (
840+
<VSCodeOption
841+
key={modelId}
842+
value={modelId}
843+
style={{
844+
whiteSpace: "normal",
845+
wordWrap: "break-word",
846+
maxWidth: "100%",
847+
}}>
848+
{modelId}
849+
</VSCodeOption>
850+
))}
851+
</VSCodeDropdown>
852+
</DropdownContainer>
853+
</div>
854+
)}
855+
{(selectedModelId === "anthropic.claude-3-7-sonnet-20250219-v1:0" ||
856+
(apiConfiguration?.awsBedrockCustomSelected &&
857+
apiConfiguration?.awsBedrockCustomModelBaseId === "anthropic.claude-3-7-sonnet-20250219-v1:0")) && (
858+
<ThinkingBudgetSlider apiConfiguration={apiConfiguration} setApiConfiguration={setApiConfiguration} />
859+
)}
860+
<ModelInfoView
861+
selectedModelId={selectedModelId}
862+
modelInfo={selectedModelInfo}
863+
isDescriptionExpanded={isDescriptionExpanded}
864+
setIsDescriptionExpanded={setIsDescriptionExpanded}
865+
isPopup={isPopup}
866+
/>
772867
</div>
773868
)}
774869

@@ -1702,14 +1797,14 @@ const ApiOptions = ({
17021797
selectedProvider !== "vscode-lm" &&
17031798
selectedProvider !== "litellm" &&
17041799
selectedProvider !== "requesty" &&
1800+
selectedProvider !== "bedrock" &&
17051801
showModelOptions && (
17061802
<>
17071803
<DropdownContainer zIndex={DROPDOWN_Z_INDEX - 2} className="dropdown-container">
17081804
<label htmlFor="model-id">
17091805
<span style={{ fontWeight: 500 }}>Model</span>
17101806
</label>
17111807
{selectedProvider === "anthropic" && createDropdown(anthropicModels)}
1712-
{selectedProvider === "bedrock" && createDropdown(bedrockModels)}
17131808
{selectedProvider === "vertex" && createDropdown(vertexModels)}
17141809
{selectedProvider === "gemini" && createDropdown(geminiModels)}
17151810
{selectedProvider === "openai-native" && createDropdown(openAiNativeModels)}
@@ -1726,7 +1821,6 @@ const ApiOptions = ({
17261821
</DropdownContainer>
17271822

17281823
{((selectedProvider === "anthropic" && selectedModelId === "claude-3-7-sonnet-20250219") ||
1729-
(selectedProvider === "bedrock" && selectedModelId === "anthropic.claude-3-7-sonnet-20250219-v1:0") ||
17301824
(selectedProvider === "vertex" && selectedModelId === "claude-3-7-sonnet@20250219")) && (
17311825
<ThinkingBudgetSlider apiConfiguration={apiConfiguration} setApiConfiguration={setApiConfiguration} />
17321826
)}
@@ -2048,6 +2142,14 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration):
20482142
case "anthropic":
20492143
return getProviderData(anthropicModels, anthropicDefaultModelId)
20502144
case "bedrock":
2145+
if (apiConfiguration?.awsBedrockCustomSelected) {
2146+
const baseModelId = apiConfiguration.awsBedrockCustomModelBaseId
2147+
return {
2148+
selectedProvider: provider,
2149+
selectedModelId: modelId || bedrockDefaultModelId,
2150+
selectedModelInfo: (baseModelId && bedrockModels[baseModelId]) || bedrockModels[bedrockDefaultModelId],
2151+
}
2152+
}
20512153
return getProviderData(bedrockModels, bedrockDefaultModelId)
20522154
case "vertex":
20532155
return getProviderData(vertexModels, vertexDefaultModelId)

0 commit comments

Comments
 (0)