Skip to content

Commit 358a19d

Browse files
hannesrudolphdaniel-lxs
authored andcommitted
feat: Add DeepSeek R1 support to Chutes provider (#4523) (#4525)
* feat: Add DeepSeek R1 support to Chutes provider (#4523) - Modified BaseOpenAiCompatibleProvider to expose client as protected - Enhanced ChutesHandler to detect DeepSeek R1 models and parse reasoning chunks - Applied R1 format conversion for message formatting - Set appropriate temperature (0.6) for DeepSeek models - Migrated tests from Jest to Vitest format - Added comprehensive tests for DeepSeek R1 functionality This ensures reasoning chunks are properly separated from regular content when using DeepSeek R1 models via Chutes provider. * feat: Enhance DeepSeek R1 support with <think> tag handling in Chutes provider * fix: Correct temperature retrieval in ChutesHandler to use model's info * fix: Update condition for DeepSeek-R1 model identification in createMessage method --------- Co-authored-by: Daniel Riccio <[email protected]>
1 parent c4b97c4 commit 358a19d

File tree

3 files changed

+272
-24
lines changed

3 files changed

+272
-24
lines changed

src/api/providers/__tests__/chutes.spec.ts

Lines changed: 186 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,64 @@
11
// npx vitest run api/providers/__tests__/chutes.spec.ts
22

3-
import { vitest, describe, it, expect, beforeEach } from "vitest"
4-
import OpenAI from "openai"
53
import { Anthropic } from "@anthropic-ai/sdk"
4+
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"
5+
import OpenAI from "openai"
66

7-
import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
7+
import { type ChutesModelId, chutesDefaultModelId, chutesModels, DEEP_SEEK_DEFAULT_TEMPERATURE } from "@roo-code/types"
88

99
import { ChutesHandler } from "../chutes"
1010

11-
const mockCreate = vitest.fn()
11+
// Create mock functions
12+
const mockCreate = vi.fn()
1213

13-
vitest.mock("openai", () => {
14-
return {
15-
default: vitest.fn().mockImplementation(() => ({
16-
chat: {
17-
completions: {
18-
create: mockCreate,
19-
},
14+
// Mock OpenAI module
15+
vi.mock("openai", () => ({
16+
default: vi.fn(() => ({
17+
chat: {
18+
completions: {
19+
create: mockCreate,
2020
},
21-
})),
22-
}
23-
})
21+
},
22+
})),
23+
}))
2424

2525
describe("ChutesHandler", () => {
2626
let handler: ChutesHandler
2727

2828
beforeEach(() => {
29-
vitest.clearAllMocks()
30-
handler = new ChutesHandler({ chutesApiKey: "test-chutes-api-key" })
29+
vi.clearAllMocks()
30+
// Set up default mock implementation
31+
mockCreate.mockImplementation(async () => ({
32+
[Symbol.asyncIterator]: async function* () {
33+
yield {
34+
choices: [
35+
{
36+
delta: { content: "Test response" },
37+
index: 0,
38+
},
39+
],
40+
usage: null,
41+
}
42+
yield {
43+
choices: [
44+
{
45+
delta: {},
46+
index: 0,
47+
},
48+
],
49+
usage: {
50+
prompt_tokens: 10,
51+
completion_tokens: 5,
52+
total_tokens: 15,
53+
},
54+
}
55+
},
56+
}))
57+
handler = new ChutesHandler({ chutesApiKey: "test-key" })
58+
})
59+
60+
afterEach(() => {
61+
vi.restoreAllMocks()
3162
})
3263

3364
it("should use the correct Chutes base URL", () => {
@@ -41,18 +72,96 @@ describe("ChutesHandler", () => {
4172
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: chutesApiKey }))
4273
})
4374

75+
it("should handle DeepSeek R1 reasoning format", async () => {
76+
// Override the mock for this specific test
77+
mockCreate.mockImplementationOnce(async () => ({
78+
[Symbol.asyncIterator]: async function* () {
79+
yield {
80+
choices: [
81+
{
82+
delta: { content: "<think>Thinking..." },
83+
index: 0,
84+
},
85+
],
86+
usage: null,
87+
}
88+
yield {
89+
choices: [
90+
{
91+
delta: { content: "</think>Hello" },
92+
index: 0,
93+
},
94+
],
95+
usage: null,
96+
}
97+
yield {
98+
choices: [
99+
{
100+
delta: {},
101+
index: 0,
102+
},
103+
],
104+
usage: { prompt_tokens: 10, completion_tokens: 5 },
105+
}
106+
},
107+
}))
108+
109+
const systemPrompt = "You are a helpful assistant."
110+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
111+
vi.spyOn(handler, "getModel").mockReturnValue({
112+
id: "deepseek-ai/DeepSeek-R1-0528",
113+
info: { maxTokens: 1024, temperature: 0.7 },
114+
} as any)
115+
116+
const stream = handler.createMessage(systemPrompt, messages)
117+
const chunks = []
118+
for await (const chunk of stream) {
119+
chunks.push(chunk)
120+
}
121+
122+
expect(chunks).toEqual([
123+
{ type: "reasoning", text: "Thinking..." },
124+
{ type: "text", text: "Hello" },
125+
{ type: "usage", inputTokens: 10, outputTokens: 5 },
126+
])
127+
})
128+
129+
it("should fall back to base provider for non-DeepSeek models", async () => {
130+
// Use default mock implementation which returns text content
131+
const systemPrompt = "You are a helpful assistant."
132+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
133+
vi.spyOn(handler, "getModel").mockReturnValue({
134+
id: "some-other-model",
135+
info: { maxTokens: 1024, temperature: 0.7 },
136+
} as any)
137+
138+
const stream = handler.createMessage(systemPrompt, messages)
139+
const chunks = []
140+
for await (const chunk of stream) {
141+
chunks.push(chunk)
142+
}
143+
144+
expect(chunks).toEqual([
145+
{ type: "text", text: "Test response" },
146+
{ type: "usage", inputTokens: 10, outputTokens: 5 },
147+
])
148+
})
149+
44150
it("should return default model when no model is specified", () => {
45151
const model = handler.getModel()
46152
expect(model.id).toBe(chutesDefaultModelId)
47-
expect(model.info).toEqual(chutesModels[chutesDefaultModelId])
153+
expect(model.info).toEqual(expect.objectContaining(chutesModels[chutesDefaultModelId]))
48154
})
49155

50156
it("should return specified model when valid model is provided", () => {
51157
const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
52-
const handlerWithModel = new ChutesHandler({ apiModelId: testModelId, chutesApiKey: "test-chutes-api-key" })
158+
const handlerWithModel = new ChutesHandler({
159+
apiModelId: testModelId,
160+
chutesApiKey: "test-chutes-api-key",
161+
})
53162
const model = handlerWithModel.getModel()
54163
expect(model.id).toBe(testModelId)
55-
expect(model.info).toEqual(chutesModels[testModelId])
164+
expect(model.info).toEqual(expect.objectContaining(chutesModels[testModelId]))
56165
})
57166

58167
it("completePrompt method should return text from Chutes API", async () => {
@@ -74,7 +183,7 @@ describe("ChutesHandler", () => {
74183
mockCreate.mockImplementationOnce(() => {
75184
return {
76185
[Symbol.asyncIterator]: () => ({
77-
next: vitest
186+
next: vi
78187
.fn()
79188
.mockResolvedValueOnce({
80189
done: false,
@@ -96,7 +205,7 @@ describe("ChutesHandler", () => {
96205
mockCreate.mockImplementationOnce(() => {
97206
return {
98207
[Symbol.asyncIterator]: () => ({
99-
next: vitest
208+
next: vi
100209
.fn()
101210
.mockResolvedValueOnce({
102211
done: false,
@@ -114,8 +223,43 @@ describe("ChutesHandler", () => {
114223
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
115224
})
116225

117-
it("createMessage should pass correct parameters to Chutes client", async () => {
226+
it("createMessage should pass correct parameters to Chutes client for DeepSeek R1", async () => {
118227
const modelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
228+
229+
// Clear previous mocks and set up new implementation
230+
mockCreate.mockClear()
231+
mockCreate.mockImplementationOnce(async () => ({
232+
[Symbol.asyncIterator]: async function* () {
233+
// Empty stream for this test
234+
},
235+
}))
236+
237+
const handlerWithModel = new ChutesHandler({
238+
apiModelId: modelId,
239+
chutesApiKey: "test-chutes-api-key",
240+
})
241+
242+
const systemPrompt = "Test system prompt for Chutes"
243+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Chutes" }]
244+
245+
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
246+
await messageGenerator.next()
247+
248+
expect(mockCreate).toHaveBeenCalledWith(
249+
expect.objectContaining({
250+
model: modelId,
251+
messages: [
252+
{
253+
role: "user",
254+
content: `${systemPrompt}\n${messages[0].content}`,
255+
},
256+
],
257+
}),
258+
)
259+
})
260+
261+
it("createMessage should pass correct parameters to Chutes client for non-DeepSeek models", async () => {
262+
const modelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct"
119263
const modelInfo = chutesModels[modelId]
120264
const handlerWithModel = new ChutesHandler({ apiModelId: modelId, chutesApiKey: "test-chutes-api-key" })
121265

@@ -146,4 +290,24 @@ describe("ChutesHandler", () => {
146290
}),
147291
)
148292
})
293+
294+
it("should apply DeepSeek default temperature for R1 models", () => {
295+
const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
296+
const handlerWithModel = new ChutesHandler({
297+
apiModelId: testModelId,
298+
chutesApiKey: "test-chutes-api-key",
299+
})
300+
const model = handlerWithModel.getModel()
301+
expect(model.info.temperature).toBe(DEEP_SEEK_DEFAULT_TEMPERATURE)
302+
})
303+
304+
it("should use default temperature for non-DeepSeek models", () => {
305+
const testModelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct"
306+
const handlerWithModel = new ChutesHandler({
307+
apiModelId: testModelId,
308+
chutesApiKey: "test-chutes-api-key",
309+
})
310+
const model = handlerWithModel.getModel()
311+
expect(model.info.temperature).toBe(0.5)
312+
})
149313
})

src/api/providers/base-openai-compatible-provider.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
3131

3232
protected readonly options: ApiHandlerOptions
3333

34-
private client: OpenAI
34+
protected client: OpenAI
3535

3636
constructor({
3737
providerName,

src/api/providers/chutes.ts

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1-
import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
1+
import { DEEP_SEEK_DEFAULT_TEMPERATURE, type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
2+
import { Anthropic } from "@anthropic-ai/sdk"
3+
import OpenAI from "openai"
24

35
import type { ApiHandlerOptions } from "../../shared/api"
6+
import { XmlMatcher } from "../../utils/xml-matcher"
7+
import { convertToR1Format } from "../transform/r1-format"
8+
import { convertToOpenAiMessages } from "../transform/openai-format"
9+
import { ApiStream } from "../transform/stream"
410

511
import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
612

@@ -16,4 +22,82 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider<ChutesModelId> {
1622
defaultTemperature: 0.5,
1723
})
1824
}
25+
26+
private getCompletionParams(
27+
systemPrompt: string,
28+
messages: Anthropic.Messages.MessageParam[],
29+
): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming {
30+
const {
31+
id: model,
32+
info: { maxTokens: max_tokens },
33+
} = this.getModel()
34+
35+
const temperature = this.options.modelTemperature ?? this.getModel().info.temperature
36+
37+
return {
38+
model,
39+
max_tokens,
40+
temperature,
41+
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
42+
stream: true,
43+
stream_options: { include_usage: true },
44+
}
45+
}
46+
47+
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
48+
const model = this.getModel()
49+
50+
if (model.id.includes("DeepSeek-R1")) {
51+
const stream = await this.client.chat.completions.create({
52+
...this.getCompletionParams(systemPrompt, messages),
53+
messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
54+
})
55+
56+
const matcher = new XmlMatcher(
57+
"think",
58+
(chunk) =>
59+
({
60+
type: chunk.matched ? "reasoning" : "text",
61+
text: chunk.data,
62+
}) as const,
63+
)
64+
65+
for await (const chunk of stream) {
66+
const delta = chunk.choices[0]?.delta
67+
68+
if (delta?.content) {
69+
for (const processedChunk of matcher.update(delta.content)) {
70+
yield processedChunk
71+
}
72+
}
73+
74+
if (chunk.usage) {
75+
yield {
76+
type: "usage",
77+
inputTokens: chunk.usage.prompt_tokens || 0,
78+
outputTokens: chunk.usage.completion_tokens || 0,
79+
}
80+
}
81+
}
82+
83+
// Process any remaining content
84+
for (const processedChunk of matcher.final()) {
85+
yield processedChunk
86+
}
87+
} else {
88+
yield* super.createMessage(systemPrompt, messages)
89+
}
90+
}
91+
92+
override getModel() {
93+
const model = super.getModel()
94+
const isDeepSeekR1 = model.id.includes("DeepSeek-R1")
95+
return {
96+
...model,
97+
info: {
98+
...model.info,
99+
temperature: isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : this.defaultTemperature,
100+
},
101+
}
102+
}
19103
}

0 commit comments

Comments
 (0)