Skip to content

Commit 626fd45

Browse files
committed
feat: enable prompt caching detection for AWS Bedrock custom ARNs
- Add new message types for requesting Bedrock model capabilities - Implement backend handler to parse ARN and return model capabilities - Create useBedrockModelCapabilities hook to fetch capabilities from backend - Update useSelectedModel to use dynamic capabilities instead of hardcoded values - Add comprehensive tests for the new functionality Fixes #6429
1 parent 4e8b174 commit 626fd45

File tree

6 files changed

+237
-0
lines changed

6 files changed

+237
-0
lines changed

src/core/webview/webviewMessageHandler.ts

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,48 @@ export const webviewMessageHandler = async (
694694
})
695695
}
696696
break
697+
case "requestBedrockModelCapabilities":
698+
// Handle request for Bedrock model capabilities
699+
if (message.values?.customArn) {
700+
try {
701+
const { apiConfiguration } = await provider.getState()
702+
703+
// Only process if using bedrock provider
704+
if (apiConfiguration.apiProvider === "bedrock") {
705+
// Import the bedrock handler dynamically
706+
const { AwsBedrockHandler } = await import("../../api/providers/bedrock")
707+
708+
// Create a temporary handler instance to get model info
709+
const tempHandler = new AwsBedrockHandler({
710+
...apiConfiguration,
711+
awsCustomArn: message.values.customArn,
712+
})
713+
714+
// Get the model info which includes capabilities
715+
const modelInfo = tempHandler.getModel()
716+
717+
// Send the capabilities back to the webview
718+
await provider.postMessageToWebview({
719+
type: "bedrockModelCapabilities",
720+
values: {
721+
customArn: message.values.customArn,
722+
modelInfo: modelInfo.info,
723+
},
724+
})
725+
}
726+
} catch (error) {
727+
provider.log(`Error getting Bedrock model capabilities: ${error}`)
728+
// Send error response
729+
await provider.postMessageToWebview({
730+
type: "bedrockModelCapabilities",
731+
values: {
732+
customArn: message.values.customArn,
733+
error: error instanceof Error ? error.message : String(error),
734+
},
735+
})
736+
}
737+
}
738+
break
697739
case "openImage":
698740
openImage(message.text!, { values: message.values })
699741
break

src/shared/ExtensionMessage.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ export interface ExtensionMessage {
120120
| "showEditMessageDialog"
121121
| "commands"
122122
| "insertTextIntoTextarea"
123+
| "bedrockModelCapabilities"
123124
text?: string
124125
payload?: any // Add a generic payload for now, can refine later
125126
action?:

src/shared/WebviewMessage.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ export interface WebviewMessage {
6868
| "requestLmStudioModels"
6969
| "requestVsCodeLmModels"
7070
| "requestHuggingFaceModels"
71+
| "requestBedrockModelCapabilities"
7172
| "openImage"
7273
| "saveImage"
7374
| "openFile"
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import { renderHook, act } from "@testing-library/react"
2+
import { vi, describe, it, expect, beforeEach } from "vitest"
3+
import { useBedrockModelCapabilities } from "../useBedrockModelCapabilities"
4+
import { vscode } from "../../../../utils/vscode"
5+
6+
// Mock vscode
7+
vi.mock("../../../../utils/vscode", () => ({
8+
vscode: {
9+
postMessage: vi.fn(),
10+
},
11+
}))
12+
13+
describe("useBedrockModelCapabilities", () => {
14+
beforeEach(() => {
15+
vi.clearAllMocks()
16+
})
17+
18+
it("should return undefined when no customArn is provided", () => {
19+
const { result } = renderHook(() => useBedrockModelCapabilities())
20+
expect(result.current).toBeUndefined()
21+
expect(vscode.postMessage).not.toHaveBeenCalled()
22+
})
23+
24+
it("should request capabilities when customArn is provided", () => {
25+
const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model"
26+
renderHook(() => useBedrockModelCapabilities(customArn))
27+
28+
expect(vscode.postMessage).toHaveBeenCalledWith({
29+
type: "requestBedrockModelCapabilities",
30+
values: { customArn },
31+
})
32+
})
33+
34+
it("should update capabilities when receiving a successful response", () => {
35+
const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model"
36+
const { result } = renderHook(() => useBedrockModelCapabilities(customArn))
37+
38+
const mockCapabilities = {
39+
maxTokens: 8192,
40+
contextWindow: 200000,
41+
supportsPromptCache: true,
42+
supportsImages: true,
43+
}
44+
45+
act(() => {
46+
const event = new MessageEvent("message", {
47+
data: {
48+
type: "bedrockModelCapabilities",
49+
values: {
50+
customArn,
51+
modelInfo: mockCapabilities,
52+
},
53+
},
54+
})
55+
window.dispatchEvent(event)
56+
})
57+
58+
expect(result.current).toEqual(mockCapabilities)
59+
})
60+
61+
it("should handle error responses gracefully", () => {
62+
const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model"
63+
const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {})
64+
const { result } = renderHook(() => useBedrockModelCapabilities(customArn))
65+
66+
act(() => {
67+
const event = new MessageEvent("message", {
68+
data: {
69+
type: "bedrockModelCapabilities",
70+
values: {
71+
customArn,
72+
error: "Failed to parse ARN",
73+
},
74+
},
75+
})
76+
window.dispatchEvent(event)
77+
})
78+
79+
expect(result.current).toBeUndefined()
80+
expect(consoleSpy).toHaveBeenCalledWith("Error fetching Bedrock model capabilities:", "Failed to parse ARN")
81+
82+
consoleSpy.mockRestore()
83+
})
84+
85+
it("should ignore responses for different ARNs", () => {
86+
const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model"
87+
const { result } = renderHook(() => useBedrockModelCapabilities(customArn))
88+
89+
act(() => {
90+
const event = new MessageEvent("message", {
91+
data: {
92+
type: "bedrockModelCapabilities",
93+
values: {
94+
customArn: "different-arn",
95+
modelInfo: { maxTokens: 1000 },
96+
},
97+
},
98+
})
99+
window.dispatchEvent(event)
100+
})
101+
102+
expect(result.current).toBeUndefined()
103+
})
104+
105+
it("should clean up event listener on unmount", () => {
106+
const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model"
107+
const removeEventListenerSpy = vi.spyOn(window, "removeEventListener")
108+
const { unmount } = renderHook(() => useBedrockModelCapabilities(customArn))
109+
110+
unmount()
111+
112+
expect(removeEventListenerSpy).toHaveBeenCalledWith("message", expect.any(Function))
113+
removeEventListenerSpy.mockRestore()
114+
})
115+
116+
it("should request new capabilities when customArn changes", () => {
117+
const { rerender } = renderHook(({ arn }) => useBedrockModelCapabilities(arn), {
118+
initialProps: { arn: "arn1" },
119+
})
120+
121+
expect(vscode.postMessage).toHaveBeenCalledWith({
122+
type: "requestBedrockModelCapabilities",
123+
values: { customArn: "arn1" },
124+
})
125+
126+
rerender({ arn: "arn2" })
127+
128+
expect(vscode.postMessage).toHaveBeenCalledWith({
129+
type: "requestBedrockModelCapabilities",
130+
values: { customArn: "arn2" },
131+
})
132+
133+
expect(vscode.postMessage).toHaveBeenCalledTimes(2)
134+
})
135+
})
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import { useEffect, useState } from "react"
2+
import { vscode } from "../../../utils/vscode"
3+
import type { ModelInfo } from "@roo-code/types"
4+
5+
export function useBedrockModelCapabilities(customArn?: string): ModelInfo | undefined {
6+
const [capabilities, setCapabilities] = useState<ModelInfo | undefined>(undefined)
7+
8+
useEffect(() => {
9+
if (!customArn) {
10+
setCapabilities(undefined)
11+
return
12+
}
13+
14+
// Request capabilities from backend
15+
vscode.postMessage({
16+
type: "requestBedrockModelCapabilities",
17+
values: { customArn },
18+
})
19+
20+
// Listen for response
21+
const handler = (event: MessageEvent) => {
22+
const message = event.data
23+
if (message.type === "bedrockModelCapabilities" && message.values?.customArn === customArn) {
24+
if (message.values.modelInfo) {
25+
setCapabilities(message.values.modelInfo)
26+
} else if (message.values.error) {
27+
console.error("Error fetching Bedrock model capabilities:", message.values.error)
28+
// Keep undefined to fall back to defaults
29+
setCapabilities(undefined)
30+
}
31+
}
32+
}
33+
34+
window.addEventListener("message", handler)
35+
36+
return () => {
37+
window.removeEventListener("message", handler)
38+
}
39+
}, [customArn])
40+
41+
return capabilities
42+
}

webview-ui/src/components/ui/hooks/useSelectedModel.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ import type { ModelRecord, RouterModels } from "@roo/api"
5151
import { useRouterModels } from "./useRouterModels"
5252
import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders"
5353
import { useLmStudioModels } from "./useLmStudioModels"
54+
import { useBedrockModelCapabilities } from "./useBedrockModelCapabilities"
5455

5556
export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
5657
const provider = apiConfiguration?.apiProvider || "anthropic"
@@ -61,6 +62,12 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
6162
const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId)
6263
const lmStudioModels = useLmStudioModels(lmStudioModelId)
6364

65+
// Always call the hook, but only use it when needed
66+
const isBedrockCustomArn = provider === "bedrock" && apiConfiguration?.apiModelId === "custom-arn"
67+
const bedrockCapabilities = useBedrockModelCapabilities(
68+
isBedrockCustomArn ? apiConfiguration?.awsCustomArn : undefined,
69+
)
70+
6471
const { id, info } =
6572
apiConfiguration &&
6673
(typeof lmStudioModelId === "undefined" || typeof lmStudioModels.data !== "undefined") &&
@@ -72,6 +79,7 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
7279
routerModels: routerModels.data,
7380
openRouterModelProviders: openRouterModelProviders.data,
7481
lmStudioModels: lmStudioModels.data,
82+
bedrockCapabilities,
7583
})
7684
: { id: anthropicDefaultModelId, info: undefined }
7785

@@ -96,12 +104,14 @@ function getSelectedModel({
96104
routerModels,
97105
openRouterModelProviders,
98106
lmStudioModels,
107+
bedrockCapabilities,
99108
}: {
100109
provider: ProviderName
101110
apiConfiguration: ProviderSettings
102111
routerModels: RouterModels
103112
openRouterModelProviders: Record<string, ModelInfo>
104113
lmStudioModels: ModelRecord | undefined
114+
bedrockCapabilities?: ModelInfo
105115
}): { id: string; info: ModelInfo | undefined } {
106116
// the `undefined` case are used to show the invalid selection to prevent
107117
// users from seeing the default model if their selection is invalid
@@ -174,6 +184,12 @@ function getSelectedModel({
174184

175185
// Special case for custom ARN.
176186
if (id === "custom-arn") {
187+
// If we have capabilities from backend, use them
188+
if (bedrockCapabilities) {
189+
return { id, info: bedrockCapabilities }
190+
}
191+
192+
// Otherwise fall back to defaults (this ensures UI doesn't break while loading)
177193
return {
178194
id,
179195
info: { maxTokens: 5000, contextWindow: 128_000, supportsPromptCache: false, supportsImages: true },

0 commit comments

Comments
 (0)