Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions src/api/providers/__tests__/mistral.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import { MistralHandler } from "../mistral"
import { ApiHandlerOptions, mistralDefaultModelId } from "../../../shared/api"
import { Anthropic } from "@anthropic-ai/sdk"
import { ApiStreamTextChunk } from "../../transform/stream"

// Mock Mistral client
const mockCreate = jest.fn()
jest.mock("@mistralai/mistralai", () => {
return {
Mistral: jest.fn().mockImplementation(() => ({
chat: {
stream: mockCreate.mockImplementation(async (options) => {
const stream = {
[Symbol.asyncIterator]: async function* () {
yield {
data: {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
},
}
},
}
return stream
}),
},
})),
}
})

describe("MistralHandler", () => {
let handler: MistralHandler
let mockOptions: ApiHandlerOptions

beforeEach(() => {
mockOptions = {
apiModelId: "codestral-latest", // Update to match the actual model ID
mistralApiKey: "test-api-key",
includeMaxTokens: true,
modelTemperature: 0,
}
handler = new MistralHandler(mockOptions)
mockCreate.mockClear()
})

describe("constructor", () => {
it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(MistralHandler)
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
})

it("should throw error if API key is missing", () => {
expect(() => {
new MistralHandler({
...mockOptions,
mistralApiKey: undefined,
})
}).toThrow("Mistral API key is required")
})

it("should use custom base URL if provided", () => {
const customBaseUrl = "https://custom.mistral.ai/v1"
const handlerWithCustomUrl = new MistralHandler({
...mockOptions,
mistralCodestralUrl: customBaseUrl,
})
expect(handlerWithCustomUrl).toBeInstanceOf(MistralHandler)
})
})

describe("getModel", () => {
it("should return correct model info", () => {
const model = handler.getModel()
expect(model.id).toBe(mockOptions.apiModelId)
expect(model.info).toBeDefined()
expect(model.info.supportsPromptCache).toBe(false)
})
})

describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [{ type: "text", text: "Hello!" }],
},
]

it("should create message successfully", async () => {
const iterator = handler.createMessage(systemPrompt, messages)
const result = await iterator.next()

expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.apiModelId,
messages: expect.any(Array),
maxTokens: expect.any(Number),
temperature: 0,
})

expect(result.value).toBeDefined()
expect(result.done).toBe(false)
})

it("should handle streaming response correctly", async () => {
const iterator = handler.createMessage(systemPrompt, messages)
const results: ApiStreamTextChunk[] = []

for await (const chunk of iterator) {
if ("text" in chunk) {
results.push(chunk as ApiStreamTextChunk)
}
}

expect(results.length).toBeGreaterThan(0)
expect(results[0].text).toBe("Test response")
})

it("should handle errors gracefully", async () => {
mockCreate.mockRejectedValueOnce(new Error("API Error"))
await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error")
})
})
})
27 changes: 20 additions & 7 deletions src/api/providers/mistral.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,36 @@ export class MistralHandler implements ApiHandler {
private client: Mistral

constructor(options: ApiHandlerOptions) {
if (!options.mistralApiKey) {
throw new Error("Mistral API key is required")
}

this.options = options
const baseUrl = this.getBaseUrl()
console.debug(`[Roo Code] MistralHandler using baseUrl: ${baseUrl}`)
this.client = new Mistral({
serverURL: "https://codestral.mistral.ai",
serverURL: baseUrl,
apiKey: this.options.mistralApiKey,
})
}

private getBaseUrl(): string {
const modelId = this.options.apiModelId
if (modelId?.startsWith("codestral-")) {
return this.options.mistralCodestralUrl || "https://codestral.mistral.ai"
}
return "https://api.mistral.ai"
}

async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const stream = await this.client.chat.stream({
model: this.getModel().id,
// max_completion_tokens: this.getModel().info.maxTokens,
const response = await this.client.chat.stream({
model: this.options.apiModelId || mistralDefaultModelId,
messages: convertToMistralMessages(messages),
maxTokens: this.options.includeMaxTokens ? this.getModel().info.maxTokens : undefined,
temperature: this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE,
messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)],
stream: true,
})

for await (const chunk of stream) {
for await (const chunk of response) {
const delta = chunk.data.choices[0]?.delta
if (delta?.content) {
let content: string = ""
Expand Down
6 changes: 6 additions & 0 deletions src/core/webview/ClineProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ type GlobalStateKey =
| "requestyModelInfo"
| "unboundModelInfo"
| "modelTemperature"
| "mistralCodestralUrl"
| "maxOpenTabsContext"

export const GlobalFileNames = {
Expand Down Expand Up @@ -1637,6 +1638,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
openRouterUseMiddleOutTransform,
vsCodeLmModelSelector,
mistralApiKey,
mistralCodestralUrl,
unboundApiKey,
unboundModelId,
unboundModelInfo,
Expand Down Expand Up @@ -1682,6 +1684,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform)
await this.updateGlobalState("vsCodeLmModelSelector", vsCodeLmModelSelector)
await this.storeSecret("mistralApiKey", mistralApiKey)
await this.updateGlobalState("mistralCodestralUrl", mistralCodestralUrl)
await this.storeSecret("unboundApiKey", unboundApiKey)
await this.updateGlobalState("unboundModelId", unboundModelId)
await this.updateGlobalState("unboundModelInfo", unboundModelInfo)
Expand Down Expand Up @@ -2521,6 +2524,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
openAiNativeApiKey,
deepSeekApiKey,
mistralApiKey,
mistralCodestralUrl,
azureApiVersion,
openAiStreamingEnabled,
openRouterModelId,
Expand Down Expand Up @@ -2602,6 +2606,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
this.getSecret("mistralApiKey") as Promise<string | undefined>,
this.getGlobalState("mistralCodestralUrl") as Promise<string | undefined>,
this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
this.getGlobalState("openAiStreamingEnabled") as Promise<boolean | undefined>,
this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
Expand Down Expand Up @@ -2700,6 +2705,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
openAiNativeApiKey,
deepSeekApiKey,
mistralApiKey,
mistralCodestralUrl,
azureApiVersion,
openAiStreamingEnabled,
openRouterModelId,
Expand Down
43 changes: 42 additions & 1 deletion src/shared/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ export interface ApiHandlerOptions {
geminiApiKey?: string
openAiNativeApiKey?: string
mistralApiKey?: string
mistralCodestralUrl?: string // New option for Codestral URL
azureApiVersion?: string
openRouterUseMiddleOutTransform?: boolean
openAiStreamingEnabled?: boolean
Expand Down Expand Up @@ -670,13 +671,53 @@ export type MistralModelId = keyof typeof mistralModels
export const mistralDefaultModelId: MistralModelId = "codestral-latest"
export const mistralModels = {
"codestral-latest": {
maxTokens: 32_768,
maxTokens: 256_000,
contextWindow: 256_000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.3,
outputPrice: 0.9,
},
"mistral-large-latest": {
maxTokens: 131_000,
contextWindow: 131_000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 2.0,
outputPrice: 6.0,
},
"ministral-8b-latest": {
maxTokens: 131_000,
contextWindow: 131_000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.1,
outputPrice: 0.1,
},
"ministral-3b-latest": {
maxTokens: 131_000,
contextWindow: 131_000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.04,
outputPrice: 0.04,
},
"mistral-small-latest": {
maxTokens: 32_000,
contextWindow: 32_000,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.2,
outputPrice: 0.6,
},
"pixtral-large-latest": {
maxTokens: 131_000,
contextWindow: 131_000,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 2.0,
outputPrice: 6.0,
},
} as const satisfies Record<string, ModelInfo>

// Unbound Security
Expand Down
27 changes: 25 additions & 2 deletions webview-ui/src/components/settings/ApiOptions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage, fromWelcomeView }: A
placeholder="Enter API Key...">
<span style={{ fontWeight: 500 }}>Mistral API Key</span>
</VSCodeTextField>

<p
style={{
fontSize: "12px",
Expand All @@ -323,15 +324,37 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage, fromWelcomeView }: A
This key is stored locally and only used to make API requests from this extension.
{!apiConfiguration?.mistralApiKey && (
<VSCodeLink
href="https://console.mistral.ai/codestral/"
href="https://console.mistral.ai/"
style={{
display: "inline",
fontSize: "inherit",
}}>
You can get a Mistral API key by signing up here.
You can get a La Plateforme (api.mistral.ai) / Codestral (codestral.mistral.ai) API key
by signing up here.
</VSCodeLink>
)}
</p>

{apiConfiguration?.apiModelId?.startsWith("codestral-") && (
<div>
<VSCodeTextField
value={apiConfiguration?.mistralCodestralUrl || ""}
style={{ width: "100%", marginTop: "10px" }}
type="url"
onBlur={handleInputChange("mistralCodestralUrl")}
placeholder="Default: https://codestral.mistral.ai">
<span style={{ fontWeight: 500 }}>Codestral Base URL (Optional)</span>
</VSCodeTextField>
<p
style={{
fontSize: "12px",
marginTop: 3,
color: "var(--vscode-descriptionForeground)",
}}>
Set alternative URL for Codestral model: https://api.mistral.ai
</p>
</div>
)}
</div>
)}

Expand Down
Loading