Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ coverage/
# Builds
bin/
roo-cline-*.vsix
tsconfig.tsbuildinfo

# Local prompts and rules
/local-prompts
Expand Down
5 changes: 5 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@
}
},
"description": "Settings for VSCode Language Model API"
},
"roo-cline.debug.mistral": {
"type": "boolean",
"default": false,
"description": "Enable debug output channel 'Roo Code Mistral' for Mistral API interactions"
}
}
}
Expand Down
282 changes: 215 additions & 67 deletions src/api/providers/__tests__/mistral.test.ts
Original file line number Diff line number Diff line change
@@ -1,49 +1,68 @@
import { MistralHandler } from "../mistral"
import { ApiHandlerOptions, mistralDefaultModelId } from "../../../shared/api"
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiStreamTextChunk } from "../../transform/stream"

// Mock Mistral client
const mockCreate = jest.fn()
jest.mock("@mistralai/mistralai", () => {
return {
Mistral: jest.fn().mockImplementation(() => ({
chat: {
stream: mockCreate.mockImplementation(async (options) => {
const stream = {
[Symbol.asyncIterator]: async function* () {
yield {
data: {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
},
}
},
}
return stream
}),
},
})),
import { ApiStream } from "../../transform/stream"

// Mock Mistral client first
const mockCreate = jest.fn().mockImplementation(() => mockStreamResponse())

// Create a mock stream response
const mockStreamResponse = async function* () {
yield {
data: {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
},
}
})
}

// Mock the entire module
jest.mock("@mistralai/mistralai", () => ({
Mistral: jest.fn().mockImplementation(() => ({
chat: {
stream: mockCreate,
},
})),
}))

// Mock vscode
jest.mock("vscode", () => ({
window: {
createOutputChannel: jest.fn().mockReturnValue({
appendLine: jest.fn(),
show: jest.fn(),
dispose: jest.fn(),
}),
},
workspace: {
getConfiguration: jest.fn().mockReturnValue({
get: jest.fn().mockReturnValue(false),
}),
},
}))

describe("MistralHandler", () => {
let handler: MistralHandler
let mockOptions: ApiHandlerOptions

beforeEach(() => {
// Clear all mocks before each test
jest.clearAllMocks()

mockOptions = {
apiModelId: "codestral-latest", // Update to match the actual model ID
apiModelId: mistralDefaultModelId,
mistralApiKey: "test-api-key",
includeMaxTokens: true,
modelTemperature: 0,
mistralModelStreamingEnabled: true,
stopToken: undefined,
mistralCodestralUrl: undefined,
}
handler = new MistralHandler(mockOptions)
mockCreate.mockClear()
})

describe("constructor", () => {
Expand All @@ -60,23 +79,114 @@ describe("MistralHandler", () => {
})
}).toThrow("Mistral API key is required")
})
})

it("should use custom base URL if provided", () => {
const customBaseUrl = "https://custom.mistral.ai/v1"
const handlerWithCustomUrl = new MistralHandler({
describe("stopToken handling", () => {
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [{ type: "text", text: "Hello!" }],
},
]

async function consumeStream(stream: ApiStream) {
for await (const chunk of stream) {
// Consume the stream
}
}

it("should not include stop parameter when stopToken is undefined", async () => {
const handlerWithoutStop = new MistralHandler({
...mockOptions,
mistralCodestralUrl: customBaseUrl,
stopToken: undefined,
})
expect(handlerWithCustomUrl).toBeInstanceOf(MistralHandler)
const stream = handlerWithoutStop.createMessage(systemPrompt, messages)
await consumeStream(stream)

expect(mockCreate).toHaveBeenCalledWith(
expect.not.objectContaining({
stop: expect.anything(),
}),
)
})
})

describe("getModel", () => {
it("should return correct model info", () => {
const model = handler.getModel()
expect(model.id).toBe(mockOptions.apiModelId)
expect(model.info).toBeDefined()
expect(model.info.supportsPromptCache).toBe(false)
it("should not include stop parameter when stopToken is empty string", async () => {
const handlerWithEmptyStop = new MistralHandler({
...mockOptions,
stopToken: "",
})
const stream = handlerWithEmptyStop.createMessage(systemPrompt, messages)
await consumeStream(stream)

expect(mockCreate).toHaveBeenCalledWith(
expect.not.objectContaining({
stop: expect.anything(),
}),
)
})

it("should not include stop parameter when stopToken contains only whitespace", async () => {
const handlerWithWhitespaceStop = new MistralHandler({
...mockOptions,
stopToken: " ",
})
const stream = handlerWithWhitespaceStop.createMessage(systemPrompt, messages)
await consumeStream(stream)

expect(mockCreate).toHaveBeenCalledWith(
expect.not.objectContaining({
stop: expect.anything(),
}),
)
})

it("should handle non-empty stop token", async () => {
const handlerWithCommasStop = new MistralHandler({
...mockOptions,
stopToken: ",,,",
})
const stream = handlerWithCommasStop.createMessage(systemPrompt, messages)
await consumeStream(stream)

const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.model).toBe(mistralDefaultModelId)
expect(callArgs.maxTokens).toBe(256000)
expect(callArgs.temperature).toBe(0)
expect(callArgs.stream).toBe(true)
expect(callArgs.stop).toStrictEqual([",,,"] as string[])
})

it("should include stop parameter with single token", async () => {
const handlerWithStop = new MistralHandler({
...mockOptions,
stopToken: "\\n\\n",
})
const stream = handlerWithStop.createMessage(systemPrompt, messages)
await consumeStream(stream)

const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.model).toBe("codestral-latest")
expect(callArgs.maxTokens).toBe(256000)
expect(callArgs.temperature).toBe(0)
expect(callArgs.stream).toBe(true)
expect(callArgs.stop).toStrictEqual(["\\n\\n"] as string[])
})

it("should keep stop token as-is", async () => {
const handlerWithMultiStop = new MistralHandler({
...mockOptions,
stopToken: "\\n\\n,,DONE, ,END,",
})
const stream = handlerWithMultiStop.createMessage(systemPrompt, messages)
await consumeStream(stream)

const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.model).toBe("codestral-latest")
expect(callArgs.maxTokens).toBe(256000)
expect(callArgs.temperature).toBe(0)
expect(callArgs.stream).toBe(true)
expect(callArgs.stop).toStrictEqual(["\\n\\n,,DONE, ,END,"] as string[])
})
})

Expand All @@ -89,38 +199,76 @@ describe("MistralHandler", () => {
},
]

it("should create message successfully", async () => {
const iterator = handler.createMessage(systemPrompt, messages)
const result = await iterator.next()
async function consumeStream(stream: ApiStream) {
for await (const chunk of stream) {
// Consume the stream
}
}

expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.apiModelId,
messages: expect.any(Array),
maxTokens: expect.any(Number),
temperature: 0,
})
it("should create message with streaming enabled", async () => {
const stream = handler.createMessage(systemPrompt, messages)
await consumeStream(stream)

expect(result.value).toBeDefined()
expect(result.done).toBe(false)
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
role: "system",
content: systemPrompt,
}),
]),
stream: true,
}),
)
})

it("should handle streaming response correctly", async () => {
const iterator = handler.createMessage(systemPrompt, messages)
const results: ApiStreamTextChunk[] = []

for await (const chunk of iterator) {
if ("text" in chunk) {
results.push(chunk as ApiStreamTextChunk)
}
}
it("should handle temperature settings", async () => {
const handlerWithTemp = new MistralHandler({
...mockOptions,
modelTemperature: 0.7,
})
const stream = handlerWithTemp.createMessage(systemPrompt, messages)
await consumeStream(stream)

expect(results.length).toBeGreaterThan(0)
expect(results[0].text).toBe("Test response")
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.temperature).toBe(0.7)
})

it("should handle errors gracefully", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error")
it("should transform messages correctly", async () => {
const complexMessages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{ type: "text", text: "Hello!" },
{ type: "text", text: "How are you?" },
],
},
{
role: "assistant",
content: [{ type: "text", text: "I'm doing well!" }],
},
]
const stream = handler.createMessage(systemPrompt, complexMessages)
await consumeStream(stream)

const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.messages).toEqual([
{
role: "system",
content: systemPrompt,
},
{
role: "user",
content: [
{ type: "text", text: "Hello!" },
{ type: "text", text: "How are you?" },
],
},
{
role: "assistant",
content: "I'm doing well!",
},
])
})
})
})
Loading