Skip to content

Commit 7f173f2

Browse files
committed
feat: Add Chutes DeepSeek support (fixes #4506)
1 parent c03869b commit 7f173f2

File tree

4 files changed

+217
-143
lines changed

4 files changed

+217
-143
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import { Anthropic } from "@anthropic-ai/sdk"
2+
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"
3+
import OpenAI from "openai"
4+
5+
import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
6+
7+
import { ChutesHandler } from "../chutes"
8+
import * as chutesModule from "../chutes"
9+
10+
// Mock the entire module
11+
vi.mock("../chutes", async () => {
12+
const actual = await vi.importActual<typeof chutesModule>("../chutes")
13+
return {
14+
...actual,
15+
ChutesHandler: class extends actual.ChutesHandler {
16+
constructor(options: any) {
17+
super(options)
18+
this.client = {
19+
chat: {
20+
completions: {
21+
create: vi.fn(),
22+
},
23+
},
24+
} as any
25+
}
26+
},
27+
}
28+
})
29+
30+
describe("ChutesHandler", () => {
31+
let handler: ChutesHandler
32+
let mockCreate: any
33+
34+
beforeEach(() => {
35+
handler = new ChutesHandler({ chutesApiKey: "test-key" })
36+
mockCreate = vi.spyOn((handler as any).client.chat.completions, "create")
37+
})
38+
39+
afterEach(() => {
40+
vi.restoreAllMocks()
41+
})
42+
43+
it("should handle DeepSeek R1 reasoning format", async () => {
44+
const mockStream = (async function* () {
45+
yield { choices: [{ delta: { reasoning: "Thinking..." } }] }
46+
yield { choices: [{ delta: { content: "Hello" } }] }
47+
yield { usage: { prompt_tokens: 10, completion_tokens: 5 } }
48+
})()
49+
50+
mockCreate.mockResolvedValue(mockStream)
51+
52+
const systemPrompt = "You are a helpful assistant."
53+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
54+
vi.spyOn(handler, "getModel").mockReturnValue({
55+
id: "deepseek-ai/DeepSeek-R1-0528",
56+
info: { maxTokens: 1024, temperature: 0.7 },
57+
} as any)
58+
59+
const stream = handler.createMessage(systemPrompt, messages)
60+
const chunks = []
61+
for await (const chunk of stream) {
62+
chunks.push(chunk)
63+
}
64+
65+
expect(chunks).toEqual([
66+
{ type: "reasoning", text: "Thinking..." },
67+
{ type: "text", text: "Hello" },
68+
{ type: "usage", inputTokens: 10, outputTokens: 5 },
69+
])
70+
})
71+
72+
it("should fall back to base provider for non-DeepSeek models", async () => {
73+
const mockStream = (async function* () {
74+
yield { choices: [{ delta: { content: "Hello" } }] }
75+
yield { usage: { prompt_tokens: 10, completion_tokens: 5 } }
76+
})()
77+
78+
mockCreate.mockResolvedValue(mockStream)
79+
80+
const systemPrompt = "You are a helpful assistant."
81+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
82+
vi.spyOn(handler, "getModel").mockReturnValue({
83+
id: "some-other-model",
84+
info: { maxTokens: 1024, temperature: 0.7 },
85+
} as any)
86+
87+
const stream = handler.createMessage(systemPrompt, messages)
88+
const chunks = []
89+
for await (const chunk of stream) {
90+
chunks.push(chunk)
91+
}
92+
93+
expect(chunks).toEqual([
94+
{ type: "text", text: "Hello" },
95+
{ type: "usage", inputTokens: 10, outputTokens: 5 },
96+
])
97+
})
98+
99+
it("should return default model when no model is specified", () => {
100+
const model = handler.getModel()
101+
expect(model.id).toBe(chutesDefaultModelId)
102+
expect(model.info).toEqual(expect.objectContaining(chutesModels[chutesDefaultModelId]))
103+
})
104+
105+
it("should return specified model when valid model is provided", () => {
106+
const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
107+
const handlerWithModel = new ChutesHandler({
108+
apiModelId: testModelId,
109+
chutesApiKey: "test-chutes-api-key",
110+
})
111+
const model = handlerWithModel.getModel()
112+
expect(model.id).toBe(testModelId)
113+
expect(model.info).toEqual(expect.objectContaining(chutesModels[testModelId]))
114+
})
115+
116+
it("createMessage should pass correct parameters to Chutes client for DeepSeek R1", async () => {
117+
const modelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
118+
const handlerWithModel = new ChutesHandler({
119+
apiModelId: modelId,
120+
chutesApiKey: "test-chutes-api-key",
121+
})
122+
123+
const mockStream = (async function* () {})()
124+
mockCreate.mockResolvedValue(mockStream)
125+
126+
const systemPrompt = "Test system prompt for Chutes"
127+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Chutes" }]
128+
129+
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
130+
await messageGenerator.next()
131+
132+
expect(mockCreate).toHaveBeenCalledWith(
133+
expect.objectContaining({
134+
model: modelId,
135+
messages: [
136+
{
137+
role: "user",
138+
content: `${systemPrompt}\n${messages[0].content}`,
139+
},
140+
],
141+
}),
142+
)
143+
})
144+
})

src/api/providers/__tests__/chutes.test.ts

Lines changed: 0 additions & 141 deletions
This file was deleted.

src/api/providers/base-openai-compatible-provider.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
3131

3232
protected readonly options: ApiHandlerOptions
3333

34-
private client: OpenAI
34+
protected client: OpenAI
3535

3636
constructor({
3737
providerName,

src/api/providers/chutes.ts

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
1+
import { DEEP_SEEK_DEFAULT_TEMPERATURE, type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
2+
import { Anthropic } from "@anthropic-ai/sdk"
3+
import OpenAI from "openai"
24

35
import type { ApiHandlerOptions } from "../../shared/api"
6+
import { convertToR1Format } from "../transform/r1-format"
7+
import { convertToOpenAiMessages } from "../transform/openai-format"
8+
import { ApiStream } from "../transform/stream"
49

510
import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
611

@@ -16,4 +21,70 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider<ChutesModelId> {
1621
defaultTemperature: 0.5,
1722
})
1823
}
24+
25+
private getCompletionParams(
26+
systemPrompt: string,
27+
messages: Anthropic.Messages.MessageParam[],
28+
): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming {
29+
const {
30+
id: model,
31+
info: { maxTokens: max_tokens },
32+
} = this.getModel()
33+
34+
const temperature = this.options.modelTemperature ?? this.defaultTemperature
35+
36+
return {
37+
model,
38+
max_tokens,
39+
temperature,
40+
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
41+
stream: true,
42+
stream_options: { include_usage: true },
43+
}
44+
}
45+
46+
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
47+
const model = this.getModel()
48+
49+
if (model.id.startsWith("deepseek-ai/DeepSeek-R1")) {
50+
const stream = await this.client.chat.completions.create({
51+
...this.getCompletionParams(systemPrompt, messages),
52+
messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
53+
})
54+
55+
for await (const chunk of stream) {
56+
const delta = chunk.choices[0]?.delta
57+
58+
if ("reasoning" in delta && delta.reasoning && typeof delta.reasoning === "string") {
59+
yield { type: "reasoning", text: delta.reasoning }
60+
}
61+
62+
if (delta?.content) {
63+
yield { type: "text", text: delta.content }
64+
}
65+
66+
if (chunk.usage) {
67+
yield {
68+
type: "usage",
69+
inputTokens: chunk.usage.prompt_tokens || 0,
70+
outputTokens: chunk.usage.completion_tokens || 0,
71+
}
72+
}
73+
}
74+
} else {
75+
yield* super.createMessage(systemPrompt, messages)
76+
}
77+
}
78+
79+
override getModel() {
80+
const model = super.getModel()
81+
const isDeepSeekR1 = model.id.startsWith("deepseek-ai/DeepSeek-R1")
82+
return {
83+
...model,
84+
info: {
85+
...model.info,
86+
temperature: isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : this.defaultTemperature,
87+
},
88+
}
89+
}
1990
}

0 commit comments

Comments
 (0)