diff --git a/packages/types/src/providers/bedrock.ts b/packages/types/src/providers/bedrock.ts index ce5ea28e95..a15f041252 100644 --- a/packages/types/src/providers/bedrock.ts +++ b/packages/types/src/providers/bedrock.ts @@ -73,6 +73,7 @@ export const bedrockModels = { supportsImages: true, supportsComputerUse: true, supportsPromptCache: true, + supportsReasoningBudget: true, inputPrice: 3.0, outputPrice: 15.0, cacheWritesPrice: 3.75, @@ -87,6 +88,7 @@ export const bedrockModels = { supportsImages: true, supportsComputerUse: true, supportsPromptCache: true, + supportsReasoningBudget: true, inputPrice: 15.0, outputPrice: 75.0, cacheWritesPrice: 18.75, @@ -101,6 +103,7 @@ export const bedrockModels = { supportsImages: true, supportsComputerUse: true, supportsPromptCache: true, + supportsReasoningBudget: true, inputPrice: 3.0, outputPrice: 15.0, cacheWritesPrice: 3.75, diff --git a/src/api/providers/__tests__/bedrock-reasoning.spec.ts b/src/api/providers/__tests__/bedrock-reasoning.spec.ts new file mode 100644 index 0000000000..6652398459 --- /dev/null +++ b/src/api/providers/__tests__/bedrock-reasoning.spec.ts @@ -0,0 +1,286 @@ +import { vi, describe, it, expect, beforeEach } from "vitest" + +// Mock AWS SDK modules before importing the handler +vi.mock("@aws-sdk/credential-providers", () => ({ + fromIni: vi.fn(), +})) + +// Define a shared mock for the send function that will be used by all instances +const sharedMockSend = vi.fn() + +vi.mock("@aws-sdk/client-bedrock-runtime", () => ({ + BedrockRuntimeClient: vi.fn().mockImplementation(() => ({ + // Ensure all instances of BedrockRuntimeClient use the sharedMockSend + send: sharedMockSend, + config: { region: "us-east-1" }, + })), + ConverseStreamCommand: vi.fn(), // This will be the mock constructor for ConverseStreamCommand + ConverseCommand: vi.fn(), +})) + +// Import after mocks are set up +import { AwsBedrockHandler } from "../bedrock" +// Import ConverseStreamCommand to check its mock constructor (which is vi.fn() from the mock factory) +import { ConverseStreamCommand } from "@aws-sdk/client-bedrock-runtime" + +describe("AwsBedrockHandler - Extended Thinking", () => { + let handler: AwsBedrockHandler + // This will hold the reference to sharedMockSend for use in tests + let mockSend: typeof sharedMockSend + + const mockOptions = { + awsRegion: "us-east-1", + apiModelId: "anthropic.claude-3-7-sonnet-20241029-v1:0", + enableReasoningEffort: false, // Default to false + modelTemperature: 0.7, + } + + beforeEach(() => { + // Clear all mocks. This will clear sharedMockSend and the ConverseStreamCommand mock constructor. + vi.clearAllMocks() + // Assign the shared mock to mockSend so tests can configure it. + mockSend = sharedMockSend + + // AwsBedrockHandler will instantiate BedrockRuntimeClient, which will get the sharedMockSend. + handler = new AwsBedrockHandler(mockOptions) + }) + + describe("Extended Thinking Configuration", () => { + it("should NOT enable extended thinking by default", async () => { + // Setup mock response + mockSend.mockResolvedValue({ + stream: (async function* () { + yield { messageStart: { role: "assistant" } } + yield { + contentBlockStart: { + start: { text: "Hello" }, + contentBlockIndex: 0, + }, + } + yield { messageStop: { stopReason: "end_turn" } } + })(), + }) + + // Create message + const messages = [{ role: "user" as const, content: "Test message" }] + const stream = handler.createMessage("", messages) + + // Consume stream + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Verify the command was called + expect(ConverseStreamCommand).toHaveBeenCalled() + const payload = (ConverseStreamCommand as any).mock.calls[0][0] + + // Extended thinking should NOT be enabled by default + expect(payload.anthropic_version).toBeUndefined() + expect(payload.additionalModelRequestFields).toBeUndefined() + expect(payload.inferenceConfig.temperature).toBeDefined() + expect(payload.inferenceConfig.topP).toBeDefined() + }) + + it("should enable extended thinking when explicitly enabled with reasoning budget", async () => { + // Enable reasoning mode with thinking tokens + handler = new AwsBedrockHandler({ + ...mockOptions, + enableReasoningEffort: true, + modelMaxThinkingTokens: 5000, + }) + + // Setup mock response with thinking blocks + mockSend.mockResolvedValue({ + stream: (async function* () { + yield { messageStart: { role: "assistant" } } + yield { + contentBlockStart: { + contentBlock: { + type: "thinking", + thinking: "Let me think about this...", + }, + contentBlockIndex: 0, + }, + } + yield { + contentBlockStart: { + start: { text: "Here is my response" }, + contentBlockIndex: 1, + }, + } + yield { messageStop: { stopReason: "end_turn" } } + })(), + }) + + // Create message + const messages = [{ role: "user" as const, content: "Test message" }] + const stream = handler.createMessage("", messages) + + // Consume stream + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Verify the command was called + expect(ConverseStreamCommand).toHaveBeenCalled() + const payload = (ConverseStreamCommand as any).mock.calls[0][0] + + // Extended thinking should be enabled + expect(payload.anthropic_version).toBe("bedrock-20250514") + expect(payload.additionalModelRequestFields).toEqual({ + thinking: { + type: "enabled", + budget_tokens: 5000, + }, + }) + // Temperature and topP should be removed + expect(payload.inferenceConfig.temperature).toBeUndefined() + expect(payload.inferenceConfig.topP).toBeUndefined() + + // Verify thinking content was processed + const reasoningChunk = chunks.find((c) => c.type === "reasoning") + expect(reasoningChunk).toBeDefined() + expect(reasoningChunk?.text).toBe("Let me think about this...") + }) + + it("should NOT enable extended thinking for unsupported models", async () => { + // Use a model that doesn't support reasoning + handler = new AwsBedrockHandler({ + ...mockOptions, + apiModelId: "anthropic.claude-3-haiku-20240307-v1:0", + enableReasoningEffort: true, + modelMaxThinkingTokens: 5000, + }) + + // Setup mock response + mockSend.mockResolvedValue({ + stream: (async function* () { + yield { messageStart: { role: "assistant" } } + yield { + contentBlockStart: { + start: { text: "Hello" }, + contentBlockIndex: 0, + }, + } + yield { messageStop: { stopReason: "end_turn" } } + })(), + }) + + // Create message + const messages = [{ role: "user" as const, content: "Test message" }] + const stream = handler.createMessage("", messages) + + // Consume stream + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Verify the command was called + expect(ConverseStreamCommand).toHaveBeenCalled() + const payload = (ConverseStreamCommand as any).mock.calls[0][0] + + // Extended thinking should NOT be enabled for unsupported models + expect(payload.anthropic_version).toBeUndefined() + expect(payload.additionalModelRequestFields).toBeUndefined() + expect(payload.inferenceConfig.temperature).toBeDefined() + expect(payload.inferenceConfig.topP).toBeDefined() + }) + }) + + describe("Stream Processing", () => { + it("should handle thinking delta events", async () => { + // Enable reasoning mode + handler = new AwsBedrockHandler({ + ...mockOptions, + enableReasoningEffort: true, + modelMaxThinkingTokens: 5000, + }) + + // Setup mock response with thinking deltas + mockSend.mockResolvedValue({ + stream: (async function* () { + yield { messageStart: { role: "assistant" } } + yield { + contentBlockDelta: { + delta: { + type: "thinking_delta", + thinking: "First part of thinking...", + }, + contentBlockIndex: 0, + }, + } + yield { + contentBlockDelta: { + delta: { + type: "thinking_delta", + thinking: " Second part of thinking.", + }, + contentBlockIndex: 0, + }, + } + yield { messageStop: { stopReason: "end_turn" } } + })(), + }) + + // Create message + const messages = [{ role: "user" as const, content: "Test message" }] + const stream = handler.createMessage("", messages) + + // Consume stream + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Verify thinking deltas were processed + const reasoningChunks = chunks.filter((c) => c.type === "reasoning") + expect(reasoningChunks).toHaveLength(2) + expect(reasoningChunks[0].text).toBe("First part of thinking...") + expect(reasoningChunks[1].text).toBe(" Second part of thinking.") + }) + + it("should handle signature delta events as reasoning", async () => { + // Enable reasoning mode + handler = new AwsBedrockHandler({ + ...mockOptions, + enableReasoningEffort: true, + modelMaxThinkingTokens: 5000, + }) + + // Setup mock response with signature deltas + mockSend.mockResolvedValue({ + stream: (async function* () { + yield { messageStart: { role: "assistant" } } + yield { + contentBlockDelta: { + delta: { + type: "signature_delta", + signature: "[Signature content]", + }, + contentBlockIndex: 0, + }, + } + yield { messageStop: { stopReason: "end_turn" } } + })(), + }) + + // Create message + const messages = [{ role: "user" as const, content: "Test message" }] + const stream = handler.createMessage("", messages) + + // Consume stream + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Verify signature delta was processed as reasoning + const reasoningChunk = chunks.find((c) => c.type === "reasoning") + expect(reasoningChunk).toBeDefined() + expect(reasoningChunk?.text).toBe("[Signature content]") + }) + }) +}) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 16ce3289aa..a111ba6367 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -30,6 +30,11 @@ import { MultiPointStrategy } from "../transform/cache-strategy/multi-point-stra import { ModelInfo as CacheModelInfo } from "../transform/cache-strategy/types" import { convertToBedrockConverseMessages as sharedConverter } from "../transform/bedrock-converse-format" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { getModelParams } from "../transform/model-params" +import { shouldUseReasoningBudget } from "../../shared/api" + +// Constants for Bedrock Extended Thinking +const BEDROCK_ANTHROPIC_VERSION = "bedrock-20250514" /************************************************************************************ * @@ -40,10 +45,28 @@ import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from ". // Define interface for Bedrock inference config interface BedrockInferenceConfig { maxTokens: number - temperature: number - topP: number + temperature?: number + topP?: number +} + +// Define interface for Bedrock payload +interface BedrockPayload { + modelId: BedrockModelId | string + messages: Message[] + system?: SystemContentBlock[] + inferenceConfig: BedrockInferenceConfig + anthropic_version?: string + additionalModelRequestFields?: { + thinking?: { + type: "enabled" + budget_tokens: number + } + } } +// Import the proper stream chunk types +import { ApiStreamChunk } from "../transform/stream" + // Define types for stream events based on AWS SDK export interface StreamEvent { messageStart?: { @@ -58,10 +81,20 @@ export interface StreamEvent { text?: string } contentBlockIndex?: number + // Extended thinking support + contentBlock?: { + type?: "thinking" | "text" + thinking?: string + text?: string + } } contentBlockDelta?: { delta?: { text?: string + // Extended thinking support + type?: "thinking_delta" | "text_delta" | "signature_delta" + thinking?: string + signature?: string } contentBlockIndex?: number } @@ -119,6 +152,47 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH private client: BedrockRuntimeClient private arnInfo: any + /** + * Determines if extended thinking should be enabled based on model support and user settings + */ + private shouldEnableExtendedThinking(modelInfo: ModelInfo, params: any): boolean { + return !!( + this.options.enableReasoningEffort && + shouldUseReasoningBudget({ model: modelInfo, settings: this.options }) && + params.reasoning && + params.reasoningBudget + ) + } + + /** + * Unified stream event processor for better maintainability + */ + private *processStreamEvent(streamEvent: StreamEvent): Generator { + // Handle content blocks + if (streamEvent.contentBlockStart) { + const { contentBlock, start } = streamEvent.contentBlockStart + + if (contentBlock?.type === "thinking" && contentBlock.thinking !== undefined) { + yield { type: "reasoning", text: contentBlock.thinking || "" } + } else if (start?.text || contentBlock?.text) { + yield { type: "text", text: start?.text || contentBlock?.text || "" } + } + } + + // Handle content deltas + if (streamEvent.contentBlockDelta?.delta) { + const { delta } = streamEvent.contentBlockDelta + + if (delta.type === "thinking_delta" && delta.thinking) { + yield { type: "reasoning", text: delta.thinking } + } else if (delta.type === "signature_delta" && delta.signature) { + yield { type: "reasoning", text: delta.signature } + } else if (delta.text) { + yield { type: "text", text: delta.text } + } + } + } + constructor(options: ProviderSettings) { super() this.options = options @@ -256,6 +330,49 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH systemPrompt: string, messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const maxRetries = 3 + let retryCount = 0 + let lastError: unknown + + while (retryCount < maxRetries) { + try { + yield* this.createMessageInternal(systemPrompt, messages, metadata) + return + } catch (error) { + lastError = error + retryCount++ + + // Check if error is retryable + const errorType = this.getErrorType(error) + const retryableErrors = ["THROTTLING", "ABORT", "GENERIC"] + + if (!retryableErrors.includes(errorType) || retryCount >= maxRetries) { + // Not retryable or max retries reached + throw error + } + + // Log retry attempt + logger.info(`Retrying Bedrock request (attempt ${retryCount}/${maxRetries})`, { + ctx: "bedrock", + errorType, + errorMessage: error instanceof Error ? error.message : String(error), + }) + + // Exponential backoff: 1s, 2s, 4s + const delay = Math.pow(2, retryCount - 1) * 1000 + await new Promise((resolve) => setTimeout(resolve, delay)) + } + } + + // If we get here, all retries failed + throw lastError + } + + private async *createMessageInternal( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { let modelConfig = this.getModel() // Handle cross-region inference @@ -280,20 +397,56 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH conversationId, ) + // Get model parameters including reasoning configuration + const params = getModelParams({ + format: "anthropic", + modelId: modelConfig.id as string, + model: modelConfig.info, + settings: this.options, + }) + // Construct the payload const inferenceConfig: BedrockInferenceConfig = { - maxTokens: modelConfig.info.maxTokens as number, - temperature: this.options.modelTemperature as number, + maxTokens: params.maxTokens || (modelConfig.info.maxTokens as number), + temperature: params.temperature || (this.options.modelTemperature as number), topP: 0.1, } - const payload = { + // Build the base payload + const payload: BedrockPayload = { modelId: modelConfig.id, messages: formatted.messages, system: formatted.system, inferenceConfig, } + // Add extended thinking support ONLY if explicitly enabled by the user + // Reasoning is disabled by default as per requirements + if (this.shouldEnableExtendedThinking(modelConfig.info, params) && params.reasoningBudget) { + // Add the anthropic_version field required for extended thinking + payload.anthropic_version = BEDROCK_ANTHROPIC_VERSION + + // Add additionalModelRequestFields with thinking configuration + payload.additionalModelRequestFields = { + thinking: { + type: "enabled", + budget_tokens: params.reasoningBudget, + }, + } + + // Remove temperature, topP, and top_k when thinking is enabled as they are incompatible + // AWS Bedrock requires these parameters to be undefined when using extended thinking + // as the thinking process uses its own internal temperature and sampling parameters + delete inferenceConfig.temperature + delete inferenceConfig.topP + + logger.info("Extended thinking enabled for Bedrock request", { + ctx: "bedrock", + modelId: modelConfig.id, + budgetTokens: params.reasoningBudget, + }) + } + // Create AbortController with 10 minute timeout const controller = new AbortController() let timeoutId: NodeJS.Timeout | undefined @@ -396,23 +549,8 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH continue } - // Handle content blocks - if (streamEvent.contentBlockStart?.start?.text) { - yield { - type: "text", - text: streamEvent.contentBlockStart.start.text, - } - continue - } - - // Handle content deltas - if (streamEvent.contentBlockDelta?.delta?.text) { - yield { - type: "text", - text: streamEvent.contentBlockDelta.delta.text, - } - continue - } + // Use unified stream event processor + yield* this.processStreamEvent(streamEvent) // Handle message stop if (streamEvent.messageStop) { continue @@ -509,7 +647,8 @@ export class AwsBedrockHandler extends BaseProvider implements SingleCompletionH conversationId?: string, // Optional conversation ID to track cache points across messages ): { system: SystemContentBlock[]; messages: Message[] } { // First convert messages using shared converter for proper image handling - const convertedMessages = sharedConverter(anthropicMessages as Anthropic.Messages.MessageParam[]) + let convertedMessages = sharedConverter(anthropicMessages as Anthropic.Messages.MessageParam[]) + // If prompt caching is disabled, return the converted messages directly if (!usePromptCache) { @@ -905,10 +1044,22 @@ Suggestions: messageTemplate: `Invalid ARN format. ARN should follow the pattern: arn:aws:bedrock:region:account-id:resource-type/resource-name`, logLevel: "error", }, + THINKING_NOT_SUPPORTED: { + patterns: ["thinking", "reasoning", "additionalmodelrequestfields"], + messageTemplate: `Extended thinking/reasoning is not supported for this model or configuration. + +Please verify: +1. You're using a supported model (Claude 3.7 Sonnet, Claude 4 Sonnet, or Claude 4 Opus) +2. Your AWS region supports extended thinking +3. You have the necessary permissions to use this feature + +If the issue persists, try disabling "Enable Reasoning Mode" in the settings.`, + logLevel: "error", + }, // Default/generic error GENERIC: { patterns: [], // Empty patterns array means this is the default - messageTemplate: `Unknown Error`, + messageTemplate: `Bedrock is unable to process your request. Please check your configuration and try again.`, logLevel: "error", }, }