Skip to content

Commit 2b1e3f5

Browse files
committed
Add Mistral API provider
1 parent 06146d5 commit 2b1e3f5

File tree

10 files changed

+248
-2
lines changed

10 files changed

+248
-2
lines changed

package-lock.json

Lines changed: 11 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169
"@anthropic-ai/sdk": "^0.26.0",
170170
"@anthropic-ai/vertex-sdk": "^0.4.1",
171171
"@google/generative-ai": "^0.18.0",
172+
"@mistralai/mistralai": "^1.3.6",
172173
"@modelcontextprotocol/sdk": "^1.0.1",
173174
"@types/clone-deep": "^4.0.4",
174175
"@types/get-folder-size": "^3.0.4",

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { GeminiHandler } from "./providers/gemini"
1111
import { OpenAiNativeHandler } from "./providers/openai-native"
1212
import { ApiStream } from "./transform/stream"
1313
import { DeepSeekHandler } from "./providers/deepseek"
14+
import { MistralHandler } from "./providers/mistral"
1415

1516
export interface ApiHandler {
1617
createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
@@ -40,6 +41,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
4041
return new OpenAiNativeHandler(options)
4142
case "deepseek":
4243
return new DeepSeekHandler(options)
44+
case "mistral":
45+
return new MistralHandler(options)
4346
default:
4447
return new AnthropicHandler(options)
4548
}

src/api/providers/mistral.ts

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import { Anthropic } from "@anthropic-ai/sdk"
2+
import { Mistral } from "@mistralai/mistralai"
3+
import { ApiHandler } from "../"
4+
import {
5+
ApiHandlerOptions,
6+
mistralDefaultModelId,
7+
MistralModelId,
8+
mistralModels,
9+
ModelInfo,
10+
openAiNativeDefaultModelId,
11+
OpenAiNativeModelId,
12+
openAiNativeModels,
13+
} from "../../shared/api"
14+
import { convertToMistralMessages } from "../transform/mistral-format"
15+
import { ApiStream } from "../transform/stream"
16+
17+
export class MistralHandler implements ApiHandler {
18+
private options: ApiHandlerOptions
19+
private client: Mistral
20+
21+
constructor(options: ApiHandlerOptions) {
22+
this.options = options
23+
this.client = new Mistral({
24+
serverURL: "https://codestral.mistral.ai",
25+
apiKey: this.options.mistralApiKey,
26+
})
27+
}
28+
29+
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
30+
const stream = await this.client.chat.stream({
31+
model: this.getModel().id,
32+
// max_completion_tokens: this.getModel().info.maxTokens,
33+
temperature: 0,
34+
messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)],
35+
stream: true,
36+
})
37+
38+
for await (const chunk of stream) {
39+
const delta = chunk.data.choices[0]?.delta
40+
if (delta?.content) {
41+
let content: string = ""
42+
if (typeof delta.content === "string") {
43+
content = delta.content
44+
} else if (Array.isArray(delta.content)) {
45+
content = delta.content.map((c) => (c.type === "text" ? c.text : "")).join("")
46+
}
47+
yield {
48+
type: "text",
49+
text: content,
50+
}
51+
}
52+
53+
if (chunk.data.usage) {
54+
yield {
55+
type: "usage",
56+
inputTokens: chunk.data.usage.promptTokens || 0,
57+
outputTokens: chunk.data.usage.completionTokens || 0,
58+
}
59+
}
60+
}
61+
}
62+
63+
getModel(): { id: MistralModelId; info: ModelInfo } {
64+
const modelId = this.options.apiModelId
65+
if (modelId && modelId in mistralModels) {
66+
const id = modelId as MistralModelId
67+
return { id, info: mistralModels[id] }
68+
}
69+
return {
70+
id: mistralDefaultModelId,
71+
info: mistralModels[mistralDefaultModelId],
72+
}
73+
}
74+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import { Anthropic } from "@anthropic-ai/sdk"
2+
import { Mistral } from "@mistralai/mistralai"
3+
import { AssistantMessage } from "@mistralai/mistralai/models/components/assistantmessage"
4+
import { SystemMessage } from "@mistralai/mistralai/models/components/systemmessage"
5+
import { ToolMessage } from "@mistralai/mistralai/models/components/toolmessage"
6+
import { UserMessage } from "@mistralai/mistralai/models/components/usermessage"
7+
8+
export type MistralMessage =
9+
| (SystemMessage & { role: "system" })
10+
| (UserMessage & { role: "user" })
11+
| (AssistantMessage & { role: "assistant" })
12+
| (ToolMessage & { role: "tool" })
13+
14+
export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): MistralMessage[] {
15+
const mistralMessages: MistralMessage[] = []
16+
for (const anthropicMessage of anthropicMessages) {
17+
if (typeof anthropicMessage.content === "string") {
18+
mistralMessages.push({
19+
role: anthropicMessage.role,
20+
content: anthropicMessage.content,
21+
})
22+
} else {
23+
if (anthropicMessage.role === "user") {
24+
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
25+
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
26+
toolMessages: Anthropic.ToolResultBlockParam[]
27+
}>(
28+
(acc, part) => {
29+
if (part.type === "tool_result") {
30+
acc.toolMessages.push(part)
31+
} else if (part.type === "text" || part.type === "image") {
32+
acc.nonToolMessages.push(part)
33+
} // user cannot send tool_use messages
34+
return acc
35+
},
36+
{ nonToolMessages: [], toolMessages: [] },
37+
)
38+
39+
if (nonToolMessages.length > 0) {
40+
mistralMessages.push({
41+
role: "user",
42+
content: nonToolMessages.map((part) => {
43+
if (part.type === "image") {
44+
return {
45+
type: "image_url",
46+
imageUrl: {
47+
url: `data:${part.source.media_type};base64,${part.source.data}`,
48+
},
49+
}
50+
}
51+
return { type: "text", text: part.text }
52+
}),
53+
})
54+
}
55+
} else if (anthropicMessage.role === "assistant") {
56+
const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{
57+
nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[]
58+
toolMessages: Anthropic.ToolUseBlockParam[]
59+
}>(
60+
(acc, part) => {
61+
if (part.type === "tool_use") {
62+
acc.toolMessages.push(part)
63+
} else if (part.type === "text" || part.type === "image") {
64+
acc.nonToolMessages.push(part)
65+
} // assistant cannot send tool_result messages
66+
return acc
67+
},
68+
{ nonToolMessages: [], toolMessages: [] },
69+
)
70+
71+
let content: string | undefined
72+
if (nonToolMessages.length > 0) {
73+
content = nonToolMessages
74+
.map((part) => {
75+
if (part.type === "image") {
76+
return "" // impossible as the assistant cannot send images
77+
}
78+
return part.text
79+
})
80+
.join("\n")
81+
}
82+
83+
mistralMessages.push({
84+
role: "assistant",
85+
content,
86+
})
87+
}
88+
}
89+
}
90+
91+
return mistralMessages
92+
}

src/core/webview/ClineProvider.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ type SecretKey =
4141
| "geminiApiKey"
4242
| "openAiNativeApiKey"
4343
| "deepSeekApiKey"
44+
| "mistralApiKey"
4445
type GlobalStateKey =
4546
| "apiProvider"
4647
| "apiModelId"
@@ -392,6 +393,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
392393
geminiApiKey,
393394
openAiNativeApiKey,
394395
deepSeekApiKey,
396+
mistralApiKey,
395397
azureApiVersion,
396398
openRouterModelId,
397399
openRouterModelInfo,
@@ -418,6 +420,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
418420
await this.storeSecret("geminiApiKey", geminiApiKey)
419421
await this.storeSecret("openAiNativeApiKey", openAiNativeApiKey)
420422
await this.storeSecret("deepSeekApiKey", deepSeekApiKey)
423+
await this.storeSecret("mistralApiKey", mistralApiKey)
421424
await this.updateGlobalState("azureApiVersion", azureApiVersion)
422425
await this.updateGlobalState("openRouterModelId", openRouterModelId)
423426
await this.updateGlobalState("openRouterModelInfo", openRouterModelInfo)
@@ -1023,6 +1026,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
10231026
geminiApiKey,
10241027
openAiNativeApiKey,
10251028
deepSeekApiKey,
1029+
mistralApiKey,
10261030
azureApiVersion,
10271031
openRouterModelId,
10281032
openRouterModelInfo,
@@ -1054,6 +1058,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
10541058
this.getSecret("geminiApiKey") as Promise<string | undefined>,
10551059
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
10561060
this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
1061+
this.getSecret("mistralApiKey") as Promise<string | undefined>,
10571062
this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
10581063
this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
10591064
this.getGlobalState("openRouterModelInfo") as Promise<ModelInfo | undefined>,
@@ -1102,6 +1107,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
11021107
geminiApiKey,
11031108
openAiNativeApiKey,
11041109
deepSeekApiKey,
1110+
mistralApiKey,
11051111
azureApiVersion,
11061112
openRouterModelId,
11071113
openRouterModelInfo,
@@ -1187,6 +1193,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
11871193
"geminiApiKey",
11881194
"openAiNativeApiKey",
11891195
"deepSeekApiKey",
1196+
"mistralApiKey",
11901197
]
11911198
for (const key of secretKeys) {
11921199
await this.storeSecret(key, undefined)

src/shared/api.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ export type ApiProvider =
99
| "gemini"
1010
| "openai-native"
1111
| "deepseek"
12+
| "mistral"
1213

1314
export interface ApiHandlerOptions {
1415
apiModelId?: string
@@ -34,6 +35,7 @@ export interface ApiHandlerOptions {
3435
geminiApiKey?: string
3536
openAiNativeApiKey?: string
3637
deepSeekApiKey?: string
38+
mistralApiKey?: string
3739
azureApiVersion?: string
3840
}
3941

@@ -374,3 +376,18 @@ export const deepSeekModels = {
374376
cacheReadsPrice: 0.014,
375377
},
376378
} as const satisfies Record<string, ModelInfo>
379+
380+
// Mistral
381+
// https://docs.mistral.ai/getting-started/models/models_overview/
382+
export type MistralModelId = keyof typeof mistralModels
383+
export const mistralDefaultModelId: MistralModelId = "codestral-latest"
384+
export const mistralModels = {
385+
"codestral-latest": {
386+
maxTokens: 32_768,
387+
contextWindow: 256_000,
388+
supportsImages: false,
389+
supportsPromptCache: false,
390+
inputPrice: 0.3,
391+
outputPrice: 0.9,
392+
},
393+
} as const satisfies Record<string, ModelInfo>

webview-ui/src/components/settings/ApiOptions.tsx

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import {
2121
deepSeekModels,
2222
geminiDefaultModelId,
2323
geminiModels,
24+
mistralDefaultModelId,
25+
mistralModels,
2426
openAiModelInfoSaneDefaults,
2527
openAiNativeDefaultModelId,
2628
openAiNativeModels,
@@ -142,6 +144,7 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
142144
<VSCodeOption value="anthropic">Anthropic</VSCodeOption>
143145
<VSCodeOption value="gemini">Google Gemini</VSCodeOption>
144146
<VSCodeOption value="deepseek">DeepSeek</VSCodeOption>
147+
<VSCodeOption value="mistral">Mistral</VSCodeOption>
145148
<VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption>
146149
<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
147150
<VSCodeOption value="openai-native">OpenAI</VSCodeOption>
@@ -270,6 +273,37 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
270273
</div>
271274
)}
272275

276+
{selectedProvider === "mistral" && (
277+
<div>
278+
<VSCodeTextField
279+
value={apiConfiguration?.mistralApiKey || ""}
280+
style={{ width: "100%" }}
281+
type="password"
282+
onInput={handleInputChange("mistralApiKey")}
283+
placeholder="Enter API Key...">
284+
<span style={{ fontWeight: 500 }}>Mistral API Key</span>
285+
</VSCodeTextField>
286+
<p
287+
style={{
288+
fontSize: "12px",
289+
marginTop: 3,
290+
color: "var(--vscode-descriptionForeground)",
291+
}}>
292+
This key is stored locally and only used to make API requests from this extension.
293+
{!apiConfiguration?.mistralApiKey && (
294+
<VSCodeLink
295+
href="https://console.mistral.ai/codestral/"
296+
style={{
297+
display: "inline",
298+
fontSize: "inherit",
299+
}}>
300+
You can get a Mistral API key by signing up here.
301+
</VSCodeLink>
302+
)}
303+
</p>
304+
</div>
305+
)}
306+
273307
{selectedProvider === "openrouter" && (
274308
<div>
275309
<VSCodeTextField
@@ -697,6 +731,7 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage, modelIdErrorMessage }:
697731
{selectedProvider === "gemini" && createDropdown(geminiModels)}
698732
{selectedProvider === "openai-native" && createDropdown(openAiNativeModels)}
699733
{selectedProvider === "deepseek" && createDropdown(deepSeekModels)}
734+
{selectedProvider === "mistral" && createDropdown(mistralModels)}
700735
</div>
701736

702737
<ModelInfoView
@@ -893,6 +928,8 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
893928
return getProviderData(openAiNativeModels, openAiNativeDefaultModelId)
894929
case "deepseek":
895930
return getProviderData(deepSeekModels, deepSeekDefaultModelId)
931+
case "mistral":
932+
return getProviderData(mistralModels, mistralDefaultModelId)
896933
case "openrouter":
897934
return {
898935
selectedProvider: provider,

0 commit comments

Comments
 (0)