-
Notifications
You must be signed in to change notification settings - Fork 2.6k
feat: enable prompt caching detection for AWS Bedrock custom ARNs #6666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| import { renderHook, act } from "@testing-library/react" | ||
| import { vi, describe, it, expect, beforeEach } from "vitest" | ||
| import { useBedrockModelCapabilities } from "../useBedrockModelCapabilities" | ||
| import { vscode } from "../../../../utils/vscode" | ||
|
|
||
| // Mock vscode | ||
| vi.mock("../../../../utils/vscode", () => ({ | ||
| vscode: { | ||
| postMessage: vi.fn(), | ||
| }, | ||
| })) | ||
|
|
||
| describe("useBedrockModelCapabilities", () => { | ||
| beforeEach(() => { | ||
| vi.clearAllMocks() | ||
| }) | ||
|
|
||
| it("should return undefined when no customArn is provided", () => { | ||
| const { result } = renderHook(() => useBedrockModelCapabilities()) | ||
| expect(result.current).toBeUndefined() | ||
| expect(vscode.postMessage).not.toHaveBeenCalled() | ||
| }) | ||
|
|
||
| it("should request capabilities when customArn is provided", () => { | ||
| const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model" | ||
| renderHook(() => useBedrockModelCapabilities(customArn)) | ||
|
|
||
| expect(vscode.postMessage).toHaveBeenCalledWith({ | ||
| type: "requestBedrockModelCapabilities", | ||
| values: { customArn }, | ||
| }) | ||
| }) | ||
|
|
||
| it("should update capabilities when receiving a successful response", () => { | ||
| const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model" | ||
| const { result } = renderHook(() => useBedrockModelCapabilities(customArn)) | ||
|
|
||
| const mockCapabilities = { | ||
| maxTokens: 8192, | ||
| contextWindow: 200000, | ||
| supportsPromptCache: true, | ||
| supportsImages: true, | ||
| } | ||
|
|
||
| act(() => { | ||
| const event = new MessageEvent("message", { | ||
| data: { | ||
| type: "bedrockModelCapabilities", | ||
| values: { | ||
| customArn, | ||
| modelInfo: mockCapabilities, | ||
| }, | ||
| }, | ||
| }) | ||
| window.dispatchEvent(event) | ||
| }) | ||
|
|
||
| expect(result.current).toEqual(mockCapabilities) | ||
| }) | ||
|
|
||
| it("should handle error responses gracefully", () => { | ||
| const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model" | ||
| const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {}) | ||
| const { result } = renderHook(() => useBedrockModelCapabilities(customArn)) | ||
|
|
||
| act(() => { | ||
| const event = new MessageEvent("message", { | ||
| data: { | ||
| type: "bedrockModelCapabilities", | ||
| values: { | ||
| customArn, | ||
| error: "Failed to parse ARN", | ||
| }, | ||
| }, | ||
| }) | ||
| window.dispatchEvent(event) | ||
| }) | ||
|
|
||
| expect(result.current).toBeUndefined() | ||
| expect(consoleSpy).toHaveBeenCalledWith("Error fetching Bedrock model capabilities:", "Failed to parse ARN") | ||
|
|
||
| consoleSpy.mockRestore() | ||
| }) | ||
|
|
||
| it("should ignore responses for different ARNs", () => { | ||
| const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model" | ||
| const { result } = renderHook(() => useBedrockModelCapabilities(customArn)) | ||
|
|
||
| act(() => { | ||
| const event = new MessageEvent("message", { | ||
| data: { | ||
| type: "bedrockModelCapabilities", | ||
| values: { | ||
| customArn: "different-arn", | ||
| modelInfo: { maxTokens: 1000 }, | ||
| }, | ||
| }, | ||
| }) | ||
| window.dispatchEvent(event) | ||
| }) | ||
|
|
||
| expect(result.current).toBeUndefined() | ||
| }) | ||
|
|
||
| it("should clean up event listener on unmount", () => { | ||
| const customArn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/test-model" | ||
| const removeEventListenerSpy = vi.spyOn(window, "removeEventListener") | ||
| const { unmount } = renderHook(() => useBedrockModelCapabilities(customArn)) | ||
|
|
||
| unmount() | ||
|
|
||
| expect(removeEventListenerSpy).toHaveBeenCalledWith("message", expect.any(Function)) | ||
| removeEventListenerSpy.mockRestore() | ||
| }) | ||
|
|
||
| it("should request new capabilities when customArn changes", () => { | ||
| const { rerender } = renderHook(({ arn }) => useBedrockModelCapabilities(arn), { | ||
| initialProps: { arn: "arn1" }, | ||
| }) | ||
|
|
||
| expect(vscode.postMessage).toHaveBeenCalledWith({ | ||
| type: "requestBedrockModelCapabilities", | ||
| values: { customArn: "arn1" }, | ||
| }) | ||
|
|
||
| rerender({ arn: "arn2" }) | ||
|
|
||
| expect(vscode.postMessage).toHaveBeenCalledWith({ | ||
| type: "requestBedrockModelCapabilities", | ||
| values: { customArn: "arn2" }, | ||
| }) | ||
|
|
||
| expect(vscode.postMessage).toHaveBeenCalledTimes(2) | ||
| }) | ||
| }) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| import { useEffect, useState } from "react" | ||
| import { vscode } from "../../../utils/vscode" | ||
| import type { ModelInfo } from "@roo-code/types" | ||
|
|
||
| export function useBedrockModelCapabilities(customArn?: string): ModelInfo | undefined { | ||
| const [capabilities, setCapabilities] = useState<ModelInfo | undefined>(undefined) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider exposing a loading state from this hook so the UI can show appropriate feedback while capabilities are being fetched. This would improve the user experience during the async operation. |
||
|
|
||
| useEffect(() => { | ||
| if (!customArn) { | ||
| setCapabilities(undefined) | ||
| return | ||
| } | ||
|
|
||
| // Request capabilities from backend | ||
| vscode.postMessage({ | ||
| type: "requestBedrockModelCapabilities", | ||
| values: { customArn }, | ||
| }) | ||
|
|
||
| // Listen for response | ||
| const handler = (event: MessageEvent) => { | ||
| const message = event.data | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For better type safety, consider defining a specific type for the message event data instead of relying on implicit any: interface BedrockCapabilitiesMessage {
type: string;
values?: {
customArn: string;
modelInfo?: ModelInfo;
error?: string;
};
} |
||
| if (message.type === "bedrockModelCapabilities" && message.values?.customArn === customArn) { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The hook doesn't handle potential race conditions if the ARN changes rapidly. Consider adding a cleanup mechanism to ignore responses from outdated requests. You could use a ref to track the current request or implement request cancellation. |
||
| if (message.values.modelInfo) { | ||
| setCapabilities(message.values.modelInfo) | ||
| } else if (message.values.error) { | ||
| console.error("Error fetching Bedrock model capabilities:", message.values.error) | ||
| // Keep undefined to fall back to defaults | ||
| setCapabilities(undefined) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| window.addEventListener("message", handler) | ||
|
|
||
| return () => { | ||
| window.removeEventListener("message", handler) | ||
| } | ||
| }, [customArn]) | ||
|
|
||
| return capabilities | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider enhancing the error logging to include the ARN for better debugging: