Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions src/api/providers/__tests__/chutes.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import { Anthropic } from "@anthropic-ai/sdk"
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"
import OpenAI from "openai"

import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"

import { ChutesHandler } from "../chutes"
import * as chutesModule from "../chutes"

// Mock the entire module
vi.mock("../chutes", async () => {
const actual = await vi.importActual<typeof chutesModule>("../chutes")
return {
...actual,
ChutesHandler: class extends actual.ChutesHandler {
constructor(options: any) {
super(options)
this.client = {
chat: {
completions: {
create: vi.fn(),
},
},
} as any
}
},
}
})

describe("ChutesHandler", () => {
let handler: ChutesHandler
let mockCreate: any

beforeEach(() => {
handler = new ChutesHandler({ chutesApiKey: "test-key" })
mockCreate = vi.spyOn((handler as any).client.chat.completions, "create")
})

afterEach(() => {
vi.restoreAllMocks()
})

it("should handle DeepSeek R1 reasoning format", async () => {
const mockStream = (async function* () {
yield { choices: [{ delta: { reasoning: "Thinking..." } }] }
yield { choices: [{ delta: { content: "Hello" } }] }
yield { usage: { prompt_tokens: 10, completion_tokens: 5 } }
})()

mockCreate.mockResolvedValue(mockStream)

const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
vi.spyOn(handler, "getModel").mockReturnValue({
id: "deepseek-ai/DeepSeek-R1-0528",
info: { maxTokens: 1024, temperature: 0.7 },
} as any)

const stream = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks).toEqual([
{ type: "reasoning", text: "Thinking..." },
{ type: "text", text: "Hello" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
})

it("should fall back to base provider for non-DeepSeek models", async () => {
const mockStream = (async function* () {
yield { choices: [{ delta: { content: "Hello" } }] }
yield { usage: { prompt_tokens: 10, completion_tokens: 5 } }
})()

mockCreate.mockResolvedValue(mockStream)

const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
vi.spyOn(handler, "getModel").mockReturnValue({
id: "some-other-model",
info: { maxTokens: 1024, temperature: 0.7 },
} as any)

const stream = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks).toEqual([
{ type: "text", text: "Hello" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
})

it("should return default model when no model is specified", () => {
const model = handler.getModel()
expect(model.id).toBe(chutesDefaultModelId)
expect(model.info).toEqual(expect.objectContaining(chutesModels[chutesDefaultModelId]))
})

it("should return specified model when valid model is provided", () => {
const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
const handlerWithModel = new ChutesHandler({
apiModelId: testModelId,
chutesApiKey: "test-chutes-api-key",
})
const model = handlerWithModel.getModel()
expect(model.id).toBe(testModelId)
expect(model.info).toEqual(expect.objectContaining(chutesModels[testModelId]))
})

it("createMessage should pass correct parameters to Chutes client for DeepSeek R1", async () => {
const modelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
const handlerWithModel = new ChutesHandler({
apiModelId: modelId,
chutesApiKey: "test-chutes-api-key",
})

const mockStream = (async function* () {})()
mockCreate.mockResolvedValue(mockStream)

const systemPrompt = "Test system prompt for Chutes"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Chutes" }]

const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: modelId,
messages: [
{
role: "user",
content: `${systemPrompt}\n${messages[0].content}`,
},
],
}),
)
})
})
141 changes: 0 additions & 141 deletions src/api/providers/__tests__/chutes.test.ts

This file was deleted.

2 changes: 1 addition & 1 deletion src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>

protected readonly options: ApiHandlerOptions

private client: OpenAI
protected client: OpenAI

constructor({
providerName,
Expand Down
73 changes: 72 additions & 1 deletion src/api/providers/chutes.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
import { DEEP_SEEK_DEFAULT_TEMPERATURE, type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import type { ApiHandlerOptions } from "../../shared/api"
import { convertToR1Format } from "../transform/r1-format"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"

Expand All @@ -16,4 +21,70 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider<ChutesModelId> {
defaultTemperature: 0.5,
})
}

private getCompletionParams(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming {
const {
id: model,
info: { maxTokens: max_tokens },
} = this.getModel()

const temperature = this.options.modelTemperature ?? this.defaultTemperature

return {
model,
max_tokens,
temperature,
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
}
}

override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const model = this.getModel()

if (model.id.startsWith("deepseek-ai/DeepSeek-R1")) {
const stream = await this.client.chat.completions.create({
...this.getCompletionParams(systemPrompt, messages),
messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
})

for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta

if ("reasoning" in delta && delta.reasoning && typeof delta.reasoning === "string") {
yield { type: "reasoning", text: delta.reasoning }
}

if (delta?.content) {
yield { type: "text", text: delta.content }
}

if (chunk.usage) {
yield {
type: "usage",
inputTokens: chunk.usage.prompt_tokens || 0,
outputTokens: chunk.usage.completion_tokens || 0,
}
}
}
} else {
yield* super.createMessage(systemPrompt, messages)
}
}

override getModel() {
const model = super.getModel()
const isDeepSeekR1 = model.id.startsWith("deepseek-ai/DeepSeek-R1")
return {
...model,
info: {
...model.info,
temperature: isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : this.defaultTemperature,
},
}
}
}
Loading