Skip to content

Commit 5224e56

Browse files
committed
Add support for DocumentUrl, and downloading an dummy pdf with casette
1 parent fe93dad commit 5224e56

File tree

6 files changed

+515
-21
lines changed

6 files changed

+515
-21
lines changed

examples/pydantic_ai_examples/stock_analysis_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# The model will automatically use XaiProvider with the API key from the environment
2323

2424
# Create the model using XaiModel with server-side tools
25-
model = XaiModel('grok-4-fast')
25+
model = XaiModel('grok-4-1-fast-non-reasoning')
2626

2727

2828
class StockAnalysis(BaseModel):

pydantic_ai_slim/pydantic_ai/models/xai.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
# Import xai_sdk components
1414
from xai_sdk import AsyncClient
15-
from xai_sdk.chat import assistant, image, system, tool, tool_result, user
15+
from xai_sdk.chat import assistant, file, image, system, tool, tool_result, user
1616
from xai_sdk.tools import code_execution, get_tool_call_type, mcp, web_search # x_search not yet supported
1717
except ImportError as _import_error:
1818
raise ImportError(
@@ -25,9 +25,12 @@
2525
from ..builtin_tools import CodeExecutionTool, MCPServerTool, WebSearchTool
2626
from ..exceptions import UserError
2727
from ..messages import (
28+
AudioUrl,
2829
BinaryContent,
2930
BuiltinToolCallPart,
3031
BuiltinToolReturnPart,
32+
CachePoint,
33+
DocumentUrl,
3134
FinishReason,
3235
ImageUrl,
3336
ModelMessage,
@@ -43,11 +46,13 @@
4346
ToolCallPart,
4447
ToolReturnPart,
4548
UserPromptPart,
49+
VideoUrl,
4650
)
4751
from ..models import (
4852
Model,
4953
ModelRequestParameters,
5054
StreamedResponse,
55+
download_item,
5156
)
5257
from ..profiles import ModelProfileSpec
5358
from ..providers import Provider, infer_provider
@@ -100,28 +105,28 @@ def system(self) -> str:
100105
"""The model provider."""
101106
return 'xai'
102107

103-
def _map_messages(self, messages: list[ModelMessage]) -> list[chat_types.chat_pb2.Message]:
108+
async def _map_messages(self, messages: list[ModelMessage]) -> list[chat_types.chat_pb2.Message]:
104109
"""Convert pydantic_ai messages to xAI SDK messages."""
105110
xai_messages: list[chat_types.chat_pb2.Message] = []
106111

107112
for message in messages:
108113
if isinstance(message, ModelRequest):
109-
xai_messages.extend(self._map_request_parts(message.parts))
114+
xai_messages.extend(await self._map_request_parts(message.parts))
110115
elif isinstance(message, ModelResponse):
111116
if response_msg := self._map_response_parts(message.parts):
112117
xai_messages.append(response_msg)
113118

114119
return xai_messages
115120

116-
def _map_request_parts(self, parts: Sequence[ModelRequestPart]) -> list[chat_types.chat_pb2.Message]:
121+
async def _map_request_parts(self, parts: Sequence[ModelRequestPart]) -> list[chat_types.chat_pb2.Message]:
117122
"""Map ModelRequest parts to xAI messages."""
118123
xai_messages: list[chat_types.chat_pb2.Message] = []
119124

120125
for part in parts:
121126
if isinstance(part, SystemPromptPart):
122127
xai_messages.append(system(part.content))
123128
elif isinstance(part, UserPromptPart):
124-
if user_msg := self._map_user_prompt(part):
129+
if user_msg := await self._map_user_prompt(part):
125130
xai_messages.append(user_msg)
126131
elif isinstance(part, ToolReturnPart):
127132
xai_messages.append(tool_result(part.model_response_str()))
@@ -137,7 +142,20 @@ def _map_request_parts(self, parts: Sequence[ModelRequestPart]) -> list[chat_typ
137142

138143
return xai_messages
139144

140-
def _map_user_prompt(self, part: UserPromptPart) -> chat_types.chat_pb2.Message | None:
145+
async def _upload_file_to_xai(self, data: bytes, filename: str) -> str:
146+
"""Upload a file to xAI files API and return the file ID.
147+
148+
Args:
149+
data: The file content as bytes
150+
filename: The filename to use for the upload
151+
152+
Returns:
153+
The file ID from xAI
154+
"""
155+
uploaded_file = await self._provider.client.files.upload(data, filename=filename)
156+
return uploaded_file.id
157+
158+
async def _map_user_prompt(self, part: UserPromptPart) -> chat_types.chat_pb2.Message | None: # noqa: C901
141159
"""Map a UserPromptPart to an xAI user message."""
142160
if isinstance(part.content, str):
143161
return user(part.content)
@@ -158,9 +176,33 @@ def _map_user_prompt(self, part: UserPromptPart) -> chat_types.chat_pb2.Message
158176
if item.is_image:
159177
# Convert binary content to data URI and use image()
160178
content_items.append(image(item.data_uri, detail='auto'))
161-
else:
162-
# xAI SDK doesn't support non-image binary content yet
163-
pass
179+
elif item.is_audio:
180+
raise NotImplementedError('AudioUrl/BinaryContent with audio is not supported by xAI SDK')
181+
elif item.is_document:
182+
# Upload document to xAI files API and reference it
183+
filename = item.identifier or f'document.{item.format}'
184+
file_id = await self._upload_file_to_xai(item.data, filename)
185+
content_items.append(file(file_id))
186+
else: # pragma: no cover
187+
raise RuntimeError(f'Unsupported binary content type: {item.media_type}')
188+
elif isinstance(item, AudioUrl):
189+
raise NotImplementedError('AudioUrl is not supported by xAI SDK')
190+
elif isinstance(item, DocumentUrl):
191+
# Download and upload to xAI files API
192+
downloaded = await download_item(item, data_format='bytes')
193+
filename = item.identifier or 'document'
194+
if 'data_type' in downloaded and downloaded['data_type']:
195+
filename = f'{filename}.{downloaded["data_type"]}'
196+
197+
file_id = await self._upload_file_to_xai(downloaded['data'], filename)
198+
content_items.append(file(file_id))
199+
elif isinstance(item, VideoUrl):
200+
raise NotImplementedError('VideoUrl is not supported by xAI SDK')
201+
elif isinstance(item, CachePoint):
202+
# xAI doesn't support prompt caching via CachePoint, so we filter it out
203+
pass
204+
else:
205+
assert_never(item)
164206

165207
if content_items:
166208
return user(*content_items)
@@ -225,7 +267,7 @@ async def request(
225267
client = self._provider.client
226268

227269
# Convert messages to xAI format
228-
xai_messages = self._map_messages(messages)
270+
xai_messages = await self._map_messages(messages)
229271

230272
# Convert tools: combine built-in (server-side) tools and custom (client-side) tools
231273
tools: list[chat_types.chat_pb2.Tool] = []
@@ -277,7 +319,7 @@ async def request_stream(
277319
client = self._provider.client
278320

279321
# Convert messages to xAI format
280-
xai_messages = self._map_messages(messages)
322+
xai_messages = await self._map_messages(messages)
281323

282324
# Convert tools: combine built-in (server-side) tools and custom (client-side) tools
283325
tools: list[chat_types.chat_pb2.Tool] = []

0 commit comments

Comments
 (0)