diff --git a/src/api/providers/__tests__/mistral.test.ts b/src/api/providers/__tests__/mistral.test.ts new file mode 100644 index 00000000000..781cb3dcfc5 --- /dev/null +++ b/src/api/providers/__tests__/mistral.test.ts @@ -0,0 +1,126 @@ +import { MistralHandler } from "../mistral" +import { ApiHandlerOptions, mistralDefaultModelId } from "../../../shared/api" +import { Anthropic } from "@anthropic-ai/sdk" +import { ApiStreamTextChunk } from "../../transform/stream" + +// Mock Mistral client +const mockCreate = jest.fn() +jest.mock("@mistralai/mistralai", () => { + return { + Mistral: jest.fn().mockImplementation(() => ({ + chat: { + stream: mockCreate.mockImplementation(async (options) => { + const stream = { + [Symbol.asyncIterator]: async function* () { + yield { + data: { + choices: [ + { + delta: { content: "Test response" }, + index: 0, + }, + ], + }, + } + }, + } + return stream + }), + }, + })), + } +}) + +describe("MistralHandler", () => { + let handler: MistralHandler + let mockOptions: ApiHandlerOptions + + beforeEach(() => { + mockOptions = { + apiModelId: "codestral-latest", // Update to match the actual model ID + mistralApiKey: "test-api-key", + includeMaxTokens: true, + modelTemperature: 0, + } + handler = new MistralHandler(mockOptions) + mockCreate.mockClear() + }) + + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(MistralHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) + + it("should throw error if API key is missing", () => { + expect(() => { + new MistralHandler({ + ...mockOptions, + mistralApiKey: undefined, + }) + }).toThrow("Mistral API key is required") + }) + + it("should use custom base URL if provided", () => { + const customBaseUrl = "https://custom.mistral.ai/v1" + const handlerWithCustomUrl = new MistralHandler({ + ...mockOptions, + mistralCodestralUrl: customBaseUrl, + }) + expect(handlerWithCustomUrl).toBeInstanceOf(MistralHandler) + }) + }) + + describe("getModel", () => { + it("should return correct model info", () => { + const model = handler.getModel() + expect(model.id).toBe(mockOptions.apiModelId) + expect(model.info).toBeDefined() + expect(model.info.supportsPromptCache).toBe(false) + }) + }) + + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text", text: "Hello!" }], + }, + ] + + it("should create message successfully", async () => { + const iterator = handler.createMessage(systemPrompt, messages) + const result = await iterator.next() + + expect(mockCreate).toHaveBeenCalledWith({ + model: mockOptions.apiModelId, + messages: expect.any(Array), + maxTokens: expect.any(Number), + temperature: 0, + }) + + expect(result.value).toBeDefined() + expect(result.done).toBe(false) + }) + + it("should handle streaming response correctly", async () => { + const iterator = handler.createMessage(systemPrompt, messages) + const results: ApiStreamTextChunk[] = [] + + for await (const chunk of iterator) { + if ("text" in chunk) { + results.push(chunk as ApiStreamTextChunk) + } + } + + expect(results.length).toBeGreaterThan(0) + expect(results[0].text).toBe("Test response") + }) + + it("should handle errors gracefully", async () => { + mockCreate.mockRejectedValueOnce(new Error("API Error")) + await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error") + }) + }) +}) diff --git a/src/api/providers/mistral.ts b/src/api/providers/mistral.ts index 9ce70a297cb..6582f5d2209 100644 --- a/src/api/providers/mistral.ts +++ b/src/api/providers/mistral.ts @@ -21,23 +21,36 @@ export class MistralHandler implements ApiHandler { private client: Mistral constructor(options: ApiHandlerOptions) { + if (!options.mistralApiKey) { + throw new Error("Mistral API key is required") + } + this.options = options + const baseUrl = this.getBaseUrl() + console.debug(`[Roo Code] MistralHandler using baseUrl: ${baseUrl}`) this.client = new Mistral({ - serverURL: "https://codestral.mistral.ai", + serverURL: baseUrl, apiKey: this.options.mistralApiKey, }) } + private getBaseUrl(): string { + const modelId = this.options.apiModelId + if (modelId?.startsWith("codestral-")) { + return this.options.mistralCodestralUrl || "https://codestral.mistral.ai" + } + return "https://api.mistral.ai" + } + async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const stream = await this.client.chat.stream({ - model: this.getModel().id, - // max_completion_tokens: this.getModel().info.maxTokens, + const response = await this.client.chat.stream({ + model: this.options.apiModelId || mistralDefaultModelId, + messages: convertToMistralMessages(messages), + maxTokens: this.options.includeMaxTokens ? this.getModel().info.maxTokens : undefined, temperature: this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE, - messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)], - stream: true, }) - for await (const chunk of stream) { + for await (const chunk of response) { const delta = chunk.data.choices[0]?.delta if (delta?.content) { let content: string = "" diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 0d59d78e716..4375cf4da18 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -127,6 +127,7 @@ type GlobalStateKey = | "requestyModelInfo" | "unboundModelInfo" | "modelTemperature" + | "mistralCodestralUrl" | "maxOpenTabsContext" export const GlobalFileNames = { @@ -1637,6 +1638,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { openRouterUseMiddleOutTransform, vsCodeLmModelSelector, mistralApiKey, + mistralCodestralUrl, unboundApiKey, unboundModelId, unboundModelInfo, @@ -1682,6 +1684,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform) await this.updateGlobalState("vsCodeLmModelSelector", vsCodeLmModelSelector) await this.storeSecret("mistralApiKey", mistralApiKey) + await this.updateGlobalState("mistralCodestralUrl", mistralCodestralUrl) await this.storeSecret("unboundApiKey", unboundApiKey) await this.updateGlobalState("unboundModelId", unboundModelId) await this.updateGlobalState("unboundModelInfo", unboundModelInfo) @@ -2521,6 +2524,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { openAiNativeApiKey, deepSeekApiKey, mistralApiKey, + mistralCodestralUrl, azureApiVersion, openAiStreamingEnabled, openRouterModelId, @@ -2602,6 +2606,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.getSecret("openAiNativeApiKey") as Promise, this.getSecret("deepSeekApiKey") as Promise, this.getSecret("mistralApiKey") as Promise, + this.getGlobalState("mistralCodestralUrl") as Promise, this.getGlobalState("azureApiVersion") as Promise, this.getGlobalState("openAiStreamingEnabled") as Promise, this.getGlobalState("openRouterModelId") as Promise, @@ -2700,6 +2705,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { openAiNativeApiKey, deepSeekApiKey, mistralApiKey, + mistralCodestralUrl, azureApiVersion, openAiStreamingEnabled, openRouterModelId, diff --git a/src/shared/api.ts b/src/shared/api.ts index 7e926b09cfa..5ad9df8dfa7 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -52,6 +52,7 @@ export interface ApiHandlerOptions { geminiApiKey?: string openAiNativeApiKey?: string mistralApiKey?: string + mistralCodestralUrl?: string // New option for Codestral URL azureApiVersion?: string openRouterUseMiddleOutTransform?: boolean openAiStreamingEnabled?: boolean @@ -670,13 +671,53 @@ export type MistralModelId = keyof typeof mistralModels export const mistralDefaultModelId: MistralModelId = "codestral-latest" export const mistralModels = { "codestral-latest": { - maxTokens: 32_768, + maxTokens: 256_000, contextWindow: 256_000, supportsImages: false, supportsPromptCache: false, inputPrice: 0.3, outputPrice: 0.9, }, + "mistral-large-latest": { + maxTokens: 131_000, + contextWindow: 131_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 2.0, + outputPrice: 6.0, + }, + "ministral-8b-latest": { + maxTokens: 131_000, + contextWindow: 131_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.1, + outputPrice: 0.1, + }, + "ministral-3b-latest": { + maxTokens: 131_000, + contextWindow: 131_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.04, + outputPrice: 0.04, + }, + "mistral-small-latest": { + maxTokens: 32_000, + contextWindow: 32_000, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.2, + outputPrice: 0.6, + }, + "pixtral-large-latest": { + maxTokens: 131_000, + contextWindow: 131_000, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 2.0, + outputPrice: 6.0, + }, } as const satisfies Record // Unbound Security diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 0d8bb43bb5b..125c3d37857 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -314,6 +314,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage, fromWelcomeView }: A placeholder="Enter API Key..."> Mistral API Key +

- You can get a Mistral API key by signing up here. + You can get a La Plateforme (api.mistral.ai) / Codestral (codestral.mistral.ai) API key + by signing up here. )}

+ + {apiConfiguration?.apiModelId?.startsWith("codestral-") && ( +
+ + Codestral Base URL (Optional) + +

+ Set alternative URL for Codestral model: https://api.mistral.ai +

+
+ )} )}