Skip to content

Commit 6763824

Browse files
authored
Merge pull request #1167 from RooVetGit/cte/claude-3.7-thinking
Claude 3.7 thinking
2 parents 8ae4fb0 + 37cc993 commit 6763824

24 files changed

+1157
-1050
lines changed

package-lock.json

Lines changed: 5 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@
304304
},
305305
"dependencies": {
306306
"@anthropic-ai/bedrock-sdk": "^0.10.2",
307-
"@anthropic-ai/sdk": "^0.26.0",
307+
"@anthropic-ai/sdk": "^0.37.0",
308308
"@anthropic-ai/vertex-sdk": "^0.4.1",
309309
"@aws-sdk/client-bedrock-runtime": "^3.706.0",
310310
"@google/generative-ai": "^0.18.0",

src/api/providers/__tests__/anthropic.test.ts

Lines changed: 18 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,13 @@
1+
// npx jest src/api/providers/__tests__/anthropic.test.ts
2+
13
import { AnthropicHandler } from "../anthropic"
24
import { ApiHandlerOptions } from "../../../shared/api"
3-
import { ApiStream } from "../../transform/stream"
4-
import { Anthropic } from "@anthropic-ai/sdk"
55

6-
// Mock Anthropic client
7-
const mockBetaCreate = jest.fn()
86
const mockCreate = jest.fn()
7+
98
jest.mock("@anthropic-ai/sdk", () => {
109
return {
1110
Anthropic: jest.fn().mockImplementation(() => ({
12-
beta: {
13-
promptCaching: {
14-
messages: {
15-
create: mockBetaCreate.mockImplementation(async () => ({
16-
async *[Symbol.asyncIterator]() {
17-
yield {
18-
type: "message_start",
19-
message: {
20-
usage: {
21-
input_tokens: 100,
22-
output_tokens: 50,
23-
cache_creation_input_tokens: 20,
24-
cache_read_input_tokens: 10,
25-
},
26-
},
27-
}
28-
yield {
29-
type: "content_block_start",
30-
index: 0,
31-
content_block: {
32-
type: "text",
33-
text: "Hello",
34-
},
35-
}
36-
yield {
37-
type: "content_block_delta",
38-
delta: {
39-
type: "text_delta",
40-
text: " world",
41-
},
42-
}
43-
},
44-
})),
45-
},
46-
},
47-
},
4811
messages: {
4912
create: mockCreate.mockImplementation(async (options) => {
5013
if (!options.stream) {
@@ -65,16 +28,26 @@ jest.mock("@anthropic-ai/sdk", () => {
6528
type: "message_start",
6629
message: {
6730
usage: {
68-
input_tokens: 10,
69-
output_tokens: 5,
31+
input_tokens: 100,
32+
output_tokens: 50,
33+
cache_creation_input_tokens: 20,
34+
cache_read_input_tokens: 10,
7035
},
7136
},
7237
}
7338
yield {
7439
type: "content_block_start",
40+
index: 0,
7541
content_block: {
7642
type: "text",
77-
text: "Test response",
43+
text: "Hello",
44+
},
45+
}
46+
yield {
47+
type: "content_block_delta",
48+
delta: {
49+
type: "text_delta",
50+
text: " world",
7851
},
7952
}
8053
},
@@ -95,7 +68,6 @@ describe("AnthropicHandler", () => {
9568
apiModelId: "claude-3-5-sonnet-20241022",
9669
}
9770
handler = new AnthropicHandler(mockOptions)
98-
mockBetaCreate.mockClear()
9971
mockCreate.mockClear()
10072
})
10173

@@ -126,17 +98,6 @@ describe("AnthropicHandler", () => {
12698

12799
describe("createMessage", () => {
128100
const systemPrompt = "You are a helpful assistant."
129-
const messages: Anthropic.Messages.MessageParam[] = [
130-
{
131-
role: "user",
132-
content: [
133-
{
134-
type: "text" as const,
135-
text: "Hello!",
136-
},
137-
],
138-
},
139-
]
140101

141102
it("should handle prompt caching for supported models", async () => {
142103
const stream = handler.createMessage(systemPrompt, [
@@ -173,9 +134,8 @@ describe("AnthropicHandler", () => {
173134
expect(textChunks[0].text).toBe("Hello")
174135
expect(textChunks[1].text).toBe(" world")
175136

176-
// Verify beta API was used
177-
expect(mockBetaCreate).toHaveBeenCalled()
178-
expect(mockCreate).not.toHaveBeenCalled()
137+
// Verify API
138+
expect(mockCreate).toHaveBeenCalled()
179139
})
180140
})
181141

src/api/providers/anthropic.ts

Lines changed: 63 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import { Anthropic } from "@anthropic-ai/sdk"
22
import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming"
3+
import { CacheControlEphemeral } from "@anthropic-ai/sdk/resources"
4+
import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta"
35
import {
46
anthropicDefaultModelId,
57
AnthropicModelId,
@@ -12,60 +14,73 @@ import { ApiStream } from "../transform/stream"
1214

1315
const ANTHROPIC_DEFAULT_TEMPERATURE = 0
1416

17+
const THINKING_MODELS = ["claude-3-7-sonnet-20250219"]
18+
1519
export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
1620
private options: ApiHandlerOptions
1721
private client: Anthropic
1822

1923
constructor(options: ApiHandlerOptions) {
2024
this.options = options
25+
2126
this.client = new Anthropic({
2227
apiKey: this.options.apiKey,
2328
baseURL: this.options.anthropicBaseUrl || undefined,
2429
})
2530
}
2631

2732
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
28-
let stream: AnthropicStream<Anthropic.Beta.PromptCaching.Messages.RawPromptCachingBetaMessageStreamEvent>
33+
let stream: AnthropicStream<Anthropic.Messages.RawMessageStreamEvent>
34+
const cacheControl: CacheControlEphemeral = { type: "ephemeral" }
2935
const modelId = this.getModel().id
36+
const maxTokens = this.getModel().info.maxTokens || 8192
37+
let temperature = this.options.modelTemperature ?? ANTHROPIC_DEFAULT_TEMPERATURE
38+
let thinking: BetaThinkingConfigParam | undefined = undefined
39+
40+
if (THINKING_MODELS.includes(modelId)) {
41+
thinking = this.options.anthropicThinking
42+
? { type: "enabled", budget_tokens: this.options.anthropicThinking }
43+
: { type: "disabled" }
44+
45+
temperature = 1.0
46+
}
3047

3148
switch (modelId) {
32-
// 'latest' alias does not support cache_control
3349
case "claude-3-7-sonnet-20250219":
3450
case "claude-3-5-sonnet-20241022":
3551
case "claude-3-5-haiku-20241022":
3652
case "claude-3-opus-20240229":
3753
case "claude-3-haiku-20240307": {
38-
/*
39-
The latest message will be the new user message, one before will be the assistant message from a previous request, and the user message before that will be a previously cached user message. So we need to mark the latest user message as ephemeral to cache it for the next request, and mark the second to last user message as ephemeral to let the server know the last message to retrieve from the cache for the current request..
40-
*/
54+
/**
55+
* The latest message will be the new user message, one before will
56+
* be the assistant message from a previous request, and the user message before that will be a previously cached user message. So we need to mark the latest user message as ephemeral to cache it for the next request, and mark the second to last user message as ephemeral to let the server know the last message to retrieve from the cache for the current request..
57+
*/
4158
const userMsgIndices = messages.reduce(
4259
(acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
4360
[] as number[],
4461
)
62+
4563
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
4664
const secondLastMsgUserIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1
47-
stream = await this.client.beta.promptCaching.messages.create(
65+
66+
stream = await this.client.messages.create(
4867
{
4968
model: modelId,
50-
max_tokens: this.getModel().info.maxTokens || 8192,
51-
temperature: this.options.modelTemperature ?? ANTHROPIC_DEFAULT_TEMPERATURE,
52-
system: [{ text: systemPrompt, type: "text", cache_control: { type: "ephemeral" } }], // setting cache breakpoint for system prompt so new tasks can reuse it
69+
max_tokens: maxTokens,
70+
temperature,
71+
thinking,
72+
// Setting cache breakpoint for system prompt so new tasks can reuse it.
73+
system: [{ text: systemPrompt, type: "text", cache_control: cacheControl }],
5374
messages: messages.map((message, index) => {
5475
if (index === lastUserMsgIndex || index === secondLastMsgUserIndex) {
5576
return {
5677
...message,
5778
content:
5879
typeof message.content === "string"
59-
? [
60-
{
61-
type: "text",
62-
text: message.content,
63-
cache_control: { type: "ephemeral" },
64-
},
65-
]
80+
? [{ type: "text", text: message.content, cache_control: cacheControl }]
6681
: message.content.map((content, contentIndex) =>
6782
contentIndex === message.content.length - 1
68-
? { ...content, cache_control: { type: "ephemeral" } }
83+
? { ...content, cache_control: cacheControl }
6984
: content,
7085
),
7186
}
@@ -114,54 +129,63 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
114129
for await (const chunk of stream) {
115130
switch (chunk.type) {
116131
case "message_start":
117-
// tells us cache reads/writes/input/output
132+
// Tells us cache reads/writes/input/output.
118133
const usage = chunk.message.usage
134+
119135
yield {
120136
type: "usage",
121137
inputTokens: usage.input_tokens || 0,
122138
outputTokens: usage.output_tokens || 0,
123139
cacheWriteTokens: usage.cache_creation_input_tokens || undefined,
124140
cacheReadTokens: usage.cache_read_input_tokens || undefined,
125141
}
142+
126143
break
127144
case "message_delta":
128-
// tells us stop_reason, stop_sequence, and output tokens along the way and at the end of the message
129-
145+
// Tells us stop_reason, stop_sequence, and output tokens
146+
// along the way and at the end of the message.
130147
yield {
131148
type: "usage",
132149
inputTokens: 0,
133150
outputTokens: chunk.usage.output_tokens || 0,
134151
}
152+
135153
break
136154
case "message_stop":
137-
// no usage data, just an indicator that the message is done
155+
// No usage data, just an indicator that the message is done.
138156
break
139157
case "content_block_start":
140158
switch (chunk.content_block.type) {
141-
case "text":
142-
// we may receive multiple text blocks, in which case just insert a line break between them
159+
case "thinking":
160+
// We may receive multiple text blocks, in which
161+
// case just insert a line break between them.
143162
if (chunk.index > 0) {
144-
yield {
145-
type: "text",
146-
text: "\n",
147-
}
163+
yield { type: "reasoning", text: "\n" }
148164
}
149-
yield {
150-
type: "text",
151-
text: chunk.content_block.text,
165+
166+
yield { type: "reasoning", text: chunk.content_block.thinking }
167+
break
168+
case "text":
169+
// We may receive multiple text blocks, in which
170+
// case just insert a line break between them.
171+
if (chunk.index > 0) {
172+
yield { type: "text", text: "\n" }
152173
}
174+
175+
yield { type: "text", text: chunk.content_block.text }
153176
break
154177
}
155178
break
156179
case "content_block_delta":
157180
switch (chunk.delta.type) {
181+
case "thinking_delta":
182+
yield { type: "reasoning", text: chunk.delta.thinking }
183+
break
158184
case "text_delta":
159-
yield {
160-
type: "text",
161-
text: chunk.delta.text,
162-
}
185+
yield { type: "text", text: chunk.delta.text }
163186
break
164187
}
188+
165189
break
166190
case "content_block_stop":
167191
break
@@ -171,10 +195,12 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
171195

172196
getModel(): { id: AnthropicModelId; info: ModelInfo } {
173197
const modelId = this.options.apiModelId
198+
174199
if (modelId && modelId in anthropicModels) {
175200
const id = modelId as AnthropicModelId
176201
return { id, info: anthropicModels[id] }
177202
}
203+
178204
return { id: anthropicDefaultModelId, info: anthropicModels[anthropicDefaultModelId] }
179205
}
180206

@@ -189,14 +215,17 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
189215
})
190216

191217
const content = response.content[0]
218+
192219
if (content.type === "text") {
193220
return content.text
194221
}
222+
195223
return ""
196224
} catch (error) {
197225
if (error instanceof Error) {
198226
throw new Error(`Anthropic completion error: ${error.message}`)
199227
}
228+
200229
throw error
201230
}
202231
}

src/api/providers/bedrock.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import { Anthropic } from "@anthropic-ai/sdk"
99
import { ApiHandler, SingleCompletionHandler } from "../"
1010
import { ApiHandlerOptions, BedrockModelId, ModelInfo, bedrockDefaultModelId, bedrockModels } from "../../shared/api"
1111
import { ApiStream } from "../transform/stream"
12-
import { convertToBedrockConverseMessages, convertToAnthropicMessage } from "../transform/bedrock-converse-format"
12+
import { convertToBedrockConverseMessages } from "../transform/bedrock-converse-format"
1313

1414
const BEDROCK_DEFAULT_TEMPERATURE = 0.3
1515

0 commit comments

Comments
 (0)