diff --git a/.gitignore b/.gitignore index 211d06aa199..76a793a8251 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ coverage/ # Builds bin/ roo-cline-*.vsix +tsconfig.tsbuildinfo # Local prompts and rules /local-prompts diff --git a/package.json b/package.json index c4c38e78571..63a45c79d52 100644 --- a/package.json +++ b/package.json @@ -269,6 +269,11 @@ } }, "description": "Settings for VSCode Language Model API" + }, + "roo-cline.debug.mistral": { + "type": "boolean", + "default": false, + "description": "Enable debug output channel 'Roo Code Mistral' for Mistral API interactions" } } } diff --git a/src/api/providers/__tests__/mistral.test.ts b/src/api/providers/__tests__/mistral.test.ts index 781cb3dcfc5..c56373f3a09 100644 --- a/src/api/providers/__tests__/mistral.test.ts +++ b/src/api/providers/__tests__/mistral.test.ts @@ -1,49 +1,68 @@ import { MistralHandler } from "../mistral" import { ApiHandlerOptions, mistralDefaultModelId } from "../../../shared/api" import { Anthropic } from "@anthropic-ai/sdk" -import { ApiStreamTextChunk } from "../../transform/stream" - -// Mock Mistral client -const mockCreate = jest.fn() -jest.mock("@mistralai/mistralai", () => { - return { - Mistral: jest.fn().mockImplementation(() => ({ - chat: { - stream: mockCreate.mockImplementation(async (options) => { - const stream = { - [Symbol.asyncIterator]: async function* () { - yield { - data: { - choices: [ - { - delta: { content: "Test response" }, - index: 0, - }, - ], - }, - } - }, - } - return stream - }), - }, - })), +import { ApiStream } from "../../transform/stream" + +// Mock Mistral client first +const mockCreate = jest.fn().mockImplementation(() => mockStreamResponse()) + +// Create a mock stream response +const mockStreamResponse = async function* () { + yield { + data: { + choices: [ + { + delta: { content: "Test response" }, + index: 0, + }, + ], + }, } -}) +} + +// Mock the entire module +jest.mock("@mistralai/mistralai", () => ({ + Mistral: jest.fn().mockImplementation(() => ({ + chat: { + stream: mockCreate, + }, + })), +})) + +// Mock vscode +jest.mock("vscode", () => ({ + window: { + createOutputChannel: jest.fn().mockReturnValue({ + appendLine: jest.fn(), + show: jest.fn(), + dispose: jest.fn(), + }), + }, + workspace: { + getConfiguration: jest.fn().mockReturnValue({ + get: jest.fn().mockReturnValue(false), + }), + }, +})) describe("MistralHandler", () => { let handler: MistralHandler let mockOptions: ApiHandlerOptions beforeEach(() => { + // Clear all mocks before each test + jest.clearAllMocks() + mockOptions = { - apiModelId: "codestral-latest", // Update to match the actual model ID + apiModelId: mistralDefaultModelId, mistralApiKey: "test-api-key", includeMaxTokens: true, modelTemperature: 0, + mistralModelStreamingEnabled: true, + stopToken: undefined, + mistralCodestralUrl: undefined, } handler = new MistralHandler(mockOptions) - mockCreate.mockClear() }) describe("constructor", () => { @@ -60,23 +79,114 @@ describe("MistralHandler", () => { }) }).toThrow("Mistral API key is required") }) + }) - it("should use custom base URL if provided", () => { - const customBaseUrl = "https://custom.mistral.ai/v1" - const handlerWithCustomUrl = new MistralHandler({ + describe("stopToken handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text", text: "Hello!" }], + }, + ] + + async function consumeStream(stream: ApiStream) { + for await (const chunk of stream) { + // Consume the stream + } + } + + it("should not include stop parameter when stopToken is undefined", async () => { + const handlerWithoutStop = new MistralHandler({ ...mockOptions, - mistralCodestralUrl: customBaseUrl, + stopToken: undefined, }) - expect(handlerWithCustomUrl).toBeInstanceOf(MistralHandler) + const stream = handlerWithoutStop.createMessage(systemPrompt, messages) + await consumeStream(stream) + + expect(mockCreate).toHaveBeenCalledWith( + expect.not.objectContaining({ + stop: expect.anything(), + }), + ) }) - }) - describe("getModel", () => { - it("should return correct model info", () => { - const model = handler.getModel() - expect(model.id).toBe(mockOptions.apiModelId) - expect(model.info).toBeDefined() - expect(model.info.supportsPromptCache).toBe(false) + it("should not include stop parameter when stopToken is empty string", async () => { + const handlerWithEmptyStop = new MistralHandler({ + ...mockOptions, + stopToken: "", + }) + const stream = handlerWithEmptyStop.createMessage(systemPrompt, messages) + await consumeStream(stream) + + expect(mockCreate).toHaveBeenCalledWith( + expect.not.objectContaining({ + stop: expect.anything(), + }), + ) + }) + + it("should not include stop parameter when stopToken contains only whitespace", async () => { + const handlerWithWhitespaceStop = new MistralHandler({ + ...mockOptions, + stopToken: " ", + }) + const stream = handlerWithWhitespaceStop.createMessage(systemPrompt, messages) + await consumeStream(stream) + + expect(mockCreate).toHaveBeenCalledWith( + expect.not.objectContaining({ + stop: expect.anything(), + }), + ) + }) + + it("should handle non-empty stop token", async () => { + const handlerWithCommasStop = new MistralHandler({ + ...mockOptions, + stopToken: ",,,", + }) + const stream = handlerWithCommasStop.createMessage(systemPrompt, messages) + await consumeStream(stream) + + const callArgs = mockCreate.mock.calls[0][0] + expect(callArgs.model).toBe(mistralDefaultModelId) + expect(callArgs.maxTokens).toBe(256000) + expect(callArgs.temperature).toBe(0) + expect(callArgs.stream).toBe(true) + expect(callArgs.stop).toStrictEqual([",,,"] as string[]) + }) + + it("should include stop parameter with single token", async () => { + const handlerWithStop = new MistralHandler({ + ...mockOptions, + stopToken: "\\n\\n", + }) + const stream = handlerWithStop.createMessage(systemPrompt, messages) + await consumeStream(stream) + + const callArgs = mockCreate.mock.calls[0][0] + expect(callArgs.model).toBe("codestral-latest") + expect(callArgs.maxTokens).toBe(256000) + expect(callArgs.temperature).toBe(0) + expect(callArgs.stream).toBe(true) + expect(callArgs.stop).toStrictEqual(["\\n\\n"] as string[]) + }) + + it("should keep stop token as-is", async () => { + const handlerWithMultiStop = new MistralHandler({ + ...mockOptions, + stopToken: "\\n\\n,,DONE, ,END,", + }) + const stream = handlerWithMultiStop.createMessage(systemPrompt, messages) + await consumeStream(stream) + + const callArgs = mockCreate.mock.calls[0][0] + expect(callArgs.model).toBe("codestral-latest") + expect(callArgs.maxTokens).toBe(256000) + expect(callArgs.temperature).toBe(0) + expect(callArgs.stream).toBe(true) + expect(callArgs.stop).toStrictEqual(["\\n\\n,,DONE, ,END,"] as string[]) }) }) @@ -89,38 +199,76 @@ describe("MistralHandler", () => { }, ] - it("should create message successfully", async () => { - const iterator = handler.createMessage(systemPrompt, messages) - const result = await iterator.next() + async function consumeStream(stream: ApiStream) { + for await (const chunk of stream) { + // Consume the stream + } + } - expect(mockCreate).toHaveBeenCalledWith({ - model: mockOptions.apiModelId, - messages: expect.any(Array), - maxTokens: expect.any(Number), - temperature: 0, - }) + it("should create message with streaming enabled", async () => { + const stream = handler.createMessage(systemPrompt, messages) + await consumeStream(stream) - expect(result.value).toBeDefined() - expect(result.done).toBe(false) + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + role: "system", + content: systemPrompt, + }), + ]), + stream: true, + }), + ) }) - it("should handle streaming response correctly", async () => { - const iterator = handler.createMessage(systemPrompt, messages) - const results: ApiStreamTextChunk[] = [] - - for await (const chunk of iterator) { - if ("text" in chunk) { - results.push(chunk as ApiStreamTextChunk) - } - } + it("should handle temperature settings", async () => { + const handlerWithTemp = new MistralHandler({ + ...mockOptions, + modelTemperature: 0.7, + }) + const stream = handlerWithTemp.createMessage(systemPrompt, messages) + await consumeStream(stream) - expect(results.length).toBeGreaterThan(0) - expect(results[0].text).toBe("Test response") + const callArgs = mockCreate.mock.calls[0][0] + expect(callArgs.temperature).toBe(0.7) }) - it("should handle errors gracefully", async () => { - mockCreate.mockRejectedValueOnce(new Error("API Error")) - await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error") + it("should transform messages correctly", async () => { + const complexMessages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { type: "text", text: "Hello!" }, + { type: "text", text: "How are you?" }, + ], + }, + { + role: "assistant", + content: [{ type: "text", text: "I'm doing well!" }], + }, + ] + const stream = handler.createMessage(systemPrompt, complexMessages) + await consumeStream(stream) + + const callArgs = mockCreate.mock.calls[0][0] + expect(callArgs.messages).toEqual([ + { + role: "system", + content: systemPrompt, + }, + { + role: "user", + content: [ + { type: "text", text: "Hello!" }, + { type: "text", text: "How are you?" }, + ], + }, + { + role: "assistant", + content: "I'm doing well!", + }, + ]) }) }) }) diff --git a/src/api/providers/mistral.ts b/src/api/providers/mistral.ts index 08054c36b6a..b0d4e33867d 100644 --- a/src/api/providers/mistral.ts +++ b/src/api/providers/mistral.ts @@ -1,47 +1,104 @@ import { Anthropic } from "@anthropic-ai/sdk" import { Mistral } from "@mistralai/mistralai" import { ApiHandler } from "../" -import { - ApiHandlerOptions, - mistralDefaultModelId, - MistralModelId, - mistralModels, - ModelInfo, - openAiNativeDefaultModelId, - OpenAiNativeModelId, - openAiNativeModels, -} from "../../shared/api" +import { ApiHandlerOptions, mistralDefaultModelId, MistralModelId, mistralModels, ModelInfo } from "../../shared/api" import { convertToMistralMessages } from "../transform/mistral-format" import { ApiStream } from "../transform/stream" +import * as vscode from "vscode" const MISTRAL_DEFAULT_TEMPERATURE = 0 export class MistralHandler implements ApiHandler { private options: ApiHandlerOptions private client: Mistral + private readonly enableDebugOutput: boolean + private readonly outputChannel?: vscode.OutputChannel + private cachedModel: { id: MistralModelId; info: ModelInfo; forModelId: string | undefined } | null = null + + private static readonly outputChannelName = "Roo Code Mistral" + private static sharedOutputChannel: vscode.OutputChannel | undefined constructor(options: ApiHandlerOptions) { if (!options.mistralApiKey) { throw new Error("Mistral API key is required") } - // Set default model ID if not provided + // Clear cached model if options change + this.cachedModel = null + + // Destructure only the options we need + const { + apiModelId, + mistralApiKey, + mistralCodestralUrl, + mistralModelStreamingEnabled, + modelTemperature, + stopToken, + includeMaxTokens, + } = options + this.options = { - ...options, - apiModelId: options.apiModelId || mistralDefaultModelId, + apiModelId: apiModelId || mistralDefaultModelId, + mistralApiKey, + mistralCodestralUrl, + mistralModelStreamingEnabled, + modelTemperature, + stopToken, + includeMaxTokens, + } + + const config = vscode.workspace.getConfiguration("roo-cline") + this.enableDebugOutput = config.get("debug.mistral", false) + + if (this.enableDebugOutput) { + if (!MistralHandler.sharedOutputChannel) { + MistralHandler.sharedOutputChannel = vscode.window.createOutputChannel(MistralHandler.outputChannelName) + } + this.outputChannel = MistralHandler.sharedOutputChannel } + this.logDebug(`Initializing MistralHandler with options: ${JSON.stringify(this.options, null, 2)}`) const baseUrl = this.getBaseUrl() - console.debug(`[Roo Code] MistralHandler using baseUrl: ${baseUrl}`) + this.logDebug(`MistralHandler using baseUrl: ${baseUrl}`) + + const logger = { + group: (message: string) => { + if (this.enableDebugOutput && this.outputChannel) { + this.outputChannel.appendLine(`[Mistral SDK] Group: ${message}`) + } + }, + groupEnd: () => { + if (this.enableDebugOutput && this.outputChannel) { + this.outputChannel.appendLine(`[Mistral SDK] GroupEnd`) + } + }, + log: (...args: any[]) => { + if (this.enableDebugOutput && this.outputChannel) { + const formattedArgs = args + .map((arg) => (typeof arg === "object" ? JSON.stringify(arg, null, 2) : arg)) + .join(" ") + this.outputChannel.appendLine(`[Mistral SDK] ${formattedArgs}`) + } + }, + } + this.client = new Mistral({ serverURL: baseUrl, apiKey: this.options.mistralApiKey, + debugLogger: this.enableDebugOutput ? logger : undefined, }) } + private logDebug(message: string | object) { + if (this.enableDebugOutput && this.outputChannel) { + const formattedMessage = typeof message === "object" ? JSON.stringify(message, null, 2) : message + this.outputChannel.appendLine(`[Roo Code] ${formattedMessage}`) + } + } + private getBaseUrl(): string { const modelId = this.options.apiModelId ?? mistralDefaultModelId - console.debug(`[Roo Code] MistralHandler using modelId: ${modelId}`) + this.logDebug(`MistralHandler using modelId: ${modelId}`) if (modelId?.startsWith("codestral-")) { return this.options.mistralCodestralUrl || "https://codestral.mistral.ai" } @@ -49,13 +106,19 @@ export class MistralHandler implements ApiHandler { } async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { + this.logDebug(`Creating message with system prompt: ${systemPrompt}`) + const response = await this.client.chat.stream({ - model: this.options.apiModelId || mistralDefaultModelId, + model: this.options?.apiModelId || mistralDefaultModelId, + maxTokens: this.options?.includeMaxTokens ? this.getModel().info.maxTokens : undefined, messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)], - maxTokens: this.options.includeMaxTokens ? this.getModel().info.maxTokens : undefined, - temperature: this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE, + temperature: this.options?.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE, + ...(this.options?.mistralModelStreamingEnabled === true && { stream: true }), + ...(this.options?.stopToken?.trim() && { stop: [this.options.stopToken] }), }) + let completeContent = "" + for await (const chunk of response) { const delta = chunk.data.choices[0]?.delta if (delta?.content) { @@ -65,6 +128,7 @@ export class MistralHandler implements ApiHandler { } else if (Array.isArray(delta.content)) { content = delta.content.map((c) => (c.type === "text" ? c.text : "")).join("") } + completeContent += content yield { type: "text", text: content, @@ -72,6 +136,10 @@ export class MistralHandler implements ApiHandler { } if (chunk.data.usage) { + this.logDebug(`Complete content: ${completeContent}`) + this.logDebug( + `Usage - Input tokens: ${chunk.data.usage.promptTokens}, Output tokens: ${chunk.data.usage.completionTokens}`, + ) yield { type: "usage", inputTokens: chunk.data.usage.promptTokens || 0, @@ -82,19 +150,44 @@ export class MistralHandler implements ApiHandler { } getModel(): { id: MistralModelId; info: ModelInfo } { + // Check if cache exists and is for the current model + if (this.cachedModel && this.cachedModel.forModelId === this.options.apiModelId) { + return { + id: this.cachedModel.id, + info: this.cachedModel.info, + } + } + const modelId = this.options.apiModelId if (modelId && modelId in mistralModels) { const id = modelId as MistralModelId - return { id, info: mistralModels[id] } + this.logDebug(`Using model: ${id}`) + this.cachedModel = { + id, + info: mistralModels[id], + forModelId: modelId, + } + return { + id: this.cachedModel.id, + info: this.cachedModel.info, + } } - return { + + this.logDebug(`Using default model: ${mistralDefaultModelId}`) + this.cachedModel = { id: mistralDefaultModelId, info: mistralModels[mistralDefaultModelId], + forModelId: undefined, + } + return { + id: this.cachedModel.id, + info: this.cachedModel.info, } } async completePrompt(prompt: string): Promise { try { + this.logDebug(`Completing prompt: ${prompt}`) const response = await this.client.chat.complete({ model: this.options.apiModelId || mistralDefaultModelId, messages: [{ role: "user", content: prompt }], @@ -103,12 +196,16 @@ export class MistralHandler implements ApiHandler { const content = response.choices?.[0]?.message.content if (Array.isArray(content)) { - return content.map((c) => (c.type === "text" ? c.text : "")).join("") + const result = content.map((c) => (c.type === "text" ? c.text : "")).join("") + this.logDebug(`Completion result: ${result}`) + return result } + this.logDebug(`Completion result: ${content}`) return content || "" } catch (error) { if (error instanceof Error) { - throw new Error(`Mistral completion error: ${error.message}`) + this.logDebug(`Completion error: ${error.message}`) + throw new Error(`Mistral completion error: ${error.message}`, { cause: error }) } throw error } diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 6790224ecae..4e92786de76 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -127,7 +127,9 @@ type GlobalStateKey = | "requestyModelInfo" | "unboundModelInfo" | "modelTemperature" + | "stopToken" | "mistralCodestralUrl" + | "mistralModelStreamingEnabled" | "maxOpenTabsContext" export const GlobalFileNames = { @@ -1666,6 +1668,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { vsCodeLmModelSelector, mistralApiKey, mistralCodestralUrl, + mistralModelStreamingEnabled, unboundApiKey, unboundModelId, unboundModelInfo, @@ -1673,6 +1676,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { requestyModelId, requestyModelInfo, modelTemperature, + stopToken, } = apiConfiguration await Promise.all([ this.updateGlobalState("apiProvider", apiProvider), @@ -1706,6 +1710,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.storeSecret("deepSeekApiKey", deepSeekApiKey), this.updateGlobalState("azureApiVersion", azureApiVersion), this.updateGlobalState("openAiStreamingEnabled", openAiStreamingEnabled), + this.updateGlobalState("mistralModelStreamingEnabled", mistralModelStreamingEnabled), this.updateGlobalState("openRouterModelId", openRouterModelId), this.updateGlobalState("openRouterModelInfo", openRouterModelInfo), this.updateGlobalState("openRouterBaseUrl", openRouterBaseUrl), @@ -1720,6 +1725,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.updateGlobalState("requestyModelId", requestyModelId), this.updateGlobalState("requestyModelInfo", requestyModelInfo), this.updateGlobalState("modelTemperature", modelTemperature), + this.updateGlobalState("stopToken", stopToken), ]) if (this.cline) { this.cline.api = buildApiHandler(apiConfiguration) @@ -2553,6 +2559,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { deepSeekApiKey, mistralApiKey, mistralCodestralUrl, + mistralModelStreamingEnabled, azureApiVersion, openAiStreamingEnabled, openRouterModelId, @@ -2602,6 +2609,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { requestyModelId, requestyModelInfo, modelTemperature, + stopToken, maxOpenTabsContext, ] = await Promise.all([ this.getGlobalState("apiProvider") as Promise, @@ -2635,6 +2643,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.getSecret("deepSeekApiKey") as Promise, this.getSecret("mistralApiKey") as Promise, this.getGlobalState("mistralCodestralUrl") as Promise, + this.getGlobalState("mistralModelStreamingEnabled") as Promise, this.getGlobalState("azureApiVersion") as Promise, this.getGlobalState("openAiStreamingEnabled") as Promise, this.getGlobalState("openRouterModelId") as Promise, @@ -2684,6 +2693,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.getGlobalState("requestyModelId") as Promise, this.getGlobalState("requestyModelInfo") as Promise, this.getGlobalState("modelTemperature") as Promise, + this.getGlobalState("stopToken") as Promise, this.getGlobalState("maxOpenTabsContext") as Promise, ]) @@ -2734,6 +2744,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { deepSeekApiKey, mistralApiKey, mistralCodestralUrl, + mistralModelStreamingEnabled, azureApiVersion, openAiStreamingEnabled, openRouterModelId, @@ -2748,6 +2759,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { requestyModelId, requestyModelInfo, modelTemperature, + stopToken, }, lastShownAnnouncementId, customInstructions, diff --git a/src/shared/api.ts b/src/shared/api.ts index 9ecb12c1403..423956106d4 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -52,7 +52,8 @@ export interface ApiHandlerOptions { geminiApiKey?: string openAiNativeApiKey?: string mistralApiKey?: string - mistralCodestralUrl?: string // New option for Codestral URL + mistralCodestralUrl?: string + mistralModelStreamingEnabled?: boolean azureApiVersion?: string openRouterUseMiddleOutTransform?: boolean openAiStreamingEnabled?: boolean @@ -67,6 +68,7 @@ export interface ApiHandlerOptions { requestyModelId?: string requestyModelInfo?: ModelInfo modelTemperature?: number + stopToken?: string } export type ApiConfiguration = ApiHandlerOptions & { diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 1303e79c7ab..756e310ecfe 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -347,6 +347,32 @@ const ApiOptions = ({

)} + +

+

+ + Enable streaming + +
+

+ + + Optional: Stop Token e.g. \n\n + +

+ Optional token to stop generation when encountered +

)}