diff --git a/src/api/providers/__tests__/chutes.spec.ts b/src/api/providers/__tests__/chutes.spec.ts index c67515cb7f..e8b3e53688 100644 --- a/src/api/providers/__tests__/chutes.spec.ts +++ b/src/api/providers/__tests__/chutes.spec.ts @@ -1,33 +1,64 @@ // npx vitest run api/providers/__tests__/chutes.spec.ts -import { vitest, describe, it, expect, beforeEach } from "vitest" -import OpenAI from "openai" import { Anthropic } from "@anthropic-ai/sdk" +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest" +import OpenAI from "openai" -import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types" +import { type ChutesModelId, chutesDefaultModelId, chutesModels, DEEP_SEEK_DEFAULT_TEMPERATURE } from "@roo-code/types" import { ChutesHandler } from "../chutes" -const mockCreate = vitest.fn() +// Create mock functions +const mockCreate = vi.fn() -vitest.mock("openai", () => { - return { - default: vitest.fn().mockImplementation(() => ({ - chat: { - completions: { - create: mockCreate, - }, +// Mock OpenAI module +vi.mock("openai", () => ({ + default: vi.fn(() => ({ + chat: { + completions: { + create: mockCreate, }, - })), - } -}) + }, + })), +})) describe("ChutesHandler", () => { let handler: ChutesHandler beforeEach(() => { - vitest.clearAllMocks() - handler = new ChutesHandler({ chutesApiKey: "test-chutes-api-key" }) + vi.clearAllMocks() + // Set up default mock implementation + mockCreate.mockImplementation(async () => ({ + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { content: "Test response" }, + index: 0, + }, + ], + usage: null, + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + }, + })) + handler = new ChutesHandler({ chutesApiKey: "test-key" }) + }) + + afterEach(() => { + vi.restoreAllMocks() }) it("should use the correct Chutes base URL", () => { @@ -41,18 +72,96 @@ describe("ChutesHandler", () => { expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: chutesApiKey })) }) + it("should handle DeepSeek R1 reasoning format", async () => { + // Override the mock for this specific test + mockCreate.mockImplementationOnce(async () => ({ + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { content: "Thinking..." }, + index: 0, + }, + ], + usage: null, + } + yield { + choices: [ + { + delta: { content: "Hello" }, + index: 0, + }, + ], + usage: null, + } + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + usage: { prompt_tokens: 10, completion_tokens: 5 }, + } + }, + })) + + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }] + vi.spyOn(handler, "getModel").mockReturnValue({ + id: "deepseek-ai/DeepSeek-R1-0528", + info: { maxTokens: 1024, temperature: 0.7 }, + } as any) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks).toEqual([ + { type: "reasoning", text: "Thinking..." }, + { type: "text", text: "Hello" }, + { type: "usage", inputTokens: 10, outputTokens: 5 }, + ]) + }) + + it("should fall back to base provider for non-DeepSeek models", async () => { + // Use default mock implementation which returns text content + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }] + vi.spyOn(handler, "getModel").mockReturnValue({ + id: "some-other-model", + info: { maxTokens: 1024, temperature: 0.7 }, + } as any) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks).toEqual([ + { type: "text", text: "Test response" }, + { type: "usage", inputTokens: 10, outputTokens: 5 }, + ]) + }) + it("should return default model when no model is specified", () => { const model = handler.getModel() expect(model.id).toBe(chutesDefaultModelId) - expect(model.info).toEqual(chutesModels[chutesDefaultModelId]) + expect(model.info).toEqual(expect.objectContaining(chutesModels[chutesDefaultModelId])) }) it("should return specified model when valid model is provided", () => { const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1" - const handlerWithModel = new ChutesHandler({ apiModelId: testModelId, chutesApiKey: "test-chutes-api-key" }) + const handlerWithModel = new ChutesHandler({ + apiModelId: testModelId, + chutesApiKey: "test-chutes-api-key", + }) const model = handlerWithModel.getModel() expect(model.id).toBe(testModelId) - expect(model.info).toEqual(chutesModels[testModelId]) + expect(model.info).toEqual(expect.objectContaining(chutesModels[testModelId])) }) it("completePrompt method should return text from Chutes API", async () => { @@ -74,7 +183,7 @@ describe("ChutesHandler", () => { mockCreate.mockImplementationOnce(() => { return { [Symbol.asyncIterator]: () => ({ - next: vitest + next: vi .fn() .mockResolvedValueOnce({ done: false, @@ -96,7 +205,7 @@ describe("ChutesHandler", () => { mockCreate.mockImplementationOnce(() => { return { [Symbol.asyncIterator]: () => ({ - next: vitest + next: vi .fn() .mockResolvedValueOnce({ done: false, @@ -114,8 +223,43 @@ describe("ChutesHandler", () => { expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) }) - it("createMessage should pass correct parameters to Chutes client", async () => { + it("createMessage should pass correct parameters to Chutes client for DeepSeek R1", async () => { const modelId: ChutesModelId = "deepseek-ai/DeepSeek-R1" + + // Clear previous mocks and set up new implementation + mockCreate.mockClear() + mockCreate.mockImplementationOnce(async () => ({ + [Symbol.asyncIterator]: async function* () { + // Empty stream for this test + }, + })) + + const handlerWithModel = new ChutesHandler({ + apiModelId: modelId, + chutesApiKey: "test-chutes-api-key", + }) + + const systemPrompt = "Test system prompt for Chutes" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Chutes" }] + + const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: modelId, + messages: [ + { + role: "user", + content: `${systemPrompt}\n${messages[0].content}`, + }, + ], + }), + ) + }) + + it("createMessage should pass correct parameters to Chutes client for non-DeepSeek models", async () => { + const modelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct" const modelInfo = chutesModels[modelId] const handlerWithModel = new ChutesHandler({ apiModelId: modelId, chutesApiKey: "test-chutes-api-key" }) @@ -146,4 +290,24 @@ describe("ChutesHandler", () => { }), ) }) + + it("should apply DeepSeek default temperature for R1 models", () => { + const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1" + const handlerWithModel = new ChutesHandler({ + apiModelId: testModelId, + chutesApiKey: "test-chutes-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.info.temperature).toBe(DEEP_SEEK_DEFAULT_TEMPERATURE) + }) + + it("should use default temperature for non-DeepSeek models", () => { + const testModelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct" + const handlerWithModel = new ChutesHandler({ + apiModelId: testModelId, + chutesApiKey: "test-chutes-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.info.temperature).toBe(0.5) + }) }) diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts index bf1f3c35a8..f196b5f309 100644 --- a/src/api/providers/base-openai-compatible-provider.ts +++ b/src/api/providers/base-openai-compatible-provider.ts @@ -31,7 +31,7 @@ export abstract class BaseOpenAiCompatibleProvider protected readonly options: ApiHandlerOptions - private client: OpenAI + protected client: OpenAI constructor({ providerName, diff --git a/src/api/providers/chutes.ts b/src/api/providers/chutes.ts index 0fa8741fa3..62121bd19d 100644 --- a/src/api/providers/chutes.ts +++ b/src/api/providers/chutes.ts @@ -1,6 +1,12 @@ -import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types" +import { DEEP_SEEK_DEFAULT_TEMPERATURE, type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types" +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" import type { ApiHandlerOptions } from "../../shared/api" +import { XmlMatcher } from "../../utils/xml-matcher" +import { convertToR1Format } from "../transform/r1-format" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { ApiStream } from "../transform/stream" import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" @@ -16,4 +22,82 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider { defaultTemperature: 0.5, }) } + + private getCompletionParams( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + ): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming { + const { + id: model, + info: { maxTokens: max_tokens }, + } = this.getModel() + + const temperature = this.options.modelTemperature ?? this.getModel().info.temperature + + return { + model, + max_tokens, + temperature, + messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], + stream: true, + stream_options: { include_usage: true }, + } + } + + override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + const model = this.getModel() + + if (model.id.includes("DeepSeek-R1")) { + const stream = await this.client.chat.completions.create({ + ...this.getCompletionParams(systemPrompt, messages), + messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]), + }) + + const matcher = new XmlMatcher( + "think", + (chunk) => + ({ + type: chunk.matched ? "reasoning" : "text", + text: chunk.data, + }) as const, + ) + + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + + if (delta?.content) { + for (const processedChunk of matcher.update(delta.content)) { + yield processedChunk + } + } + + if (chunk.usage) { + yield { + type: "usage", + inputTokens: chunk.usage.prompt_tokens || 0, + outputTokens: chunk.usage.completion_tokens || 0, + } + } + } + + // Process any remaining content + for (const processedChunk of matcher.final()) { + yield processedChunk + } + } else { + yield* super.createMessage(systemPrompt, messages) + } + } + + override getModel() { + const model = super.getModel() + const isDeepSeekR1 = model.id.includes("DeepSeek-R1") + return { + ...model, + info: { + ...model.info, + temperature: isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : this.defaultTemperature, + }, + } + } }