Skip to content

Commit fb3a728

Browse files
feat: Add reasoning budget support to Bedrock models for extended thinking (#4201) (#4481)
* Add reasoning budget support to Bedrock models and update related components - Introduced `supportsReasoningBudget` property in Bedrock models. - Enhanced `AwsBedrockHandler` to handle reasoning budget in payloads. - Updated `ThinkingBudget` component to dynamically set max tokens based on reasoning support. - Modified `ApiOptions` and `Bedrock` components to conditionally render `ThinkingBudget`. - Added tests for extended thinking functionality in `bedrock-reasoning.test.ts`. * Add BedrockThinkingConfig interface and update payload structure * fix: address PR review feedback (#4481) - Simplify ThinkingBudget ternary logic since component only renders when reasoning budget supported - Break down complex thinking enabled condition with clear documentation - Replace 'as any' usage with proper TypeScript interfaces for AWS SDK events - Add comprehensive documentation for multiple stream structures explaining AWS SDK compatibility
1 parent 7bed944 commit fb3a728

File tree

6 files changed

+549
-62
lines changed

6 files changed

+549
-62
lines changed

packages/types/src/providers/bedrock.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ export const bedrockModels = {
7373
supportsImages: true,
7474
supportsComputerUse: true,
7575
supportsPromptCache: true,
76+
supportsReasoningBudget: true,
7677
inputPrice: 3.0,
7778
outputPrice: 15.0,
7879
cacheWritesPrice: 3.75,
@@ -87,6 +88,7 @@ export const bedrockModels = {
8788
supportsImages: true,
8889
supportsComputerUse: true,
8990
supportsPromptCache: true,
91+
supportsReasoningBudget: true,
9092
inputPrice: 15.0,
9193
outputPrice: 75.0,
9294
cacheWritesPrice: 18.75,
@@ -101,6 +103,7 @@ export const bedrockModels = {
101103
supportsImages: true,
102104
supportsComputerUse: true,
103105
supportsPromptCache: true,
106+
supportsReasoningBudget: true,
104107
inputPrice: 3.0,
105108
outputPrice: 15.0,
106109
cacheWritesPrice: 3.75,
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
import { AwsBedrockHandler } from "../bedrock"
2+
import { BedrockRuntimeClient, ConverseStreamCommand } from "@aws-sdk/client-bedrock-runtime"
3+
import { logger } from "../../../utils/logging"
4+
5+
// Mock the AWS SDK
6+
jest.mock("@aws-sdk/client-bedrock-runtime")
7+
jest.mock("../../../utils/logging")
8+
9+
// Store the command payload for verification
10+
let capturedPayload: any = null
11+
12+
describe("AwsBedrockHandler - Extended Thinking", () => {
13+
let handler: AwsBedrockHandler
14+
let mockSend: jest.Mock
15+
16+
beforeEach(() => {
17+
capturedPayload = null
18+
mockSend = jest.fn()
19+
20+
// Mock ConverseStreamCommand to capture the payload
21+
;(ConverseStreamCommand as unknown as jest.Mock).mockImplementation((payload) => {
22+
capturedPayload = payload
23+
return {
24+
input: payload,
25+
}
26+
})
27+
;(BedrockRuntimeClient as jest.Mock).mockImplementation(() => ({
28+
send: mockSend,
29+
config: { region: "us-east-1" },
30+
}))
31+
;(logger.info as jest.Mock).mockImplementation(() => {})
32+
;(logger.error as jest.Mock).mockImplementation(() => {})
33+
})
34+
35+
afterEach(() => {
36+
jest.clearAllMocks()
37+
})
38+
39+
describe("Extended Thinking Support", () => {
40+
it("should include thinking parameter for Claude Sonnet 4 when reasoning is enabled", async () => {
41+
handler = new AwsBedrockHandler({
42+
apiProvider: "bedrock",
43+
apiModelId: "anthropic.claude-sonnet-4-20250514-v1:0",
44+
awsRegion: "us-east-1",
45+
enableReasoningEffort: true,
46+
modelMaxTokens: 8192,
47+
modelMaxThinkingTokens: 4096,
48+
})
49+
50+
// Mock the stream response
51+
mockSend.mockResolvedValue({
52+
stream: (async function* () {
53+
yield {
54+
messageStart: { role: "assistant" },
55+
}
56+
yield {
57+
contentBlockStart: {
58+
content_block: { type: "thinking", thinking: "Let me think..." },
59+
contentBlockIndex: 0,
60+
},
61+
}
62+
yield {
63+
contentBlockDelta: {
64+
delta: { type: "thinking_delta", thinking: " about this problem." },
65+
},
66+
}
67+
yield {
68+
contentBlockStart: {
69+
start: { text: "Here's the answer:" },
70+
contentBlockIndex: 1,
71+
},
72+
}
73+
yield {
74+
metadata: {
75+
usage: { inputTokens: 100, outputTokens: 50 },
76+
},
77+
}
78+
})(),
79+
})
80+
81+
const messages = [{ role: "user" as const, content: "Test message" }]
82+
const stream = handler.createMessage("System prompt", messages)
83+
84+
const chunks = []
85+
for await (const chunk of stream) {
86+
chunks.push(chunk)
87+
}
88+
89+
// Verify the command was called with the correct payload
90+
expect(mockSend).toHaveBeenCalledTimes(1)
91+
expect(capturedPayload).toBeDefined()
92+
expect(capturedPayload.additionalModelRequestFields).toBeDefined()
93+
expect(capturedPayload.additionalModelRequestFields.thinking).toEqual({
94+
type: "enabled",
95+
budget_tokens: 4096, // Uses the full modelMaxThinkingTokens value
96+
})
97+
98+
// Verify reasoning chunks were yielded
99+
const reasoningChunks = chunks.filter((c) => c.type === "reasoning")
100+
expect(reasoningChunks).toHaveLength(2)
101+
expect(reasoningChunks[0].text).toBe("Let me think...")
102+
expect(reasoningChunks[1].text).toBe(" about this problem.")
103+
104+
// Verify that topP is NOT present when thinking is enabled
105+
expect(capturedPayload.inferenceConfig).not.toHaveProperty("topP")
106+
})
107+
108+
it("should pass thinking parameters from metadata", async () => {
109+
handler = new AwsBedrockHandler({
110+
apiProvider: "bedrock",
111+
apiModelId: "anthropic.claude-3-7-sonnet-20250219-v1:0",
112+
awsRegion: "us-east-1",
113+
})
114+
115+
mockSend.mockResolvedValue({
116+
stream: (async function* () {
117+
yield { messageStart: { role: "assistant" } }
118+
yield { metadata: { usage: { inputTokens: 100, outputTokens: 50 } } }
119+
})(),
120+
})
121+
122+
const messages = [{ role: "user" as const, content: "Test message" }]
123+
const metadata = {
124+
taskId: "test-task",
125+
thinking: {
126+
enabled: true,
127+
maxTokens: 16384,
128+
maxThinkingTokens: 8192,
129+
},
130+
}
131+
132+
const stream = handler.createMessage("System prompt", messages, metadata)
133+
const chunks = []
134+
for await (const chunk of stream) {
135+
chunks.push(chunk)
136+
}
137+
138+
// Verify the thinking parameter was passed correctly
139+
expect(mockSend).toHaveBeenCalledTimes(1)
140+
expect(capturedPayload).toBeDefined()
141+
expect(capturedPayload.additionalModelRequestFields).toBeDefined()
142+
expect(capturedPayload.additionalModelRequestFields.thinking).toEqual({
143+
type: "enabled",
144+
budget_tokens: 8192,
145+
})
146+
147+
// Verify that topP is NOT present when thinking is enabled via metadata
148+
expect(capturedPayload.inferenceConfig).not.toHaveProperty("topP")
149+
})
150+
151+
it("should log when extended thinking is enabled", async () => {
152+
handler = new AwsBedrockHandler({
153+
apiProvider: "bedrock",
154+
apiModelId: "anthropic.claude-opus-4-20250514-v1:0",
155+
awsRegion: "us-east-1",
156+
enableReasoningEffort: true,
157+
modelMaxThinkingTokens: 5000,
158+
})
159+
160+
mockSend.mockResolvedValue({
161+
stream: (async function* () {
162+
yield { messageStart: { role: "assistant" } }
163+
})(),
164+
})
165+
166+
const messages = [{ role: "user" as const, content: "Test" }]
167+
const stream = handler.createMessage("System prompt", messages)
168+
169+
for await (const chunk of stream) {
170+
// consume stream
171+
}
172+
173+
// Verify logging
174+
expect(logger.info).toHaveBeenCalledWith(
175+
expect.stringContaining("Extended thinking enabled"),
176+
expect.objectContaining({
177+
ctx: "bedrock",
178+
modelId: "anthropic.claude-opus-4-20250514-v1:0",
179+
}),
180+
)
181+
})
182+
183+
it("should include topP when thinking is disabled", async () => {
184+
handler = new AwsBedrockHandler({
185+
apiProvider: "bedrock",
186+
apiModelId: "anthropic.claude-3-7-sonnet-20250219-v1:0",
187+
awsRegion: "us-east-1",
188+
// Note: no enableReasoningEffort = true, so thinking is disabled
189+
})
190+
191+
mockSend.mockResolvedValue({
192+
stream: (async function* () {
193+
yield { messageStart: { role: "assistant" } }
194+
yield {
195+
contentBlockStart: {
196+
start: { text: "Hello" },
197+
contentBlockIndex: 0,
198+
},
199+
}
200+
yield {
201+
contentBlockDelta: {
202+
delta: { text: " world" },
203+
},
204+
}
205+
yield { metadata: { usage: { inputTokens: 100, outputTokens: 50 } } }
206+
})(),
207+
})
208+
209+
const messages = [{ role: "user" as const, content: "Test message" }]
210+
const stream = handler.createMessage("System prompt", messages)
211+
212+
const chunks = []
213+
for await (const chunk of stream) {
214+
chunks.push(chunk)
215+
}
216+
217+
// Verify that topP IS present when thinking is disabled
218+
expect(mockSend).toHaveBeenCalledTimes(1)
219+
expect(capturedPayload).toBeDefined()
220+
expect(capturedPayload.inferenceConfig).toHaveProperty("topP", 0.1)
221+
222+
// Verify that additionalModelRequestFields is not present or empty
223+
expect(capturedPayload.additionalModelRequestFields).toBeUndefined()
224+
})
225+
226+
it("should enable reasoning when enableReasoningEffort is true in settings", async () => {
227+
handler = new AwsBedrockHandler({
228+
apiProvider: "bedrock",
229+
apiModelId: "anthropic.claude-sonnet-4-20250514-v1:0",
230+
awsRegion: "us-east-1",
231+
enableReasoningEffort: true, // This should trigger reasoning
232+
modelMaxThinkingTokens: 4096,
233+
})
234+
235+
mockSend.mockResolvedValue({
236+
stream: (async function* () {
237+
yield { messageStart: { role: "assistant" } }
238+
yield {
239+
contentBlockStart: {
240+
content_block: { type: "thinking", thinking: "Let me think..." },
241+
contentBlockIndex: 0,
242+
},
243+
}
244+
yield {
245+
contentBlockDelta: {
246+
delta: { type: "thinking_delta", thinking: " about this problem." },
247+
},
248+
}
249+
yield { metadata: { usage: { inputTokens: 100, outputTokens: 50 } } }
250+
})(),
251+
})
252+
253+
const messages = [{ role: "user" as const, content: "Test message" }]
254+
const stream = handler.createMessage("System prompt", messages)
255+
256+
const chunks = []
257+
for await (const chunk of stream) {
258+
chunks.push(chunk)
259+
}
260+
261+
// Verify thinking was enabled via settings
262+
expect(mockSend).toHaveBeenCalledTimes(1)
263+
expect(capturedPayload).toBeDefined()
264+
expect(capturedPayload.additionalModelRequestFields).toBeDefined()
265+
expect(capturedPayload.additionalModelRequestFields.thinking).toEqual({
266+
type: "enabled",
267+
budget_tokens: 4096,
268+
})
269+
270+
// Verify that topP is NOT present when thinking is enabled via settings
271+
expect(capturedPayload.inferenceConfig).not.toHaveProperty("topP")
272+
273+
// Verify reasoning chunks were yielded
274+
const reasoningChunks = chunks.filter((c) => c.type === "reasoning")
275+
expect(reasoningChunks).toHaveLength(2)
276+
expect(reasoningChunks[0].text).toBe("Let me think...")
277+
expect(reasoningChunks[1].text).toBe(" about this problem.")
278+
})
279+
})
280+
})

0 commit comments

Comments
 (0)