diff --git a/src/api/providers/__tests__/bedrock-global-inference.spec.ts b/src/api/providers/__tests__/bedrock-global-inference.spec.ts new file mode 100644 index 00000000000..991ff62d53a --- /dev/null +++ b/src/api/providers/__tests__/bedrock-global-inference.spec.ts @@ -0,0 +1,229 @@ +// npx vitest run src/api/providers/__tests__/bedrock-global-inference.spec.ts + +import { AwsBedrockHandler } from "../bedrock" +import { BedrockRuntimeClient, ConverseStreamCommand } from "@aws-sdk/client-bedrock-runtime" +import { logger } from "../../../utils/logging" +import type { ProviderSettings } from "@roo-code/types" + +// Mock AWS SDK modules +vitest.mock("@aws-sdk/client-bedrock-runtime", () => { + const mockSend = vi.fn().mockResolvedValue({ + stream: (async function* () { + yield { + contentBlockStart: { + start: { text: "Test response" }, + }, + } + yield { + contentBlockDelta: { + delta: { text: " from Claude" }, + }, + } + yield { + messageStop: {}, + } + })(), + }) + + return { + BedrockRuntimeClient: vi.fn().mockImplementation(() => ({ + send: mockSend, + })), + ConverseStreamCommand: vi.fn(), + ConverseCommand: vi.fn(), + } +}) + +vitest.mock("../../../utils/logging") + +describe("AwsBedrockHandler - Global Inference Profile Support", () => { + let handler: AwsBedrockHandler + let mockSend: any + + beforeEach(() => { + vi.clearAllMocks() + mockSend = vi.fn().mockResolvedValue({ + stream: (async function* () { + yield { + contentBlockStart: { + start: { text: "Test response" }, + }, + } + yield { + contentBlockDelta: { + delta: { text: " from Claude" }, + }, + } + yield { + messageStop: {}, + } + })(), + }) + ;(BedrockRuntimeClient as any).mockImplementation(() => ({ + send: mockSend, + })) + }) + + describe("Global Inference Profile ARN Support", () => { + it("should detect Claude Sonnet 4.5 global inference profile ARN", () => { + const options: ProviderSettings = { + apiProvider: "bedrock", + awsRegion: "us-east-1", + awsCustomArn: + "arn:aws:bedrock:us-east-1:148761681080:inference-profile/global.anthropic.claude-sonnet-4-5-20250929-v1:0", + awsAccessKey: "test-key", + awsSecretKey: "test-secret", + } + + handler = new AwsBedrockHandler(options) + const model = handler.getModel() + + // Should recognize the ARN and provide appropriate model info + expect(model.id).toBe( + "arn:aws:bedrock:us-east-1:148761681080:inference-profile/global.anthropic.claude-sonnet-4-5-20250929-v1:0", + ) + expect(model.info).toBeDefined() + expect(model.info.supportsReasoningBudget).toBe(true) + expect(model.info.supportsPromptCache).toBe(true) + expect(model.info.supportsImages).toBe(true) + }) + + it("should enable 1M context for global inference profile when awsBedrock1MContext is true", async () => { + const options: ProviderSettings = { + apiProvider: "bedrock", + awsRegion: "us-east-1", + awsCustomArn: + "arn:aws:bedrock:us-east-1:148761681080:inference-profile/global.anthropic.claude-sonnet-4-5-20250929-v1:0", + awsBedrock1MContext: true, + awsAccessKey: "test-key", + awsSecretKey: "test-secret", + } + + handler = new AwsBedrockHandler(options) + + const messages = [{ role: "user" as const, content: "Test message" }] + const stream = handler.createMessage("System prompt", messages) + + // Consume the stream + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Check that the command was called + expect(mockSend).toHaveBeenCalled() + expect(ConverseStreamCommand).toHaveBeenCalled() + + // Get the payload from the ConverseStreamCommand constructor + const commandPayload = (ConverseStreamCommand as any).mock.calls[0][0] + expect(commandPayload).toBeDefined() + expect(commandPayload.additionalModelRequestFields).toBeDefined() + expect(commandPayload.additionalModelRequestFields.anthropic_beta).toContain("context-1m-2025-08-07") + }) + + it("should enable thinking/reasoning for global inference profile", async () => { + const options: ProviderSettings = { + apiProvider: "bedrock", + awsRegion: "us-east-1", + awsCustomArn: + "arn:aws:bedrock:us-east-1:148761681080:inference-profile/global.anthropic.claude-sonnet-4-5-20250929-v1:0", + enableReasoningEffort: true, + awsAccessKey: "test-key", + awsSecretKey: "test-secret", + } + + handler = new AwsBedrockHandler(options) + + const messages = [{ role: "user" as const, content: "Test message" }] + const metadata = { + taskId: "test-task-id", + thinking: { + enabled: true, + maxThinkingTokens: 8192, + }, + } + + const stream = handler.createMessage("System prompt", messages, metadata) + + // Consume the stream + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Check that thinking was enabled + expect(logger.info).toHaveBeenCalledWith( + expect.stringContaining("Extended thinking enabled"), + expect.objectContaining({ + ctx: "bedrock", + thinking: expect.objectContaining({ + type: "enabled", + budget_tokens: 8192, + }), + }), + ) + }) + + it("should handle various Claude 4.5 ARN patterns", () => { + const testCases = [ + "arn:aws:bedrock:us-east-1:148761681080:inference-profile/global.anthropic.claude-sonnet-4-5-20250929-v1:0", + "arn:aws:bedrock:eu-west-1:123456789012:inference-profile/anthropic.claude-sonnet-4-5-20250929-v1:0", + "arn:aws:bedrock:ap-southeast-1:987654321098:foundation-model/anthropic.claude-sonnet-4.5-v1:0", + ] + + testCases.forEach((arn) => { + const options: ProviderSettings = { + apiProvider: "bedrock", + awsRegion: "us-east-1", + awsCustomArn: arn, + awsAccessKey: "test-key", + awsSecretKey: "test-secret", + } + + handler = new AwsBedrockHandler(options) + const model = handler.getModel() + + expect(model.info.supportsReasoningBudget).toBe(true) + expect(model.info.supportsPromptCache).toBe(true) + }) + }) + + it("should not enable thinking for non-Claude-4.5 custom ARNs", () => { + const options: ProviderSettings = { + apiProvider: "bedrock", + awsRegion: "us-east-1", + awsCustomArn: + "arn:aws:bedrock:us-east-1:123456789012:foundation-model/anthropic.claude-3-haiku-20240307-v1:0", + awsAccessKey: "test-key", + awsSecretKey: "test-secret", + } + + handler = new AwsBedrockHandler(options) + const model = handler.getModel() + + // Should not have reasoning budget support for non-Claude-4.5 models + expect(model.info.supportsReasoningBudget).toBeFalsy() + }) + }) + + describe("ARN Parsing with Global Inference Profile", () => { + it("should correctly parse global inference profile ARN", () => { + const handler = new AwsBedrockHandler({ + apiProvider: "bedrock", + awsRegion: "us-east-1", + awsAccessKey: "test-key", + awsSecretKey: "test-secret", + }) + + const parseArn = (handler as any).parseArn.bind(handler) + const result = parseArn( + "arn:aws:bedrock:us-east-1:148761681080:inference-profile/global.anthropic.claude-sonnet-4-5-20250929-v1:0", + ) + + expect(result.isValid).toBe(true) + expect(result.region).toBe("us-east-1") + expect(result.modelType).toBe("inference-profile") + expect(result.modelId).toContain("anthropic.claude-sonnet-4-5-20250929") + }) + }) +}) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 493c02483f1..4388d2179fc 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -252,9 +252,34 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } // Helper to guess model info from custom modelId string if not in bedrockModels - private guessModelInfoFromId(modelId: string): Partial { + private guessModelInfoFromId(modelId: string | undefined): Partial { + // Handle undefined or empty modelId + if (!modelId) { + return { + maxTokens: BEDROCK_MAX_TOKENS, + contextWindow: BEDROCK_DEFAULT_CONTEXT, + supportsImages: false, + supportsPromptCache: false, + } + } // Define a mapping for model ID patterns and their configurations const modelConfigMap: Record> = { + // Claude 4.5 Sonnet models (including global inference profile) + "claude-sonnet-4-5": { + maxTokens: 8192, + contextWindow: 200_000, + supportsImages: true, + supportsPromptCache: true, + supportsReasoningBudget: true, + }, + // Claude 4 Sonnet models + "claude-sonnet-4": { + maxTokens: 8192, + contextWindow: 200_000, + supportsImages: true, + supportsPromptCache: true, + supportsReasoningBudget: true, + }, "claude-4": { maxTokens: 8192, contextWindow: 200_000, @@ -266,6 +291,7 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH contextWindow: 200_000, supportsImages: true, supportsPromptCache: true, + supportsReasoningBudget: true, }, "claude-3-5": { maxTokens: 8192, @@ -376,8 +402,10 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH // Check if 1M context is enabled for Claude Sonnet 4 // Use parseBaseModelId to handle cross-region inference prefixes const baseModelId = this.parseBaseModelId(modelConfig.id) - const is1MContextEnabled = - BEDROCK_1M_CONTEXT_MODEL_IDS.includes(baseModelId as any) && this.options.awsBedrock1MContext + // Check if it's a known model ID or if it's a custom ARN that matches Claude 4.5 pattern + const isEligibleFor1MContext = + BEDROCK_1M_CONTEXT_MODEL_IDS.includes(baseModelId as any) || this.isClaudeSonnet45Model(modelConfig.id) + const is1MContextEnabled = isEligibleFor1MContext && this.options.awsBedrock1MContext // Add anthropic_beta for 1M context to additionalModelRequestFields if (is1MContextEnabled) { @@ -889,6 +917,18 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH return modelId } + // Helper method to check if a model ID represents a Claude Sonnet 4.5 model + private isClaudeSonnet45Model(modelId: string): boolean { + const id = modelId.toLowerCase() + // Check for various Claude 4.5 patterns including global inference profile + return ( + id.includes("claude-sonnet-4-5") || + id.includes("claude-sonnet-4.5") || + // Specific check for the global inference profile ARN mentioned in the issue + id.includes("global.anthropic.claude-sonnet-4-5-20250929") + ) + } + //Prompt Router responses come back in a different sequence and the model used is in the response and must be fetched by name getModelById(modelId: string, modelType?: string): { id: BedrockModelId | string; info: ModelInfo } { // Try to find the model in bedrockModels diff --git a/webview-ui/src/components/settings/providers/Bedrock.tsx b/webview-ui/src/components/settings/providers/Bedrock.tsx index 1b3143fa083..3efd9f62261 100644 --- a/webview-ui/src/components/settings/providers/Bedrock.tsx +++ b/webview-ui/src/components/settings/providers/Bedrock.tsx @@ -19,9 +19,25 @@ export const Bedrock = ({ apiConfiguration, setApiConfigurationField, selectedMo const { t } = useAppTranslation() const [awsEndpointSelected, setAwsEndpointSelected] = useState(!!apiConfiguration?.awsBedrockEndpointEnabled) + // Helper function to check if a model ID or ARN represents a Claude Sonnet 4.5 model + const isClaudeSonnet45Model = (modelId: string): boolean => { + if (!modelId) return false + const id = modelId.toLowerCase() + return ( + id.includes("claude-sonnet-4-5") || + id.includes("claude-sonnet-4.5") || + // Specific check for the global inference profile ARN mentioned in the issue + id.includes("global.anthropic.claude-sonnet-4-5-20250929") + ) + } + // Check if the selected model supports 1M context (Claude Sonnet 4 / 4.5) + // This includes both known model IDs and custom ARNs that match Claude 4.5 patterns const supports1MContextBeta = - !!apiConfiguration?.apiModelId && BEDROCK_1M_CONTEXT_MODEL_IDS.includes(apiConfiguration.apiModelId as any) + (!!apiConfiguration?.apiModelId && BEDROCK_1M_CONTEXT_MODEL_IDS.includes(apiConfiguration.apiModelId as any)) || + (apiConfiguration?.apiModelId === "custom-arn" && + apiConfiguration?.awsCustomArn && + isClaudeSonnet45Model(apiConfiguration.awsCustomArn)) // Update the endpoint enabled state when the configuration changes useEffect(() => { diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 0d0514b4d66..37fd78c564c 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -192,12 +192,44 @@ function getSelectedModel({ const id = apiConfiguration.apiModelId ?? bedrockDefaultModelId const baseInfo = bedrockModels[id as keyof typeof bedrockModels] + // Helper function to check if a model ID or ARN represents a Claude Sonnet 4.5 model + const isClaudeSonnet45Model = (modelId: string): boolean => { + if (!modelId) return false + const lowerId = modelId.toLowerCase() + return ( + lowerId.includes("claude-sonnet-4-5") || + lowerId.includes("claude-sonnet-4.5") || + // Specific check for the global inference profile ARN + lowerId.includes("global.anthropic.claude-sonnet-4-5-20250929") + ) + } + // Special case for custom ARN. if (id === "custom-arn") { - return { - id, - info: { maxTokens: 5000, contextWindow: 128_000, supportsPromptCache: false, supportsImages: true }, + const customArn = apiConfiguration.awsCustomArn || "" + const isClaudeSonnet45 = isClaudeSonnet45Model(customArn) + + // Base info for custom ARNs + let info: ModelInfo = { + maxTokens: 5000, + contextWindow: 128_000, + supportsPromptCache: false, + supportsImages: true, } + + // If it's a Claude Sonnet 4.5 model, add thinking support and better defaults + if (isClaudeSonnet45) { + info = { + maxTokens: 8192, + contextWindow: apiConfiguration.awsBedrock1MContext ? 1_000_000 : 200_000, + supportsImages: true, + supportsPromptCache: true, + supportsReasoningBudget: true, + supportsComputerUse: true, + } + } + + return { id, info } } // Apply 1M context for Claude Sonnet 4 / 4.5 when enabled