diff --git a/src/api/providers/__tests__/chutes.spec.ts b/src/api/providers/__tests__/chutes.spec.ts new file mode 100644 index 0000000000..2205f4234f --- /dev/null +++ b/src/api/providers/__tests__/chutes.spec.ts @@ -0,0 +1,144 @@ +import { Anthropic } from "@anthropic-ai/sdk" +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest" +import OpenAI from "openai" + +import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types" + +import { ChutesHandler } from "../chutes" +import * as chutesModule from "../chutes" + +// Mock the entire module +vi.mock("../chutes", async () => { + const actual = await vi.importActual("../chutes") + return { + ...actual, + ChutesHandler: class extends actual.ChutesHandler { + constructor(options: any) { + super(options) + this.client = { + chat: { + completions: { + create: vi.fn(), + }, + }, + } as any + } + }, + } +}) + +describe("ChutesHandler", () => { + let handler: ChutesHandler + let mockCreate: any + + beforeEach(() => { + handler = new ChutesHandler({ chutesApiKey: "test-key" }) + mockCreate = vi.spyOn((handler as any).client.chat.completions, "create") + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + it("should handle DeepSeek R1 reasoning format", async () => { + const mockStream = (async function* () { + yield { choices: [{ delta: { reasoning: "Thinking..." } }] } + yield { choices: [{ delta: { content: "Hello" } }] } + yield { usage: { prompt_tokens: 10, completion_tokens: 5 } } + })() + + mockCreate.mockResolvedValue(mockStream) + + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }] + vi.spyOn(handler, "getModel").mockReturnValue({ + id: "deepseek-ai/DeepSeek-R1-0528", + info: { maxTokens: 1024, temperature: 0.7 }, + } as any) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks).toEqual([ + { type: "reasoning", text: "Thinking..." }, + { type: "text", text: "Hello" }, + { type: "usage", inputTokens: 10, outputTokens: 5 }, + ]) + }) + + it("should fall back to base provider for non-DeepSeek models", async () => { + const mockStream = (async function* () { + yield { choices: [{ delta: { content: "Hello" } }] } + yield { usage: { prompt_tokens: 10, completion_tokens: 5 } } + })() + + mockCreate.mockResolvedValue(mockStream) + + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }] + vi.spyOn(handler, "getModel").mockReturnValue({ + id: "some-other-model", + info: { maxTokens: 1024, temperature: 0.7 }, + } as any) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks).toEqual([ + { type: "text", text: "Hello" }, + { type: "usage", inputTokens: 10, outputTokens: 5 }, + ]) + }) + + it("should return default model when no model is specified", () => { + const model = handler.getModel() + expect(model.id).toBe(chutesDefaultModelId) + expect(model.info).toEqual(expect.objectContaining(chutesModels[chutesDefaultModelId])) + }) + + it("should return specified model when valid model is provided", () => { + const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1" + const handlerWithModel = new ChutesHandler({ + apiModelId: testModelId, + chutesApiKey: "test-chutes-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual(expect.objectContaining(chutesModels[testModelId])) + }) + + it("createMessage should pass correct parameters to Chutes client for DeepSeek R1", async () => { + const modelId: ChutesModelId = "deepseek-ai/DeepSeek-R1" + const handlerWithModel = new ChutesHandler({ + apiModelId: modelId, + chutesApiKey: "test-chutes-api-key", + }) + + const mockStream = (async function* () {})() + mockCreate.mockResolvedValue(mockStream) + + const systemPrompt = "Test system prompt for Chutes" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Chutes" }] + + const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) + await messageGenerator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: modelId, + messages: [ + { + role: "user", + content: `${systemPrompt}\n${messages[0].content}`, + }, + ], + }), + ) + }) +}) diff --git a/src/api/providers/__tests__/chutes.test.ts b/src/api/providers/__tests__/chutes.test.ts deleted file mode 100644 index 9ee8b8f995..0000000000 --- a/src/api/providers/__tests__/chutes.test.ts +++ /dev/null @@ -1,141 +0,0 @@ -// npx jest src/api/providers/__tests__/chutes.test.ts - -import OpenAI from "openai" -import { Anthropic } from "@anthropic-ai/sdk" - -import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types" - -import { ChutesHandler } from "../chutes" - -jest.mock("openai", () => { - const createMock = jest.fn() - return jest.fn(() => ({ chat: { completions: { create: createMock } } })) -}) - -describe("ChutesHandler", () => { - let handler: ChutesHandler - let mockCreate: jest.Mock - - beforeEach(() => { - jest.clearAllMocks() - mockCreate = (OpenAI as unknown as jest.Mock)().chat.completions.create - handler = new ChutesHandler({ chutesApiKey: "test-chutes-api-key" }) - }) - - test("should use the correct Chutes base URL", () => { - new ChutesHandler({ chutesApiKey: "test-chutes-api-key" }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://llm.chutes.ai/v1" })) - }) - - test("should use the provided API key", () => { - const chutesApiKey = "test-chutes-api-key" - new ChutesHandler({ chutesApiKey }) - expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: chutesApiKey })) - }) - - test("should return default model when no model is specified", () => { - const model = handler.getModel() - expect(model.id).toBe(chutesDefaultModelId) - expect(model.info).toEqual(chutesModels[chutesDefaultModelId]) - }) - - test("should return specified model when valid model is provided", () => { - const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1" - const handlerWithModel = new ChutesHandler({ apiModelId: testModelId, chutesApiKey: "test-chutes-api-key" }) - const model = handlerWithModel.getModel() - expect(model.id).toBe(testModelId) - expect(model.info).toEqual(chutesModels[testModelId]) - }) - - test("completePrompt method should return text from Chutes API", async () => { - const expectedResponse = "This is a test response from Chutes" - mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] }) - const result = await handler.completePrompt("test prompt") - expect(result).toBe(expectedResponse) - }) - - test("should handle errors in completePrompt", async () => { - const errorMessage = "Chutes API error" - mockCreate.mockRejectedValueOnce(new Error(errorMessage)) - await expect(handler.completePrompt("test prompt")).rejects.toThrow(`Chutes completion error: ${errorMessage}`) - }) - - test("createMessage should yield text content from stream", async () => { - const testContent = "This is test content from Chutes stream" - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: jest - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: { content: testContent } }] }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() - - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "text", text: testContent }) - }) - - test("createMessage should yield usage data from stream", async () => { - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - next: jest - .fn() - .mockResolvedValueOnce({ - done: false, - value: { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 20 } }, - }) - .mockResolvedValueOnce({ done: true }), - }), - } - }) - - const stream = handler.createMessage("system prompt", []) - const firstChunk = await stream.next() - - expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) - }) - - test("createMessage should pass correct parameters to Chutes client", async () => { - const modelId: ChutesModelId = "deepseek-ai/DeepSeek-R1" - const modelInfo = chutesModels[modelId] - const handlerWithModel = new ChutesHandler({ apiModelId: modelId, chutesApiKey: "test-chutes-api-key" }) - - mockCreate.mockImplementationOnce(() => { - return { - [Symbol.asyncIterator]: () => ({ - async next() { - return { done: true } - }, - }), - } - }) - - const systemPrompt = "Test system prompt for Chutes" - const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Chutes" }] - - const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) - await messageGenerator.next() - - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: modelId, - max_tokens: modelInfo.maxTokens, - temperature: 0.5, - messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]), - stream: true, - stream_options: { include_usage: true }, - }), - ) - }) -}) diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts index bf1f3c35a8..f196b5f309 100644 --- a/src/api/providers/base-openai-compatible-provider.ts +++ b/src/api/providers/base-openai-compatible-provider.ts @@ -31,7 +31,7 @@ export abstract class BaseOpenAiCompatibleProvider protected readonly options: ApiHandlerOptions - private client: OpenAI + protected client: OpenAI constructor({ providerName, diff --git a/src/api/providers/chutes.ts b/src/api/providers/chutes.ts index 0fa8741fa3..f520b944e6 100644 --- a/src/api/providers/chutes.ts +++ b/src/api/providers/chutes.ts @@ -1,6 +1,11 @@ -import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types" +import { DEEP_SEEK_DEFAULT_TEMPERATURE, type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types" +import { Anthropic } from "@anthropic-ai/sdk" +import OpenAI from "openai" import type { ApiHandlerOptions } from "../../shared/api" +import { convertToR1Format } from "../transform/r1-format" +import { convertToOpenAiMessages } from "../transform/openai-format" +import { ApiStream } from "../transform/stream" import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" @@ -16,4 +21,70 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider { defaultTemperature: 0.5, }) } + + private getCompletionParams( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + ): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming { + const { + id: model, + info: { maxTokens: max_tokens }, + } = this.getModel() + + const temperature = this.options.modelTemperature ?? this.defaultTemperature + + return { + model, + max_tokens, + temperature, + messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], + stream: true, + stream_options: { include_usage: true }, + } + } + + override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + const model = this.getModel() + + if (model.id.startsWith("deepseek-ai/DeepSeek-R1")) { + const stream = await this.client.chat.completions.create({ + ...this.getCompletionParams(systemPrompt, messages), + messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]), + }) + + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + + if ("reasoning" in delta && delta.reasoning && typeof delta.reasoning === "string") { + yield { type: "reasoning", text: delta.reasoning } + } + + if (delta?.content) { + yield { type: "text", text: delta.content } + } + + if (chunk.usage) { + yield { + type: "usage", + inputTokens: chunk.usage.prompt_tokens || 0, + outputTokens: chunk.usage.completion_tokens || 0, + } + } + } + } else { + yield* super.createMessage(systemPrompt, messages) + } + } + + override getModel() { + const model = super.getModel() + const isDeepSeekR1 = model.id.startsWith("deepseek-ai/DeepSeek-R1") + return { + ...model, + info: { + ...model.info, + temperature: isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : this.defaultTemperature, + }, + } + } }