Skip to content

Commit 8ace0fa

Browse files
ViicosKludex
andauthored
Add Mistral provider (#1118)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 8ddd05c commit 8ace0fa

File tree

13 files changed

+247
-67
lines changed

13 files changed

+247
-67
lines changed

docs/models.md

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ pip/uv-add 'pydantic-ai-slim[mistral]'
507507

508508
To use [Mistral](https://mistral.ai) through their API, go to [console.mistral.ai/api-keys/](https://console.mistral.ai/api-keys/) and follow your nose until you find the place to generate an API key.
509509

510-
[`MistralModelName`][pydantic_ai.models.mistral.MistralModelName] contains a list of the most popular Mistral models.
510+
[`LatestMistralModelNames`][pydantic_ai.models.mistral.LatestMistralModelNames] contains a list of the most popular Mistral models.
511511

512512
### Environment variable
513513

@@ -537,15 +537,37 @@ agent = Agent(model)
537537
...
538538
```
539539

540-
### `api_key` argument
540+
### `provider` argument
541+
542+
You can provide a custom [`Provider`][pydantic_ai.providers.Provider] via the
543+
[`provider` argument][pydantic_ai.models.mistral.MistralModel.__init__]:
544+
545+
```python {title="groq_model_provider.py"}
546+
from pydantic_ai import Agent
547+
from pydantic_ai.models.mistral import MistralModel
548+
from pydantic_ai.providers.mistral import MistralProvider
549+
550+
model = MistralModel(
551+
'mistral-large-latest', provider=MistralProvider(api_key='your-api-key')
552+
)
553+
agent = Agent(model)
554+
...
555+
```
541556

542-
If you don't want to or can't set the environment variable, you can pass it at runtime via the [`api_key` argument][pydantic_ai.models.mistral.MistralModel.__init__]:
557+
You can also customize the provider with a custom `httpx.AsyncHTTPClient`:
558+
559+
```python {title="groq_model_custom_provider.py"}
560+
from httpx import AsyncClient
543561

544-
```python {title="mistral_model_api_key.py"}
545562
from pydantic_ai import Agent
546563
from pydantic_ai.models.mistral import MistralModel
564+
from pydantic_ai.providers.mistral import MistralProvider
547565

548-
model = MistralModel('mistral-small-latest', api_key='your-api-key')
566+
custom_http_client = AsyncClient(timeout=30)
567+
model = MistralModel(
568+
'mistral-large-latest',
569+
provider=MistralProvider(api_key='your-api-key', http_client=custom_http_client),
570+
)
549571
agent = Agent(model)
550572
...
551573
```

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,9 @@ def __init__(
139139

140140
if provider is not None:
141141
if isinstance(provider, str):
142-
self._system = provider
143-
self.client = infer_provider(provider).client
144-
else:
145-
self._system = provider.name
146-
self.client = provider.client
142+
provider = infer_provider(provider)
143+
self._system = provider.name
144+
self.client = provider.client
147145
self._url = str(self.client.base_url)
148146
else:
149147
if api_key is None:

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,8 @@ def __init__(
138138

139139
if provider is not None:
140140
if isinstance(provider, str):
141-
self.client = infer_provider(provider).client
142-
else:
143-
self.client = provider.client
141+
provider = infer_provider(provider)
142+
self.client = provider.client
144143
elif groq_client is not None:
145144
assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
146145
assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
from dataclasses import dataclass, field
88
from datetime import datetime, timezone
99
from itertools import chain
10-
from typing import Any, Callable, Literal, Union, cast
10+
from typing import Any, Callable, Literal, Union, cast, overload
1111

1212
import pydantic_core
1313
from httpx import AsyncClient as AsyncHTTPClient, Timeout
14-
from typing_extensions import assert_never
14+
from typing_extensions import assert_never, deprecated
1515

1616
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
1717
from .._utils import now_utc as _now_utc
@@ -31,6 +31,7 @@
3131
ToolReturnPart,
3232
UserPromptPart,
3333
)
34+
from ..providers import Provider, infer_provider
3435
from ..result import Usage
3536
from ..settings import ModelSettings
3637
from ..tools import ToolDefinition
@@ -112,10 +113,33 @@ class MistralModel(Model):
112113
_model_name: MistralModelName = field(repr=False)
113114
_system: str = field(default='mistral_ai', repr=False)
114115

116+
@overload
115117
def __init__(
116118
self,
117119
model_name: MistralModelName,
118120
*,
121+
provider: Literal['mistral'] | Provider[Mistral] = 'mistral',
122+
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
123+
) -> None: ...
124+
125+
@overload
126+
@deprecated('Use the `provider` parameter instead of `api_key`, `client` and `http_client`.')
127+
def __init__(
128+
self,
129+
model_name: MistralModelName,
130+
*,
131+
provider: None = None,
132+
api_key: str | Callable[[], str | None] | None = None,
133+
client: Mistral | None = None,
134+
http_client: AsyncHTTPClient | None = None,
135+
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
136+
) -> None: ...
137+
138+
def __init__(
139+
self,
140+
model_name: MistralModelName,
141+
*,
142+
provider: Literal['mistral'] | Provider[Mistral] | None = None,
119143
api_key: str | Callable[[], str | None] | None = None,
120144
client: Mistral | None = None,
121145
http_client: AsyncHTTPClient | None = None,
@@ -124,6 +148,9 @@ def __init__(
124148
"""Initialize a Mistral model.
125149
126150
Args:
151+
provider: The provider to use for authentication and API access. Can be either the string
152+
'mistral' or an instance of `Provider[Mistral]`. If not provided, a new provider will be
153+
created using the other parameters.
127154
model_name: The name of the model to use.
128155
api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable.
129156
client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
@@ -133,17 +160,22 @@ def __init__(
133160
self._model_name = model_name
134161
self.json_mode_schema_prompt = json_mode_schema_prompt
135162

136-
if client is not None:
163+
if provider is not None:
164+
if isinstance(provider, str):
165+
# TODO(Marcelo): We should add an integration test with VCR when I get the API key.
166+
provider = infer_provider(provider) # pragma: no cover
167+
self.client = provider.client
168+
elif client is not None:
137169
assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
138170
assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
139171
self.client = client
140172
else:
141-
api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
173+
api_key = api_key or os.getenv('MISTRAL_API_KEY')
142174
self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
143175

144176
@property
145177
def base_url(self) -> str:
146-
return str(self.client.sdk_configuration.get_server_details()[0])
178+
return self.client.sdk_configuration.get_server_details()[0]
147179

148180
async def request(
149181
self,

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,8 @@ def __init__(
162162

163163
if provider is not None:
164164
if isinstance(provider, str):
165-
self.client = infer_provider(provider).client
166-
else:
167-
self.client = provider.client
165+
provider = infer_provider(provider)
166+
self.client = provider.client
168167
else: # pragma: no cover
169168
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
170169
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,13 @@ def infer_provider(provider: str) -> Provider[Any]:
6969
from .groq import GroqProvider
7070

7171
return GroqProvider()
72-
elif provider == 'anthropic': # pragma: no cover
72+
elif provider == 'anthropic':
7373
from .anthropic import AnthropicProvider
7474

7575
return AnthropicProvider()
76+
elif provider == 'mistral':
77+
from .mistral import MistralProvider
78+
79+
return MistralProvider()
7680
else: # pragma: no cover
7781
raise ValueError(f'Unknown provider: {provider}')
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import overload
5+
6+
from httpx import AsyncClient as AsyncHTTPClient
7+
8+
from pydantic_ai.models import cached_async_http_client
9+
10+
try:
11+
from mistralai import Mistral
12+
except ImportError as e: # pragma: no cover
13+
raise ImportError(
14+
'Please install the `mistral` package to use the Mistral provider, '
15+
"you can use the `mistral` optional group — `pip install 'pydantic-ai-slim[mistral]'`"
16+
) from e
17+
18+
19+
from . import Provider
20+
21+
22+
class MistralProvider(Provider[Mistral]):
23+
"""Provider for Mistral API."""
24+
25+
@property
26+
def name(self) -> str:
27+
return 'mistral'
28+
29+
@property
30+
def base_url(self) -> str:
31+
return self.client.sdk_configuration.get_server_details()[0]
32+
33+
@property
34+
def client(self) -> Mistral:
35+
return self._client
36+
37+
@overload
38+
def __init__(self, *, mistral_client: Mistral | None = None) -> None: ...
39+
40+
@overload
41+
def __init__(self, *, api_key: str | None = None, http_client: AsyncHTTPClient | None = None) -> None: ...
42+
43+
def __init__(
44+
self,
45+
*,
46+
api_key: str | None = None,
47+
mistral_client: Mistral | None = None,
48+
http_client: AsyncHTTPClient | None = None,
49+
) -> None:
50+
"""Create a new Mistral provider.
51+
52+
Args:
53+
api_key: The API key to use for authentication, if not provided, the `MISTRAL_API_KEY` environment variable
54+
will be used if available.
55+
mistral_client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
56+
http_client: An existing async client to use for making HTTP requests.
57+
"""
58+
api_key = api_key or os.environ.get('MISTRAL_API_KEY')
59+
60+
if api_key is None and mistral_client is None:
61+
raise ValueError(
62+
'Set the `MISTRAL_API_KEY` environment variable or pass it via `MistralProvider(api_key=...)`'
63+
'to use the Mistral provider.'
64+
)
65+
66+
if mistral_client is not None:
67+
assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
68+
assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
69+
self._client = mistral_client
70+
elif http_client is not None:
71+
self._client = Mistral(api_key=api_key, async_client=http_client)
72+
else:
73+
self._client = Mistral(api_key=api_key, async_client=cached_async_http_client())

0 commit comments

Comments
 (0)