Skip to content

Commit 99ad3aa

Browse files
committed
Renmae GrokModel to XaiModel, and update dependencies and tests. Added XaiProvider and move api_key and client to this
1 parent 843f588 commit 99ad3aa

File tree

17 files changed

+423
-278
lines changed

17 files changed

+423
-278
lines changed

docs/models/grok.md

Lines changed: 0 additions & 68 deletions
This file was deleted.

docs/models/xai.md

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# xAI
2+
3+
## Install
4+
5+
To use [`XaiModel`][pydantic_ai.models.xai.XaiModel], you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `xai` optional group:
6+
7+
```bash
8+
pip/uv-add "pydantic-ai-slim[xai]"
9+
```
10+
11+
## Configuration
12+
13+
To use xAI models from [xAI](https://x.ai/api) through their API, go to [console.x.ai](https://console.x.ai/team/default/api-keys) to create an API key.
14+
15+
[`GrokModelName`][pydantic_ai.providers.grok.GrokModelName] contains a list of available xAI models.
16+
17+
## Environment variable
18+
19+
Once you have the API key, you can set it as an environment variable:
20+
21+
```bash
22+
export XAI_API_KEY='your-api-key'
23+
```
24+
25+
You can then use [`XaiModel`][pydantic_ai.models.xai.XaiModel] by name:
26+
27+
```python
28+
from pydantic_ai import Agent
29+
30+
agent = Agent('xai:grok-4-1-fast-non-reasoning')
31+
...
32+
```
33+
34+
Or initialise the model directly:
35+
36+
```python
37+
from pydantic_ai import Agent
38+
from pydantic_ai.models.xai import XaiModel
39+
40+
# Uses XAI_API_KEY environment variable
41+
model = XaiModel('grok-4-1-fast-non-reasoning')
42+
agent = Agent(model)
43+
...
44+
```
45+
46+
You can also customize the [`XaiModel`][pydantic_ai.models.xai.XaiModel] with a custom provider:
47+
48+
```python
49+
from pydantic_ai import Agent
50+
from pydantic_ai.models.xai import XaiModel
51+
from pydantic_ai.providers.xai import XaiProvider
52+
53+
# Custom API key
54+
provider = XaiProvider(api_key='your-api-key')
55+
model = XaiModel('grok-4-1-fast-non-reasoning', provider=provider)
56+
agent = Agent(model)
57+
...
58+
```
59+
60+
Or with a custom `xai_sdk.AsyncClient`:
61+
62+
```python
63+
from xai_sdk import AsyncClient
64+
from pydantic_ai import Agent
65+
from pydantic_ai.models.xai import XaiModel
66+
from pydantic_ai.providers.xai import XaiProvider
67+
68+
xai_client = AsyncClient(api_key='your-api-key')
69+
provider = XaiProvider(xai_client=xai_client)
70+
model = XaiModel('grok-4-1-fast-non-reasoning', provider=provider)
71+
agent = Agent(model)
72+
...
73+
```

examples/pydantic_ai_examples/stock_analysis_agent.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
2. Provides buy analysis for the user
66
"""
77

8-
import os
9-
108
import logfire
119
from pydantic import BaseModel, Field
1210

@@ -15,19 +13,16 @@
1513
BuiltinToolCallPart,
1614
WebSearchTool,
1715
)
18-
from pydantic_ai.models.grok import GrokModel
16+
from pydantic_ai.models.xai import XaiModel
1917

2018
logfire.configure()
2119
logfire.instrument_pydantic_ai()
2220

23-
# Configure for xAI API
24-
xai_api_key = os.getenv('XAI_API_KEY')
25-
if not xai_api_key:
26-
raise ValueError('XAI_API_KEY environment variable is required')
27-
21+
# Configure for xAI API - XAI_API_KEY environment variable is required
22+
# The model will automatically use XaiProvider with the API key from the environment
2823

29-
# Create the model using GrokModel with server-side tools
30-
model = GrokModel('grok-4-fast', api_key=xai_api_key)
24+
# Create the model using XaiModel with server-side tools
25+
model = XaiModel('grok-4-fast')
3126

3227

3328
class StockAnalysis(BaseModel):

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,19 @@
180180
'grok:grok-4-fast-non-reasoning',
181181
'grok:grok-4-fast-reasoning',
182182
'grok:grok-code-fast-1',
183+
'xai:grok-2-image-1212',
184+
'xai:grok-2-vision-1212',
185+
'xai:grok-3',
186+
'xai:grok-3-fast',
187+
'xai:grok-3-mini',
188+
'xai:grok-3-mini-fast',
189+
'xai:grok-4',
190+
'xai:grok-4-0709',
191+
'xai:grok-4-1-fast-non-reasoning',
192+
'xai:grok-4-1-fast-reasoning',
193+
'xai:grok-4-fast-non-reasoning',
194+
'xai:grok-4-fast-reasoning',
195+
'xai:grok-code-fast-1',
183196
'groq:deepseek-r1-distill-llama-70b',
184197
'groq:deepseek-r1-distill-qwen-32b',
185198
'groq:distil-whisper-large-v3-en',
@@ -809,6 +822,7 @@ def infer_model( # noqa: C901
809822
'fireworks',
810823
'github',
811824
'grok',
825+
'xai',
812826
'heroku',
813827
'moonshotai',
814828
'ollama',

pydantic_ai_slim/pydantic_ai/models/grok.py renamed to pydantic_ai_slim/pydantic_ai/models/xai.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
1-
"""Grok model implementation using xAI SDK."""
1+
"""xAI model implementation using xAI SDK."""
22

3-
import os
43
from collections.abc import AsyncIterator, Iterator, Sequence
54
from contextlib import asynccontextmanager
65
from dataclasses import dataclass
7-
from typing import Any
6+
from typing import Any, Literal
87

9-
import xai_sdk.chat as chat_types
8+
try:
9+
import xai_sdk.chat as chat_types
1010

11-
# Import xai_sdk components
12-
from xai_sdk import AsyncClient
13-
from xai_sdk.chat import assistant, image, system, tool, tool_result, user
14-
from xai_sdk.tools import code_execution, get_tool_call_type, mcp, web_search # x_search not yet supported
11+
# Import xai_sdk components
12+
from xai_sdk import AsyncClient
13+
from xai_sdk.chat import assistant, image, system, tool, tool_result, user
14+
from xai_sdk.tools import code_execution, get_tool_call_type, mcp, web_search # x_search not yet supported
15+
except ImportError as _import_error:
16+
raise ImportError(
17+
'Please install `xai-sdk` to use the xAI model, '
18+
'you can use the `xai` optional group — `pip install "pydantic-ai-slim[xai]"`'
19+
) from _import_error
1520

1621
from .._run_context import RunContext
1722
from .._utils import now_utc
@@ -41,42 +46,46 @@
4146
ModelRequestParameters,
4247
StreamedResponse,
4348
)
49+
from ..profiles import ModelProfileSpec
50+
from ..providers import Provider, infer_provider
51+
from ..providers.grok import GrokModelName
4452
from ..settings import ModelSettings
4553
from ..usage import RequestUsage
4654

55+
# Type alias for consistency
56+
XaiModelName = GrokModelName
4757

48-
class GrokModel(Model):
49-
"""A model that uses the xAI SDK to interact with Grok."""
58+
59+
class XaiModel(Model):
60+
"""A model that uses the xAI SDK to interact with xAI models."""
5061

5162
_model_name: str
52-
_api_key: str
53-
_client: AsyncClient | None
63+
_provider: Provider[AsyncClient]
5464

5565
def __init__(
5666
self,
57-
model_name: str,
67+
model_name: XaiModelName,
5868
*,
59-
api_key: str | None = None,
60-
client: AsyncClient | None = None,
69+
provider: Literal['xai'] | Provider[AsyncClient] = 'xai',
70+
profile: ModelProfileSpec | None = None,
6171
settings: ModelSettings | None = None,
6272
):
63-
"""Initialize the Grok model.
73+
"""Initialize the xAI model.
6474
6575
Args:
66-
model_name: The name of the Grok model to use (e.g., "grok-4-1-fast-non-reasoning")
67-
api_key: The xAI API key. If not provided, uses XAI_API_KEY environment variable.
68-
client: Optional AsyncClient instance for testing. If provided, api_key is ignored.
76+
model_name: The name of the xAI model to use (e.g., "grok-4-1-fast-non-reasoning")
77+
provider: The provider to use for API calls. Defaults to `'xai'`.
78+
profile: Optional model profile specification. Defaults to a profile picked by the provider based on the model name.
6979
settings: Optional model settings.
7080
"""
71-
super().__init__(settings=settings)
7281
self._model_name = model_name
73-
self._client = client
74-
if client is None:
75-
self._api_key = api_key or os.getenv('XAI_API_KEY') or ''
76-
if not self._api_key:
77-
raise ValueError('XAI API key is required')
78-
else:
79-
self._api_key = api_key or ''
82+
83+
if isinstance(provider, str):
84+
provider = infer_provider(provider)
85+
self._provider = provider
86+
self.client = provider.client
87+
88+
super().__init__(settings=settings, profile=profile or provider.model_profile)
8089

8190
@property
8291
def model_name(self) -> str:
@@ -188,7 +197,7 @@ def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -
188197
)
189198
else:
190199
raise UserError(
191-
f'`{builtin_tool.__class__.__name__}` is not supported by `GrokModel`. '
200+
f'`{builtin_tool.__class__.__name__}` is not supported by `XaiModel`. '
192201
f'Supported built-in tools: WebSearchTool, CodeExecutionTool, MCPServerTool. '
193202
f'If XSearchTool should be supported, please file an issue.'
194203
)
@@ -200,9 +209,8 @@ async def request(
200209
model_settings: ModelSettings | None,
201210
model_request_parameters: ModelRequestParameters,
202211
) -> ModelResponse:
203-
"""Make a request to the Grok model."""
204-
# Use injected client or create one in the current async context
205-
client = self._client or AsyncClient(api_key=self._api_key)
212+
"""Make a request to the xAI model."""
213+
client = self._provider.client
206214

207215
# Convert messages to xAI format
208216
xai_messages = self._map_messages(messages)
@@ -253,9 +261,8 @@ async def request_stream(
253261
model_request_parameters: ModelRequestParameters,
254262
run_context: RunContext[Any] | None = None,
255263
) -> AsyncIterator[StreamedResponse]:
256-
"""Make a streaming request to the Grok model."""
257-
# Use injected client or create one in the current async context
258-
client = self._client or AsyncClient(api_key=self._api_key)
264+
"""Make a streaming request to the xAI model."""
265+
client = self._provider.client
259266

260267
# Convert messages to xAI format
261268
xai_messages = self._map_messages(messages)
@@ -294,7 +301,7 @@ async def request_stream(
294301

295302
# Stream the response
296303
response_stream = chat.stream()
297-
streamed_response = GrokStreamedResponse(
304+
streamed_response = XaiStreamedResponse(
298305
model_request_parameters=model_request_parameters,
299306
_model_name=self._model_name,
300307
_response=response_stream,
@@ -409,7 +416,7 @@ def _process_response(self, response: chat_types.Response) -> ModelResponse:
409416

410417
def _map_usage(self, response: chat_types.Response) -> RequestUsage:
411418
"""Extract usage information from xAI SDK response, including reasoning tokens and cache tokens."""
412-
return GrokModel.extract_usage(response)
419+
return XaiModel.extract_usage(response)
413420

414421
@staticmethod
415422
def extract_usage(response: chat_types.Response) -> RequestUsage:
@@ -472,7 +479,7 @@ def extract_usage(response: chat_types.Response) -> RequestUsage:
472479

473480

474481
@dataclass
475-
class GrokStreamedResponse(StreamedResponse):
482+
class XaiStreamedResponse(StreamedResponse):
476483
"""Implementation of `StreamedResponse` for xAI SDK."""
477484

478485
_model_name: str
@@ -486,7 +493,7 @@ def _update_response_state(self, response: Any) -> None:
486493

487494
# Update usage
488495
if hasattr(response, 'usage'):
489-
self._usage = GrokModel.extract_usage(response)
496+
self._usage = XaiModel.extract_usage(response)
490497

491498
# Set provider response ID
492499
if hasattr(response, 'id') and self.provider_response_id is None:

pydantic_ai_slim/pydantic_ai/profiles/grok.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@
77

88
@dataclass(kw_only=True)
99
class GrokModelProfile(ModelProfile):
10-
"""Profile for models used with GroqModel.
10+
"""Profile for Grok models (used with both GrokProvider and XaiProvider).
1111
12-
ALL FIELDS MUST BE `groq_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
12+
ALL FIELDS MUST BE `grok_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
1313
"""
1414

15-
builtin_tool: bool = False
15+
grok_supports_builtin_tools: bool = False
1616
"""Whether the model always has the web search built-in tool available."""
1717

1818

1919
def grok_model_profile(model_name: str) -> ModelProfile | None:
2020
"""Get the model profile for a Grok model."""
2121
return GrokModelProfile(
2222
# Support tool calling for building tools
23-
builtin_tool=model_name.startswith('grok-4'),
23+
grok_supports_builtin_tools=model_name.startswith('grok-4'),
2424
)

0 commit comments

Comments
 (0)