Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 186 additions & 22 deletions src/api/providers/__tests__/chutes.spec.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,64 @@
// npx vitest run api/providers/__tests__/chutes.spec.ts

import { vitest, describe, it, expect, beforeEach } from "vitest"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"
import OpenAI from "openai"

import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
import { type ChutesModelId, chutesDefaultModelId, chutesModels, DEEP_SEEK_DEFAULT_TEMPERATURE } from "@roo-code/types"

import { ChutesHandler } from "../chutes"

const mockCreate = vitest.fn()
// Create mock functions
const mockCreate = vi.fn()

vitest.mock("openai", () => {
return {
default: vitest.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate,
},
// Mock OpenAI module
vi.mock("openai", () => ({
default: vi.fn(() => ({
chat: {
completions: {
create: mockCreate,
},
})),
}
})
},
})),
}))

describe("ChutesHandler", () => {
let handler: ChutesHandler

beforeEach(() => {
vitest.clearAllMocks()
handler = new ChutesHandler({ chutesApiKey: "test-chutes-api-key" })
vi.clearAllMocks()
// Set up default mock implementation
mockCreate.mockImplementation(async () => ({
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
delta: {},
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
},
}))
handler = new ChutesHandler({ chutesApiKey: "test-key" })
})

afterEach(() => {
vi.restoreAllMocks()
})

it("should use the correct Chutes base URL", () => {
Expand All @@ -41,18 +72,96 @@ describe("ChutesHandler", () => {
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: chutesApiKey }))
})

it("should handle DeepSeek R1 reasoning format", async () => {
// Override the mock for this specific test
mockCreate.mockImplementationOnce(async () => ({
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: { content: "<think>Thinking..." },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
delta: { content: "</think>Hello" },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
delta: {},
index: 0,
},
],
usage: { prompt_tokens: 10, completion_tokens: 5 },
}
},
}))

const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
vi.spyOn(handler, "getModel").mockReturnValue({
id: "deepseek-ai/DeepSeek-R1-0528",
info: { maxTokens: 1024, temperature: 0.7 },
} as any)

const stream = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks).toEqual([
{ type: "reasoning", text: "Thinking..." },
{ type: "text", text: "Hello" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
})

it("should fall back to base provider for non-DeepSeek models", async () => {
// Use default mock implementation which returns text content
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hi" }]
vi.spyOn(handler, "getModel").mockReturnValue({
id: "some-other-model",
info: { maxTokens: 1024, temperature: 0.7 },
} as any)

const stream = handler.createMessage(systemPrompt, messages)
const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks).toEqual([
{ type: "text", text: "Test response" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
})

it("should return default model when no model is specified", () => {
const model = handler.getModel()
expect(model.id).toBe(chutesDefaultModelId)
expect(model.info).toEqual(chutesModels[chutesDefaultModelId])
expect(model.info).toEqual(expect.objectContaining(chutesModels[chutesDefaultModelId]))
})

it("should return specified model when valid model is provided", () => {
const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
const handlerWithModel = new ChutesHandler({ apiModelId: testModelId, chutesApiKey: "test-chutes-api-key" })
const handlerWithModel = new ChutesHandler({
apiModelId: testModelId,
chutesApiKey: "test-chutes-api-key",
})
const model = handlerWithModel.getModel()
expect(model.id).toBe(testModelId)
expect(model.info).toEqual(chutesModels[testModelId])
expect(model.info).toEqual(expect.objectContaining(chutesModels[testModelId]))
})

it("completePrompt method should return text from Chutes API", async () => {
Expand All @@ -74,7 +183,7 @@ describe("ChutesHandler", () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vitest
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
Expand All @@ -96,7 +205,7 @@ describe("ChutesHandler", () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vitest
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
Expand All @@ -114,8 +223,43 @@ describe("ChutesHandler", () => {
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
})

it("createMessage should pass correct parameters to Chutes client", async () => {
it("createMessage should pass correct parameters to Chutes client for DeepSeek R1", async () => {
const modelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"

// Clear previous mocks and set up new implementation
mockCreate.mockClear()
mockCreate.mockImplementationOnce(async () => ({
[Symbol.asyncIterator]: async function* () {
// Empty stream for this test
},
}))

const handlerWithModel = new ChutesHandler({
apiModelId: modelId,
chutesApiKey: "test-chutes-api-key",
})

const systemPrompt = "Test system prompt for Chutes"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Chutes" }]

const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: modelId,
messages: [
{
role: "user",
content: `${systemPrompt}\n${messages[0].content}`,
},
],
}),
)
})

it("createMessage should pass correct parameters to Chutes client for non-DeepSeek models", async () => {
const modelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct"
const modelInfo = chutesModels[modelId]
const handlerWithModel = new ChutesHandler({ apiModelId: modelId, chutesApiKey: "test-chutes-api-key" })

Expand Down Expand Up @@ -146,4 +290,24 @@ describe("ChutesHandler", () => {
}),
)
})

it("should apply DeepSeek default temperature for R1 models", () => {
const testModelId: ChutesModelId = "deepseek-ai/DeepSeek-R1"
const handlerWithModel = new ChutesHandler({
apiModelId: testModelId,
chutesApiKey: "test-chutes-api-key",
})
const model = handlerWithModel.getModel()
expect(model.info.temperature).toBe(DEEP_SEEK_DEFAULT_TEMPERATURE)
})

it("should use default temperature for non-DeepSeek models", () => {
const testModelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct"
const handlerWithModel = new ChutesHandler({
apiModelId: testModelId,
chutesApiKey: "test-chutes-api-key",
})
const model = handlerWithModel.getModel()
expect(model.info.temperature).toBe(0.5)
})
})
2 changes: 1 addition & 1 deletion src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>

protected readonly options: ApiHandlerOptions

private client: OpenAI
protected client: OpenAI
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this one necessary to change?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, I see now


constructor({
providerName,
Expand Down
86 changes: 85 additions & 1 deletion src/api/providers/chutes.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import { type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
import { DEEP_SEEK_DEFAULT_TEMPERATURE, type ChutesModelId, chutesDefaultModelId, chutesModels } from "@roo-code/types"
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import type { ApiHandlerOptions } from "../../shared/api"
import { XmlMatcher } from "../../utils/xml-matcher"
import { convertToR1Format } from "../transform/r1-format"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"

Expand All @@ -16,4 +22,82 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider<ChutesModelId> {
defaultTemperature: 0.5,
})
}

private getCompletionParams(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
): OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming {
const {
id: model,
info: { maxTokens: max_tokens },
} = this.getModel()

const temperature = this.options.modelTemperature ?? this.getModel().info.temperature

return {
model,
max_tokens,
temperature,
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
}
}

override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const model = this.getModel()

if (model.id.includes("DeepSeek-R1")) {
const stream = await this.client.chat.completions.create({
...this.getCompletionParams(systemPrompt, messages),
messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]),
})

const matcher = new XmlMatcher(
"think",
(chunk) =>
({
type: chunk.matched ? "reasoning" : "text",
text: chunk.data,
}) as const,
)

for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta

if (delta?.content) {
for (const processedChunk of matcher.update(delta.content)) {
yield processedChunk
}
}

if (chunk.usage) {
yield {
type: "usage",
inputTokens: chunk.usage.prompt_tokens || 0,
outputTokens: chunk.usage.completion_tokens || 0,
}
}
}

// Process any remaining content
for (const processedChunk of matcher.final()) {
yield processedChunk
}
} else {
yield* super.createMessage(systemPrompt, messages)
}
}

override getModel() {
const model = super.getModel()
const isDeepSeekR1 = model.id.includes("DeepSeek-R1")
return {
...model,
info: {
...model.info,
temperature: isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : this.defaultTemperature,
},
}
}
}
Loading