diff --git a/packages/types/src/providers/bedrock.ts b/packages/types/src/providers/bedrock.ts index ce5ea28e95..a15f041252 100644 --- a/packages/types/src/providers/bedrock.ts +++ b/packages/types/src/providers/bedrock.ts @@ -73,6 +73,7 @@ export const bedrockModels = { supportsImages: true, supportsComputerUse: true, supportsPromptCache: true, + supportsReasoningBudget: true, inputPrice: 3.0, outputPrice: 15.0, cacheWritesPrice: 3.75, @@ -87,6 +88,7 @@ export const bedrockModels = { supportsImages: true, supportsComputerUse: true, supportsPromptCache: true, + supportsReasoningBudget: true, inputPrice: 15.0, outputPrice: 75.0, cacheWritesPrice: 18.75, @@ -101,6 +103,7 @@ export const bedrockModels = { supportsImages: true, supportsComputerUse: true, supportsPromptCache: true, + supportsReasoningBudget: true, inputPrice: 3.0, outputPrice: 15.0, cacheWritesPrice: 3.75, diff --git a/src/api/providers/__tests__/bedrock-reasoning.test.ts b/src/api/providers/__tests__/bedrock-reasoning.test.ts new file mode 100644 index 0000000000..4a45c25701 --- /dev/null +++ b/src/api/providers/__tests__/bedrock-reasoning.test.ts @@ -0,0 +1,280 @@ +import { AwsBedrockHandler } from "../bedrock" +import { BedrockRuntimeClient, ConverseStreamCommand } from "@aws-sdk/client-bedrock-runtime" +import { logger } from "../../../utils/logging" + +// Mock the AWS SDK +jest.mock("@aws-sdk/client-bedrock-runtime") +jest.mock("../../../utils/logging") + +// Store the command payload for verification +let capturedPayload: any = null + +describe("AwsBedrockHandler - Extended Thinking", () => { + let handler: AwsBedrockHandler + let mockSend: jest.Mock + + beforeEach(() => { + capturedPayload = null + mockSend = jest.fn() + + // Mock ConverseStreamCommand to capture the payload + ;(ConverseStreamCommand as unknown as jest.Mock).mockImplementation((payload) => { + capturedPayload = payload + return { + input: payload, + } + }) + ;(BedrockRuntimeClient as jest.Mock).mockImplementation(() => ({ + send: mockSend, + config: { region: "us-east-1" }, + })) + ;(logger.info as jest.Mock).mockImplementation(() => {}) + ;(logger.error as jest.Mock).mockImplementation(() => {}) + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + describe("Extended Thinking Support", () => { + it("should include thinking parameter for Claude Sonnet 4 when reasoning is enabled", async () => { + handler = new AwsBedrockHandler({ + apiProvider: "bedrock", + apiModelId: "anthropic.claude-sonnet-4-20250514-v1:0", + awsRegion: "us-east-1", + enableReasoningEffort: true, + modelMaxTokens: 8192, + modelMaxThinkingTokens: 4096, + }) + + // Mock the stream response + mockSend.mockResolvedValue({ + stream: (async function* () { + yield { + messageStart: { role: "assistant" }, + } + yield { + contentBlockStart: { + content_block: { type: "thinking", thinking: "Let me think..." }, + contentBlockIndex: 0, + }, + } + yield { + contentBlockDelta: { + delta: { type: "thinking_delta", thinking: " about this problem." }, + }, + } + yield { + contentBlockStart: { + start: { text: "Here's the answer:" }, + contentBlockIndex: 1, + }, + } + yield { + metadata: { + usage: { inputTokens: 100, outputTokens: 50 }, + }, + } + })(), + }) + + const messages = [{ role: "user" as const, content: "Test message" }] + const stream = handler.createMessage("System prompt", messages) + + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Verify the command was called with the correct payload + expect(mockSend).toHaveBeenCalledTimes(1) + expect(capturedPayload).toBeDefined() + expect(capturedPayload.additionalModelRequestFields).toBeDefined() + expect(capturedPayload.additionalModelRequestFields.thinking).toEqual({ + type: "enabled", + budget_tokens: 4096, // Uses the full modelMaxThinkingTokens value + }) + + // Verify reasoning chunks were yielded + const reasoningChunks = chunks.filter((c) => c.type === "reasoning") + expect(reasoningChunks).toHaveLength(2) + expect(reasoningChunks[0].text).toBe("Let me think...") + expect(reasoningChunks[1].text).toBe(" about this problem.") + + // Verify that topP is NOT present when thinking is enabled + expect(capturedPayload.inferenceConfig).not.toHaveProperty("topP") + }) + + it("should pass thinking parameters from metadata", async () => { + handler = new AwsBedrockHandler({ + apiProvider: "bedrock", + apiModelId: "anthropic.claude-3-7-sonnet-20250219-v1:0", + awsRegion: "us-east-1", + }) + + mockSend.mockResolvedValue({ + stream: (async function* () { + yield { messageStart: { role: "assistant" } } + yield { metadata: { usage: { inputTokens: 100, outputTokens: 50 } } } + })(), + }) + + const messages = [{ role: "user" as const, content: "Test message" }] + const metadata = { + taskId: "test-task", + thinking: { + enabled: true, + maxTokens: 16384, + maxThinkingTokens: 8192, + }, + } + + const stream = handler.createMessage("System prompt", messages, metadata) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Verify the thinking parameter was passed correctly + expect(mockSend).toHaveBeenCalledTimes(1) + expect(capturedPayload).toBeDefined() + expect(capturedPayload.additionalModelRequestFields).toBeDefined() + expect(capturedPayload.additionalModelRequestFields.thinking).toEqual({ + type: "enabled", + budget_tokens: 8192, + }) + + // Verify that topP is NOT present when thinking is enabled via metadata + expect(capturedPayload.inferenceConfig).not.toHaveProperty("topP") + }) + + it("should log when extended thinking is enabled", async () => { + handler = new AwsBedrockHandler({ + apiProvider: "bedrock", + apiModelId: "anthropic.claude-opus-4-20250514-v1:0", + awsRegion: "us-east-1", + enableReasoningEffort: true, + modelMaxThinkingTokens: 5000, + }) + + mockSend.mockResolvedValue({ + stream: (async function* () { + yield { messageStart: { role: "assistant" } } + })(), + }) + + const messages = [{ role: "user" as const, content: "Test" }] + const stream = handler.createMessage("System prompt", messages) + + for await (const chunk of stream) { + // consume stream + } + + // Verify logging + expect(logger.info).toHaveBeenCalledWith( + expect.stringContaining("Extended thinking enabled"), + expect.objectContaining({ + ctx: "bedrock", + modelId: "anthropic.claude-opus-4-20250514-v1:0", + }), + ) + }) + + it("should include topP when thinking is disabled", async () => { + handler = new AwsBedrockHandler({ + apiProvider: "bedrock", + apiModelId: "anthropic.claude-3-7-sonnet-20250219-v1:0", + awsRegion: "us-east-1", + // Note: no enableReasoningEffort = true, so thinking is disabled + }) + + mockSend.mockResolvedValue({ + stream: (async function* () { + yield { messageStart: { role: "assistant" } } + yield { + contentBlockStart: { + start: { text: "Hello" }, + contentBlockIndex: 0, + }, + } + yield { + contentBlockDelta: { + delta: { text: " world" }, + }, + } + yield { metadata: { usage: { inputTokens: 100, outputTokens: 50 } } } + })(), + }) + + const messages = [{ role: "user" as const, content: "Test message" }] + const stream = handler.createMessage("System prompt", messages) + + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Verify that topP IS present when thinking is disabled + expect(mockSend).toHaveBeenCalledTimes(1) + expect(capturedPayload).toBeDefined() + expect(capturedPayload.inferenceConfig).toHaveProperty("topP", 0.1) + + // Verify that additionalModelRequestFields is not present or empty + expect(capturedPayload.additionalModelRequestFields).toBeUndefined() + }) + + it("should enable reasoning when enableReasoningEffort is true in settings", async () => { + handler = new AwsBedrockHandler({ + apiProvider: "bedrock", + apiModelId: "anthropic.claude-sonnet-4-20250514-v1:0", + awsRegion: "us-east-1", + enableReasoningEffort: true, // This should trigger reasoning + modelMaxThinkingTokens: 4096, + }) + + mockSend.mockResolvedValue({ + stream: (async function* () { + yield { messageStart: { role: "assistant" } } + yield { + contentBlockStart: { + content_block: { type: "thinking", thinking: "Let me think..." }, + contentBlockIndex: 0, + }, + } + yield { + contentBlockDelta: { + delta: { type: "thinking_delta", thinking: " about this problem." }, + }, + } + yield { metadata: { usage: { inputTokens: 100, outputTokens: 50 } } } + })(), + }) + + const messages = [{ role: "user" as const, content: "Test message" }] + const stream = handler.createMessage("System prompt", messages) + + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Verify thinking was enabled via settings + expect(mockSend).toHaveBeenCalledTimes(1) + expect(capturedPayload).toBeDefined() + expect(capturedPayload.additionalModelRequestFields).toBeDefined() + expect(capturedPayload.additionalModelRequestFields.thinking).toEqual({ + type: "enabled", + budget_tokens: 4096, + }) + + // Verify that topP is NOT present when thinking is enabled via settings + expect(capturedPayload.inferenceConfig).not.toHaveProperty("topP") + + // Verify reasoning chunks were yielded + const reasoningChunks = chunks.filter((c) => c.type === "reasoning") + expect(reasoningChunks).toHaveLength(2) + expect(reasoningChunks[0].text).toBe("Let me think...") + expect(reasoningChunks[1].text).toBe(" about this problem.") + }) + }) +}) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 16ce3289aa..b5474cce50 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -29,6 +29,8 @@ import { logger } from "../../utils/logging" import { MultiPointStrategy } from "../transform/cache-strategy/multi-point-strategy" import { ModelInfo as CacheModelInfo } from "../transform/cache-strategy/types" import { convertToBedrockConverseMessages as sharedConverter } from "../transform/bedrock-converse-format" +import { getModelParams } from "../transform/model-params" +import { shouldUseReasoningBudget } from "../../shared/api" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" /************************************************************************************ @@ -40,8 +42,63 @@ import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from ". // Define interface for Bedrock inference config interface BedrockInferenceConfig { maxTokens: number - temperature: number - topP: number + temperature?: number + topP?: number +} + +// Define interface for Bedrock thinking configuration +interface BedrockThinkingConfig { + thinking: { + type: "enabled" + budget_tokens: number + } + [key: string]: any // Add index signature to be compatible with DocumentType +} + +// Define interface for Bedrock payload +interface BedrockPayload { + modelId: BedrockModelId | string + messages: Message[] + system?: SystemContentBlock[] + inferenceConfig: BedrockInferenceConfig + anthropic_version?: string + additionalModelRequestFields?: BedrockThinkingConfig +} + +// Define specific types for content block events to avoid 'as any' usage +// These handle the multiple possible structures returned by AWS SDK +interface ContentBlockStartEvent { + start?: { + text?: string + thinking?: string + } + contentBlockIndex?: number + // Alternative structure used by some AWS SDK versions + content_block?: { + type?: string + thinking?: string + } + // Official AWS SDK structure for reasoning (as documented) + contentBlock?: { + type?: string + thinking?: string + reasoningContent?: { + text?: string + } + } +} + +interface ContentBlockDeltaEvent { + delta?: { + text?: string + thinking?: string + type?: string + // AWS SDK structure for reasoning content deltas + reasoningContent?: { + text?: string + } + } + contentBlockIndex?: number } // Define types for stream events based on AWS SDK @@ -53,18 +110,8 @@ export interface StreamEvent { stopReason?: "end_turn" | "tool_use" | "max_tokens" | "stop_sequence" additionalModelResponseFields?: Record } - contentBlockStart?: { - start?: { - text?: string - } - contentBlockIndex?: number - } - contentBlockDelta?: { - delta?: { - text?: string - } - contentBlockIndex?: number - } + contentBlockStart?: ContentBlockStartEvent + contentBlockDelta?: ContentBlockDeltaEvent metadata?: { usage?: { inputTokens: number @@ -255,13 +302,17 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH override async *createMessage( systemPrompt: string, messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata, + metadata?: ApiHandlerCreateMessageMetadata & { + thinking?: { + enabled: boolean + maxTokens?: number + maxThinkingTokens?: number + } + }, ): ApiStream { - let modelConfig = this.getModel() - // Handle cross-region inference + const modelConfig = this.getModel() const usePromptCache = Boolean(this.options.awsUsePromptCache && this.supportsAwsPromptCache(modelConfig)) - // Generate a conversation ID based on the first few messages to maintain cache consistency const conversationId = messages.length > 0 ? `conv_${messages[0].role}_${ @@ -271,7 +322,6 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH }` : "default_conversation" - // Convert messages to Bedrock format, passing the model info and conversation ID const formatted = this.convertToBedrockConverseMessages( messages, systemPrompt, @@ -280,18 +330,50 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH conversationId, ) - // Construct the payload + let additionalModelRequestFields: BedrockThinkingConfig | undefined + let thinkingEnabled = false + + // Determine if thinking should be enabled + // metadata?.thinking?.enabled: Explicitly enabled through API metadata (direct request) + // shouldUseReasoningBudget(): Enabled through user settings (enableReasoningEffort = true) + const isThinkingExplicitlyEnabled = metadata?.thinking?.enabled + const isThinkingEnabledBySettings = + shouldUseReasoningBudget({ model: modelConfig.info, settings: this.options }) && + modelConfig.reasoning && + modelConfig.reasoningBudget + + if ((isThinkingExplicitlyEnabled || isThinkingEnabledBySettings) && modelConfig.info.supportsReasoningBudget) { + thinkingEnabled = true + additionalModelRequestFields = { + thinking: { + type: "enabled", + budget_tokens: metadata?.thinking?.maxThinkingTokens || modelConfig.reasoningBudget || 4096, + }, + } + logger.info("Extended thinking enabled for Bedrock request", { + ctx: "bedrock", + modelId: modelConfig.id, + thinking: additionalModelRequestFields.thinking, + }) + } + const inferenceConfig: BedrockInferenceConfig = { - maxTokens: modelConfig.info.maxTokens as number, - temperature: this.options.modelTemperature as number, - topP: 0.1, + maxTokens: modelConfig.maxTokens || (modelConfig.info.maxTokens as number), + temperature: modelConfig.temperature ?? (this.options.modelTemperature as number), } - const payload = { + if (!thinkingEnabled) { + inferenceConfig.topP = 0.1 + } + + const payload: BedrockPayload = { modelId: modelConfig.id, messages: formatted.messages, system: formatted.system, inferenceConfig, + ...(additionalModelRequestFields && { additionalModelRequestFields }), + // Add anthropic_version when using thinking features + ...(thinkingEnabled && { anthropic_version: "bedrock-2023-05-31" }), } // Create AbortController with 10 minute timeout @@ -397,19 +479,74 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } // Handle content blocks - if (streamEvent.contentBlockStart?.start?.text) { - yield { - type: "text", - text: streamEvent.contentBlockStart.start.text, + if (streamEvent.contentBlockStart) { + const cbStart = streamEvent.contentBlockStart + + // Check if this is a reasoning block (official AWS SDK structure) + if (cbStart.contentBlock?.reasoningContent) { + if (cbStart.contentBlockIndex && cbStart.contentBlockIndex > 0) { + yield { type: "reasoning", text: "\n" } + } + yield { + type: "reasoning", + text: cbStart.contentBlock.reasoningContent.text || "", + } + } + // Check for thinking block - handle both possible AWS SDK structures + // cbStart.contentBlock: newer/official structure + // cbStart.content_block: alternative structure seen in some AWS SDK versions + else if (cbStart.contentBlock?.type === "thinking" || cbStart.content_block?.type === "thinking") { + const contentBlock = cbStart.contentBlock || cbStart.content_block + if (cbStart.contentBlockIndex && cbStart.contentBlockIndex > 0) { + yield { type: "reasoning", text: "\n" } + } + if (contentBlock?.thinking) { + yield { + type: "reasoning", + text: contentBlock.thinking, + } + } + } else if (cbStart.start?.text) { + yield { + type: "text", + text: cbStart.start.text, + } } continue } // Handle content deltas - if (streamEvent.contentBlockDelta?.delta?.text) { - yield { - type: "text", - text: streamEvent.contentBlockDelta.delta.text, + if (streamEvent.contentBlockDelta) { + const cbDelta = streamEvent.contentBlockDelta + const delta = cbDelta.delta + + // Process reasoning and text content deltas + // Multiple structures are supported for AWS SDK compatibility: + // - delta.reasoningContent.text: official AWS docs structure for reasoning + // - delta.thinking: alternative structure for thinking content + // - delta.text: standard text content + if (delta) { + // Check for reasoningContent property (official AWS SDK structure) + if (delta.reasoningContent?.text) { + yield { + type: "reasoning", + text: delta.reasoningContent.text, + } + continue + } + + // Handle alternative thinking structure (fallback for older SDK versions) + if (delta.type === "thinking_delta" && delta.thinking) { + yield { + type: "reasoning", + text: delta.thinking, + } + } else if (delta.text) { + yield { + type: "text", + text: delta.text, + } + } } continue } @@ -444,10 +581,17 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH try { const modelConfig = this.getModel() + // For completePrompt, thinking is typically not used, but we should still check + // if thinking was somehow enabled in the model config + const thinkingEnabled = + shouldUseReasoningBudget({ model: modelConfig.info, settings: this.options }) && + modelConfig.reasoning && + modelConfig.reasoningBudget + const inferenceConfig: BedrockInferenceConfig = { - maxTokens: modelConfig.info.maxTokens as number, - temperature: this.options.modelTemperature as number, - topP: 0.1, + maxTokens: modelConfig.maxTokens || (modelConfig.info.maxTokens as number), + temperature: modelConfig.temperature ?? (this.options.modelTemperature as number), + ...(thinkingEnabled ? {} : { topP: 0.1 }), // Only set topP when thinking is NOT enabled } // For completePrompt, use a unique conversation ID based on the prompt @@ -722,9 +866,24 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH return model } - override getModel(): { id: BedrockModelId | string; info: ModelInfo } { + override getModel(): { + id: BedrockModelId | string + info: ModelInfo + maxTokens?: number + temperature?: number + reasoning?: any + reasoningBudget?: number + } { if (this.costModelConfig?.id?.trim().length > 0) { - return this.costModelConfig + // Get model params for cost model config + const params = getModelParams({ + format: "anthropic", + modelId: this.costModelConfig.id, + model: this.costModelConfig.info, + settings: this.options, + defaultTemperature: BEDROCK_DEFAULT_TEMPERATURE, + }) + return { ...this.costModelConfig, ...params } } let modelConfig = undefined @@ -752,8 +911,24 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH } } + // Get model params including reasoning configuration + const params = getModelParams({ + format: "anthropic", + modelId: modelConfig.id, + model: modelConfig.info, + settings: this.options, + defaultTemperature: BEDROCK_DEFAULT_TEMPERATURE, + }) + // Don't override maxTokens/contextWindow here; handled in getModelById (and includes user overrides) - return modelConfig as { id: BedrockModelId | string; info: ModelInfo } + return { ...modelConfig, ...params } as { + id: BedrockModelId | string + info: ModelInfo + maxTokens?: number + temperature?: number + reasoning?: any + reasoningBudget?: number + } } /************************************************************************************ @@ -905,10 +1080,33 @@ Suggestions: messageTemplate: `Invalid ARN format. ARN should follow the pattern: arn:aws:bedrock:region:account-id:resource-type/resource-name`, logLevel: "error", }, + VALIDATION_ERROR: { + patterns: [ + "input tag", + "does not match any of the expected tags", + "field required", + "validation", + "invalid parameter", + ], + messageTemplate: `Parameter validation error: {errorMessage} + +This error indicates that the request parameters don't match AWS Bedrock's expected format. + +Common causes: +1. Extended thinking parameter format is incorrect +2. Model-specific parameters are not supported by this model +3. API parameter structure has changed + +Please check: +- Model supports the requested features (extended thinking, etc.) +- Parameter format matches AWS Bedrock specification +- Model ID is correct for the requested features`, + logLevel: "error", + }, // Default/generic error GENERIC: { patterns: [], // Empty patterns array means this is the default - messageTemplate: `Unknown Error`, + messageTemplate: `Unknown Error: {errorMessage}`, logLevel: "error", }, } diff --git a/webview-ui/src/components/settings/ThinkingBudget.tsx b/webview-ui/src/components/settings/ThinkingBudget.tsx index 456e0be17a..0adb62f2a0 100644 --- a/webview-ui/src/components/settings/ThinkingBudget.tsx +++ b/webview-ui/src/components/settings/ThinkingBudget.tsx @@ -65,7 +65,11 @@ export const ThinkingBudget = ({ apiConfiguration, setApiConfigurationField, mod
setApiConfigurationField("modelMaxTokens", value)} diff --git a/webview-ui/src/components/settings/providers/Bedrock.tsx b/webview-ui/src/components/settings/providers/Bedrock.tsx index eb8ca94258..a0ebafd88e 100644 --- a/webview-ui/src/components/settings/providers/Bedrock.tsx +++ b/webview-ui/src/components/settings/providers/Bedrock.tsx @@ -108,24 +108,24 @@ export const Bedrock = ({ apiConfiguration, setApiConfigurationField, selectedMo {t("settings:providers.awsCrossRegion")} {selectedModelInfo?.supportsPromptCache && ( - -
- {t("settings:providers.enablePromptCaching")} - + <> + +
+ {t("settings:providers.enablePromptCaching")} + +
+
+
+ {t("settings:providers.cacheUsageNote")}
- + )} -
-
- {t("settings:providers.cacheUsageNote")} -
-
{