From 626fd45079273ac3a58f0266b573d3f3286d6018 Mon Sep 17 00:00:00 2001 From: Roo Code Date: Mon, 4 Aug 2025 16:42:32 +0000 Subject: [PATCH] feat: enable prompt caching detection for AWS Bedrock custom ARNs - Add new message types for requesting Bedrock model capabilities - Implement backend handler to parse ARN and return model capabilities - Create useBedrockModelCapabilities hook to fetch capabilities from backend - Update useSelectedModel to use dynamic capabilities instead of hardcoded values - Add comprehensive tests for the new functionality Fixes #6429 --- src/core/webview/webviewMessageHandler.ts | 42 ++++++ src/shared/ExtensionMessage.ts | 1 + src/shared/WebviewMessage.ts | 1 + .../useBedrockModelCapabilities.spec.ts | 135 ++++++++++++++++++ .../ui/hooks/useBedrockModelCapabilities.ts | 42 ++++++ .../components/ui/hooks/useSelectedModel.ts | 16 +++ 6 files changed, 237 insertions(+) create mode 100644 webview-ui/src/components/ui/hooks/__tests__/useBedrockModelCapabilities.spec.ts create mode 100644 webview-ui/src/components/ui/hooks/useBedrockModelCapabilities.ts diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index fdb7e904257..8c830bb55a6 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -694,6 +694,48 @@ export const webviewMessageHandler = async ( }) } break + case "requestBedrockModelCapabilities": + // Handle request for Bedrock model capabilities + if (message.values?.customArn) { + try { + const { apiConfiguration } = await provider.getState() + + // Only process if using bedrock provider + if (apiConfiguration.apiProvider === "bedrock") { + // Import the bedrock handler dynamically + const { AwsBedrockHandler } = await import("../../api/providers/bedrock") + + // Create a temporary handler instance to get model info + const tempHandler = new AwsBedrockHandler({ + ...apiConfiguration, + awsCustomArn: message.values.customArn, + }) + + // Get the model info which includes capabilities + const modelInfo = tempHandler.getModel() + + // Send the capabilities back to the webview + await provider.postMessageToWebview({ + type: "bedrockModelCapabilities", + values: { + customArn: message.values.customArn, + modelInfo: modelInfo.info, + }, + }) + } + } catch (error) { + provider.log(`Error getting Bedrock model capabilities: ${error}`) + // Send error response + await provider.postMessageToWebview({ + type: "bedrockModelCapabilities", + values: { + customArn: message.values.customArn, + error: error instanceof Error ? error.message : String(error), + }, + }) + } + } + break case "openImage": openImage(message.text!, { values: message.values }) break diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 930edeac732..5ae32ac17d2 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -120,6 +120,7 @@ export interface ExtensionMessage { | "showEditMessageDialog" | "commands" | "insertTextIntoTextarea" + | "bedrockModelCapabilities" text?: string payload?: any // Add a generic payload for now, can refine later action?: diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index cb8759d851c..b38ba6c7e22 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -68,6 +68,7 @@ export interface WebviewMessage { | "requestLmStudioModels" | "requestVsCodeLmModels" | "requestHuggingFaceModels" + | "requestBedrockModelCapabilities" | "openImage" | "saveImage" | "openFile" diff --git a/webview-ui/src/components/ui/hooks/__tests__/useBedrockModelCapabilities.spec.ts b/webview-ui/src/components/ui/hooks/__tests__/useBedrockModelCapabilities.spec.ts new file mode 100644 index 00000000000..7bcb8ebe19e --- /dev/null +++ b/webview-ui/src/components/ui/hooks/__tests__/useBedrockModelCapabilities.spec.ts @@ -0,0 +1,135 @@ +import { renderHook, act } from "@testing-library/react" +import { vi, describe, it, expect, beforeEach } from "vitest" +import { useBedrockModelCapabilities } from "../useBedrockModelCapabilities" +import { vscode } from "../../../../utils/vscode" + +// Mock vscode +vi.mock("../../../../utils/vscode", () => ({ + vscode: { + postMessage: vi.fn(), + }, +})) + +describe("useBedrockModelCapabilities", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it("should return undefined when no customArn is provided", () => { + const { result } = renderHook(() => useBedrockModelCapabilities()) + expect(result.current).toBeUndefined() + expect(vscode.postMessage).not.toHaveBeenCalled() + }) + + it("should request capabilities when customArn is provided", () => { + const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model" + renderHook(() => useBedrockModelCapabilities(customArn)) + + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "requestBedrockModelCapabilities", + values: { customArn }, + }) + }) + + it("should update capabilities when receiving a successful response", () => { + const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model" + const { result } = renderHook(() => useBedrockModelCapabilities(customArn)) + + const mockCapabilities = { + maxTokens: 8192, + contextWindow: 200000, + supportsPromptCache: true, + supportsImages: true, + } + + act(() => { + const event = new MessageEvent("message", { + data: { + type: "bedrockModelCapabilities", + values: { + customArn, + modelInfo: mockCapabilities, + }, + }, + }) + window.dispatchEvent(event) + }) + + expect(result.current).toEqual(mockCapabilities) + }) + + it("should handle error responses gracefully", () => { + const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model" + const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {}) + const { result } = renderHook(() => useBedrockModelCapabilities(customArn)) + + act(() => { + const event = new MessageEvent("message", { + data: { + type: "bedrockModelCapabilities", + values: { + customArn, + error: "Failed to parse ARN", + }, + }, + }) + window.dispatchEvent(event) + }) + + expect(result.current).toBeUndefined() + expect(consoleSpy).toHaveBeenCalledWith("Error fetching Bedrock model capabilities:", "Failed to parse ARN") + + consoleSpy.mockRestore() + }) + + it("should ignore responses for different ARNs", () => { + const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model" + const { result } = renderHook(() => useBedrockModelCapabilities(customArn)) + + act(() => { + const event = new MessageEvent("message", { + data: { + type: "bedrockModelCapabilities", + values: { + customArn: "different-arn", + modelInfo: { maxTokens: 1000 }, + }, + }, + }) + window.dispatchEvent(event) + }) + + expect(result.current).toBeUndefined() + }) + + it("should clean up event listener on unmount", () => { + const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model" + const removeEventListenerSpy = vi.spyOn(window, "removeEventListener") + const { unmount } = renderHook(() => useBedrockModelCapabilities(customArn)) + + unmount() + + expect(removeEventListenerSpy).toHaveBeenCalledWith("message", expect.any(Function)) + removeEventListenerSpy.mockRestore() + }) + + it("should request new capabilities when customArn changes", () => { + const { rerender } = renderHook(({ arn }) => useBedrockModelCapabilities(arn), { + initialProps: { arn: "arn1" }, + }) + + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "requestBedrockModelCapabilities", + values: { customArn: "arn1" }, + }) + + rerender({ arn: "arn2" }) + + expect(vscode.postMessage).toHaveBeenCalledWith({ + type: "requestBedrockModelCapabilities", + values: { customArn: "arn2" }, + }) + + expect(vscode.postMessage).toHaveBeenCalledTimes(2) + }) +}) diff --git a/webview-ui/src/components/ui/hooks/useBedrockModelCapabilities.ts b/webview-ui/src/components/ui/hooks/useBedrockModelCapabilities.ts new file mode 100644 index 00000000000..d9f33abeeea --- /dev/null +++ b/webview-ui/src/components/ui/hooks/useBedrockModelCapabilities.ts @@ -0,0 +1,42 @@ +import { useEffect, useState } from "react" +import { vscode } from "../../../utils/vscode" +import type { ModelInfo } from "@roo-code/types" + +export function useBedrockModelCapabilities(customArn?: string): ModelInfo | undefined { + const [capabilities, setCapabilities] = useState(undefined) + + useEffect(() => { + if (!customArn) { + setCapabilities(undefined) + return + } + + // Request capabilities from backend + vscode.postMessage({ + type: "requestBedrockModelCapabilities", + values: { customArn }, + }) + + // Listen for response + const handler = (event: MessageEvent) => { + const message = event.data + if (message.type === "bedrockModelCapabilities" && message.values?.customArn === customArn) { + if (message.values.modelInfo) { + setCapabilities(message.values.modelInfo) + } else if (message.values.error) { + console.error("Error fetching Bedrock model capabilities:", message.values.error) + // Keep undefined to fall back to defaults + setCapabilities(undefined) + } + } + } + + window.addEventListener("message", handler) + + return () => { + window.removeEventListener("message", handler) + } + }, [customArn]) + + return capabilities +} diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index a191014981e..62e485e5e0f 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -51,6 +51,7 @@ import type { ModelRecord, RouterModels } from "@roo/api" import { useRouterModels } from "./useRouterModels" import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders" import { useLmStudioModels } from "./useLmStudioModels" +import { useBedrockModelCapabilities } from "./useBedrockModelCapabilities" export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { const provider = apiConfiguration?.apiProvider || "anthropic" @@ -61,6 +62,12 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId) const lmStudioModels = useLmStudioModels(lmStudioModelId) + // Always call the hook, but only use it when needed + const isBedrockCustomArn = provider === "bedrock" && apiConfiguration?.apiModelId === "custom-arn" + const bedrockCapabilities = useBedrockModelCapabilities( + isBedrockCustomArn ? apiConfiguration?.awsCustomArn : undefined, + ) + const { id, info } = apiConfiguration && (typeof lmStudioModelId === "undefined" || typeof lmStudioModels.data !== "undefined") && @@ -72,6 +79,7 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { routerModels: routerModels.data, openRouterModelProviders: openRouterModelProviders.data, lmStudioModels: lmStudioModels.data, + bedrockCapabilities, }) : { id: anthropicDefaultModelId, info: undefined } @@ -96,12 +104,14 @@ function getSelectedModel({ routerModels, openRouterModelProviders, lmStudioModels, + bedrockCapabilities, }: { provider: ProviderName apiConfiguration: ProviderSettings routerModels: RouterModels openRouterModelProviders: Record lmStudioModels: ModelRecord | undefined + bedrockCapabilities?: ModelInfo }): { id: string; info: ModelInfo | undefined } { // the `undefined` case are used to show the invalid selection to prevent // users from seeing the default model if their selection is invalid @@ -174,6 +184,12 @@ function getSelectedModel({ // Special case for custom ARN. if (id === "custom-arn") { + // If we have capabilities from backend, use them + if (bedrockCapabilities) { + return { id, info: bedrockCapabilities } + } + + // Otherwise fall back to defaults (this ensures UI doesn't break while loading) return { id, info: { maxTokens: 5000, contextWindow: 128_000, supportsPromptCache: false, supportsImages: true },