Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,48 @@ export const webviewMessageHandler = async (
})
}
break
case "requestBedrockModelCapabilities":
// Handle request for Bedrock model capabilities
if (message.values?.customArn) {
try {
const { apiConfiguration } = await provider.getState()

// Only process if using bedrock provider
if (apiConfiguration.apiProvider === "bedrock") {
// Import the bedrock handler dynamically
const { AwsBedrockHandler } = await import("../../api/providers/bedrock")

// Create a temporary handler instance to get model info
const tempHandler = new AwsBedrockHandler({
...apiConfiguration,
awsCustomArn: message.values.customArn,
})

// Get the model info which includes capabilities
const modelInfo = tempHandler.getModel()

// Send the capabilities back to the webview
await provider.postMessageToWebview({
type: "bedrockModelCapabilities",
values: {
customArn: message.values.customArn,
modelInfo: modelInfo.info,
},
})
}
} catch (error) {
provider.log(`Error getting Bedrock model capabilities: ${error}`)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider enhancing the error logging to include the ARN for better debugging:

Suggested change
provider.log(`Error getting Bedrock model capabilities: ${error}`)
provider.log(`Error getting Bedrock model capabilities for ARN ${message.values.customArn}: ${error}`)

// Send error response
await provider.postMessageToWebview({
type: "bedrockModelCapabilities",
values: {
customArn: message.values.customArn,
error: error instanceof Error ? error.message : String(error),
},
})
}
}
break
case "openImage":
openImage(message.text!, { values: message.values })
break
Expand Down
1 change: 1 addition & 0 deletions src/shared/ExtensionMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ export interface ExtensionMessage {
| "showEditMessageDialog"
| "commands"
| "insertTextIntoTextarea"
| "bedrockModelCapabilities"
text?: string
payload?: any // Add a generic payload for now, can refine later
action?:
Expand Down
1 change: 1 addition & 0 deletions src/shared/WebviewMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ export interface WebviewMessage {
| "requestLmStudioModels"
| "requestVsCodeLmModels"
| "requestHuggingFaceModels"
| "requestBedrockModelCapabilities"
| "openImage"
| "saveImage"
| "openFile"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import { renderHook, act } from "@testing-library/react"
import { vi, describe, it, expect, beforeEach } from "vitest"
import { useBedrockModelCapabilities } from "../useBedrockModelCapabilities"
import { vscode } from "../../../../utils/vscode"

// Mock vscode
vi.mock("../../../../utils/vscode", () => ({
vscode: {
postMessage: vi.fn(),
},
}))

describe("useBedrockModelCapabilities", () => {
beforeEach(() => {
vi.clearAllMocks()
})

it("should return undefined when no customArn is provided", () => {
const { result } = renderHook(() => useBedrockModelCapabilities())
expect(result.current).toBeUndefined()
expect(vscode.postMessage).not.toHaveBeenCalled()
})

it("should request capabilities when customArn is provided", () => {
const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model"
renderHook(() => useBedrockModelCapabilities(customArn))

expect(vscode.postMessage).toHaveBeenCalledWith({
type: "requestBedrockModelCapabilities",
values: { customArn },
})
})

it("should update capabilities when receiving a successful response", () => {
const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model"
const { result } = renderHook(() => useBedrockModelCapabilities(customArn))

const mockCapabilities = {
maxTokens: 8192,
contextWindow: 200000,
supportsPromptCache: true,
supportsImages: true,
}

act(() => {
const event = new MessageEvent("message", {
data: {
type: "bedrockModelCapabilities",
values: {
customArn,
modelInfo: mockCapabilities,
},
},
})
window.dispatchEvent(event)
})

expect(result.current).toEqual(mockCapabilities)
})

it("should handle error responses gracefully", () => {
const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model"
const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {})
const { result } = renderHook(() => useBedrockModelCapabilities(customArn))

act(() => {
const event = new MessageEvent("message", {
data: {
type: "bedrockModelCapabilities",
values: {
customArn,
error: "Failed to parse ARN",
},
},
})
window.dispatchEvent(event)
})

expect(result.current).toBeUndefined()
expect(consoleSpy).toHaveBeenCalledWith("Error fetching Bedrock model capabilities:", "Failed to parse ARN")

consoleSpy.mockRestore()
})

it("should ignore responses for different ARNs", () => {
const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model"
const { result } = renderHook(() => useBedrockModelCapabilities(customArn))

act(() => {
const event = new MessageEvent("message", {
data: {
type: "bedrockModelCapabilities",
values: {
customArn: "different-arn",
modelInfo: { maxTokens: 1000 },
},
},
})
window.dispatchEvent(event)
})

expect(result.current).toBeUndefined()
})

it("should clean up event listener on unmount", () => {
const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model"
const removeEventListenerSpy = vi.spyOn(window, "removeEventListener")
const { unmount } = renderHook(() => useBedrockModelCapabilities(customArn))

unmount()

expect(removeEventListenerSpy).toHaveBeenCalledWith("message", expect.any(Function))
removeEventListenerSpy.mockRestore()
})

it("should request new capabilities when customArn changes", () => {
const { rerender } = renderHook(({ arn }) => useBedrockModelCapabilities(arn), {
initialProps: { arn: "arn1" },
})

expect(vscode.postMessage).toHaveBeenCalledWith({
type: "requestBedrockModelCapabilities",
values: { customArn: "arn1" },
})

rerender({ arn: "arn2" })

expect(vscode.postMessage).toHaveBeenCalledWith({
type: "requestBedrockModelCapabilities",
values: { customArn: "arn2" },
})

expect(vscode.postMessage).toHaveBeenCalledTimes(2)
})
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { useEffect, useState } from "react"
import { vscode } from "../../../utils/vscode"
import type { ModelInfo } from "@roo-code/types"

export function useBedrockModelCapabilities(customArn?: string): ModelInfo | undefined {
const [capabilities, setCapabilities] = useState<ModelInfo | undefined>(undefined)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider exposing a loading state from this hook so the UI can show appropriate feedback while capabilities are being fetched. This would improve the user experience during the async operation.


useEffect(() => {
if (!customArn) {
setCapabilities(undefined)
return
}

// Request capabilities from backend
vscode.postMessage({
type: "requestBedrockModelCapabilities",
values: { customArn },
})

// Listen for response
const handler = (event: MessageEvent) => {
const message = event.data
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better type safety, consider defining a specific type for the message event data instead of relying on implicit any:

interface BedrockCapabilitiesMessage {
  type: string;
  values?: {
    customArn: string;
    modelInfo?: ModelInfo;
    error?: string;
  };
}

if (message.type === "bedrockModelCapabilities" && message.values?.customArn === customArn) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hook doesn't handle potential race conditions if the ARN changes rapidly. Consider adding a cleanup mechanism to ignore responses from outdated requests. You could use a ref to track the current request or implement request cancellation.

if (message.values.modelInfo) {
setCapabilities(message.values.modelInfo)
} else if (message.values.error) {
console.error("Error fetching Bedrock model capabilities:", message.values.error)
// Keep undefined to fall back to defaults
setCapabilities(undefined)
}
}
}

window.addEventListener("message", handler)

return () => {
window.removeEventListener("message", handler)
}
}, [customArn])

return capabilities
}
16 changes: 16 additions & 0 deletions webview-ui/src/components/ui/hooks/useSelectedModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import type { ModelRecord, RouterModels } from "@roo/api"
import { useRouterModels } from "./useRouterModels"
import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders"
import { useLmStudioModels } from "./useLmStudioModels"
import { useBedrockModelCapabilities } from "./useBedrockModelCapabilities"

export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
const provider = apiConfiguration?.apiProvider || "anthropic"
Expand All @@ -61,6 +62,12 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId)
const lmStudioModels = useLmStudioModels(lmStudioModelId)

// Always call the hook, but only use it when needed
const isBedrockCustomArn = provider === "bedrock" && apiConfiguration?.apiModelId === "custom-arn"
const bedrockCapabilities = useBedrockModelCapabilities(
isBedrockCustomArn ? apiConfiguration?.awsCustomArn : undefined,
)

const { id, info } =
apiConfiguration &&
(typeof lmStudioModelId === "undefined" || typeof lmStudioModels.data !== "undefined") &&
Expand All @@ -72,6 +79,7 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
routerModels: routerModels.data,
openRouterModelProviders: openRouterModelProviders.data,
lmStudioModels: lmStudioModels.data,
bedrockCapabilities,
})
: { id: anthropicDefaultModelId, info: undefined }

Expand All @@ -96,12 +104,14 @@ function getSelectedModel({
routerModels,
openRouterModelProviders,
lmStudioModels,
bedrockCapabilities,
}: {
provider: ProviderName
apiConfiguration: ProviderSettings
routerModels: RouterModels
openRouterModelProviders: Record<string, ModelInfo>
lmStudioModels: ModelRecord | undefined
bedrockCapabilities?: ModelInfo
}): { id: string; info: ModelInfo | undefined } {
// the `undefined` case are used to show the invalid selection to prevent
// users from seeing the default model if their selection is invalid
Expand Down Expand Up @@ -174,6 +184,12 @@ function getSelectedModel({

// Special case for custom ARN.
if (id === "custom-arn") {
// If we have capabilities from backend, use them
if (bedrockCapabilities) {
return { id, info: bedrockCapabilities }
}

// Otherwise fall back to defaults (this ensures UI doesn't break while loading)
return {
id,
info: { maxTokens: 5000, contextWindow: 128_000, supportsPromptCache: false, supportsImages: true },
Expand Down
Loading