diff --git a/src/api/providers/__tests__/chutes.spec.ts b/src/api/providers/__tests__/chutes.spec.ts
index c67515cb7f..e8b3e53688 100644
--- a/src/api/providers/__tests__/chutes.spec.ts
+++ b/src/api/providers/__tests__/chutes.spec.ts
@@ -1,33 +1,64 @@
// npx vitest run api/providers/__tests__/chutes.spec.ts
-import { vitest, describe, it, expect, beforeEach } from "vitest"
-import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"
+import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"
+import OpenAI from "openai"
-import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
+import { type ChutesModelId, chutesDefaultModelId, chutesModels, DEEP_SEEK_DEFAULT_TEMPERATURE } from "@roo-code/types"
import { ChutesHandler } from "../chutes"
-const mockCreate = vitest.fn()
+// Create mock functions
+const mockCreate = vi.fn()
-vitest.mock("openai", () => {
- return {
- default: vitest.fn().mockImplementation(() => ({
- chat: {
- completions: {
- create: mockCreate,
- },
+// Mock OpenAI module
+vi.mock("openai", () => ({
+ default: vi.fn(() => ({
+ chat: {
+ completions: {
+ create: mockCreate,
},
- })),
- }
-})
+ },
+ })),
+}))
describe("ChutesHandler", () => {
let handler: ChutesHandler
beforeEach(() => {
- vitest.clearAllMocks()
- handler = new ChutesHandler({ chutesApiKey: "test-chutes-api-key" })
+ vi.clearAllMocks()
+ // Set up default mock implementation
+ mockCreate.mockImplementation(async () => ({
+ [Symbol.asyncIterator]: async function* () {
+ yield {
+ choices: [
+ {
+ delta: { content: "Test response" },
+ index: 0,
+ },
+ ],
+ usage: null,
+ }
+ yield {
+ choices: [
+ {
+ delta: {},
+ index: 0,
+ },
+ ],
+ usage: {
+ prompt_tokens: 10,
+ completion_tokens: 5,
+ total_tokens: 15,
+ },
+ }
+ },
+ }))
+ handler = new ChutesHandler({ chutesApiKey: "test-key" })
+ })
+
+ afterEach(() => {
+ vi.restoreAllMocks()
})
it("should use the correct Chutes base URL", () => {
@@ -41,18 +72,96 @@ describe("ChutesHandler", () => {
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: chutesApiKey }))
})
+ it("should handle DeepSeek R1 reasoning format", async () => {
+ // Override the mock for this specific test
+ mockCreate.mockImplementationOnce(async () => ({
+ [Symbol.asyncIterator]: async function* () {
+ yield {
+ choices: [
+ {
+ delta: { content: "Thinking..." },
+ index: 0,
+ },
+ ],
+ usage: null,
+ }
+ yield {
+ choices: [
+ {
+ delta: { content: "Hello" },
+ index: 0,
+ },
+ ],
+ usage: null,
+ }
+ yield {
+ choices: [
+ {
+ delta: {},
+ index: 0,
+ },
+ ],
+ usage: { prompt_tokens: 10, completion_tokens: 5 },
+ }
+ },
+ }))
+
+ const systemPrompt = "You are a helpful assistant."
+ const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
+ vi.spyOn(handler, "getModel").mockReturnValue({
+ id: "deepseek-ai/DeepSeek-R1-0528",
+ info: { maxTokens: 1024, temperature: 0.7 },
+ } as any)
+
+ const stream = handler.createMessage(systemPrompt, messages)
+ const chunks = []
+ for await (const chunk of stream) {
+ chunks.push(chunk)
+ }
+
+ expect(chunks).toEqual([
+ { type: "reasoning", text: "Thinking..." },
+ { type: "text", text: "Hello" },
+ { type: "usage", inputTokens: 10, outputTokens: 5 },
+ ])
+ })
+
+ it("should fall back to base provider for non-DeepSeek models", async () => {
+ // Use default mock implementation which returns text content
+ const systemPrompt = "You are a helpful assistant."
+ const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
+ vi.spyOn(handler, "getModel").mockReturnValue({
+ id: "some-other-model",
+ info: { maxTokens: 1024, temperature: 0.7 },
+ } as any)
+
+ const stream = handler.createMessage(systemPrompt, messages)
+ const chunks = []
+ for await (const chunk of stream) {
+ chunks.push(chunk)
+ }
+
+ expect(chunks).toEqual([
+ { type: "text", text: "Test response" },
+ { type: "usage", inputTokens: 10, outputTokens: 5 },
+ ])
+ })
+
it("should return default model when no model is specified", () => {
const model = handler.getModel()
expect(model.id).toBe(chutesDefaultModelId)
- expect(model.info).toEqual(chutesModels[chutesDefaultModelId])
+ expect(model.info).toEqual(expect.objectContaining(chutesModels[chutesDefaultModelId]))
})
it("should return specified model when valid model is provided", () => {
const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
- const handlerWithModel = new ChutesHandler({ apiModelId: testModelId, chutesApiKey: "test-chutes-api-key" })
+ const handlerWithModel = new ChutesHandler({
+ apiModelId: testModelId,
+ chutesApiKey: "test-chutes-api-key",
+ })
const model = handlerWithModel.getModel()
expect(model.id).toBe(testModelId)
- expect(model.info).toEqual(chutesModels[testModelId])
+ expect(model.info).toEqual(expect.objectContaining(chutesModels[testModelId]))
})
it("completePrompt method should return text from Chutes API", async () => {
@@ -74,7 +183,7 @@ describe("ChutesHandler", () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
- next: vitest
+ next: vi
.fn()
.mockResolvedValueOnce({
done: false,
@@ -96,7 +205,7 @@ describe("ChutesHandler", () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
- next: vitest
+ next: vi
.fn()
.mockResolvedValueOnce({
done: false,
@@ -114,8 +223,43 @@ describe("ChutesHandler", () => {
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
})
- it("createMessage should pass correct parameters to Chutes client", async () => {
+ it("createMessage should pass correct parameters to Chutes client for DeepSeek R1", async () => {
const modelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
+
+ // Clear previous mocks and set up new implementation
+ mockCreate.mockClear()
+ mockCreate.mockImplementationOnce(async () => ({
+ [Symbol.asyncIterator]: async function* () {
+ // Empty stream for this test
+ },
+ }))
+
+ const handlerWithModel = new ChutesHandler({
+ apiModelId: modelId,
+ chutesApiKey: "test-chutes-api-key",
+ })
+
+ const systemPrompt = "Test system prompt for Chutes"
+ const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Chutes" }]
+
+ const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
+ await messageGenerator.next()
+
+ expect(mockCreate).toHaveBeenCalledWith(
+ expect.objectContaining({
+ model: modelId,
+ messages: [
+ {
+ role: "user",
+ content: `${systemPrompt}\n${messages[0].content}`,
+ },
+ ],
+ }),
+ )
+ })
+
+ it("createMessage should pass correct parameters to Chutes client for non-DeepSeek models", async () => {
+ const modelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct"
const modelInfo = chutesModels[modelId]
const handlerWithModel = new ChutesHandler({ apiModelId: modelId, chutesApiKey: "test-chutes-api-key" })
@@ -146,4 +290,24 @@ describe("ChutesHandler", () => {
}),
)
})
+
+ it("should apply DeepSeek default temperature for R1 models", () => {
+ const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
+ const handlerWithModel = new ChutesHandler({
+ apiModelId: testModelId,
+ chutesApiKey: "test-chutes-api-key",
+ })
+ const model = handlerWithModel.getModel()
+ expect(model.info.temperature).toBe(DEEP_SEEK_DEFAULT_TEMPERATURE)
+ })
+
+ it("should use default temperature for non-DeepSeek models", () => {
+ const testModelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct"
+ const handlerWithModel = new ChutesHandler({
+ apiModelId: testModelId,
+ chutesApiKey: "test-chutes-api-key",
+ })
+ const model = handlerWithModel.getModel()
+ expect(model.info.temperature).toBe(0.5)
+ })
})
diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts
index bf1f3c35a8..f196b5f309 100644
--- a/src/api/providers/base-openai-compatible-provider.ts
+++ b/src/api/providers/base-openai-compatible-provider.ts
@@ -31,7 +31,7 @@ export abstract class BaseOpenAiCompatibleProvider
protected readonly options: ApiHandlerOptions
- private client: OpenAI
+ protected client: OpenAI
constructor({
providerName,
diff --git a/src/api/providers/chutes.ts b/src/api/providers/chutes.ts
index 0fa8741fa3..62121bd19d 100644
--- a/src/api/providers/chutes.ts
+++ b/src/api/providers/chutes.ts
@@ -1,6 +1,12 @@
-import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
+import { DEEP_SEEK_DEFAULT_TEMPERATURE, type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
+import { Anthropic } from "@anthropic-ai/sdk"
+import OpenAI from "openai"
import type { ApiHandlerOptions } from "../../shared/api"
+import { XmlMatcher } from "../../utils/xml-matcher"
+import { convertToR1Format } from "../transform/r1-format"
+import { convertToOpenAiMessages } from "../transform/openai-format"
+import { ApiStream } from "../transform/stream"
import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
@@ -16,4 +22,82 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider {
defaultTemperature: 0.5,
})
}
+
+ private getCompletionParams(
+ systemPrompt: string,
+ messages: Anthropic.Messages.MessageParam[],
+ ): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming {
+ const {
+ id: model,
+ info: { maxTokens: max_tokens },
+ } = this.getModel()
+
+ const temperature = this.options.modelTemperature ?? this.getModel().info.temperature
+
+ return {
+ model,
+ max_tokens,
+ temperature,
+ messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
+ stream: true,
+ stream_options: { include_usage: true },
+ }
+ }
+
+ override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
+ const model = this.getModel()
+
+ if (model.id.includes("DeepSeek-R1")) {
+ const stream = await this.client.chat.completions.create({
+ ...this.getCompletionParams(systemPrompt, messages),
+ messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
+ })
+
+ const matcher = new XmlMatcher(
+ "think",
+ (chunk) =>
+ ({
+ type: chunk.matched ? "reasoning" : "text",
+ text: chunk.data,
+ }) as const,
+ )
+
+ for await (const chunk of stream) {
+ const delta = chunk.choices[0]?.delta
+
+ if (delta?.content) {
+ for (const processedChunk of matcher.update(delta.content)) {
+ yield processedChunk
+ }
+ }
+
+ if (chunk.usage) {
+ yield {
+ type: "usage",
+ inputTokens: chunk.usage.prompt_tokens || 0,
+ outputTokens: chunk.usage.completion_tokens || 0,
+ }
+ }
+ }
+
+ // Process any remaining content
+ for (const processedChunk of matcher.final()) {
+ yield processedChunk
+ }
+ } else {
+ yield* super.createMessage(systemPrompt, messages)
+ }
+ }
+
+ override getModel() {
+ const model = super.getModel()
+ const isDeepSeekR1 = model.id.includes("DeepSeek-R1")
+ return {
+ ...model,
+ info: {
+ ...model.info,
+ temperature: isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : this.defaultTemperature,
+ },
+ }
+ }
}