1212
1313 # Import xai_sdk components
1414 from xai_sdk import AsyncClient
15- from xai_sdk .chat import assistant , image , system , tool , tool_result , user
15+ from xai_sdk .chat import assistant , file , image , system , tool , tool_result , user
1616 from xai_sdk .tools import code_execution , get_tool_call_type , mcp , web_search # x_search not yet supported
1717except ImportError as _import_error :
1818 raise ImportError (
2525from ..builtin_tools import CodeExecutionTool , MCPServerTool , WebSearchTool
2626from ..exceptions import UserError
2727from ..messages import (
28+ AudioUrl ,
2829 BinaryContent ,
2930 BuiltinToolCallPart ,
3031 BuiltinToolReturnPart ,
32+ CachePoint ,
33+ DocumentUrl ,
3134 FinishReason ,
3235 ImageUrl ,
3336 ModelMessage ,
4346 ToolCallPart ,
4447 ToolReturnPart ,
4548 UserPromptPart ,
49+ VideoUrl ,
4650)
4751from ..models import (
4852 Model ,
4953 ModelRequestParameters ,
5054 StreamedResponse ,
55+ download_item ,
5156)
5257from ..profiles import ModelProfileSpec
5358from ..providers import Provider , infer_provider
@@ -100,28 +105,28 @@ def system(self) -> str:
100105 """The model provider."""
101106 return 'xai'
102107
103- def _map_messages (self , messages : list [ModelMessage ]) -> list [chat_types .chat_pb2 .Message ]:
108+ async def _map_messages (self , messages : list [ModelMessage ]) -> list [chat_types .chat_pb2 .Message ]:
104109 """Convert pydantic_ai messages to xAI SDK messages."""
105110 xai_messages : list [chat_types .chat_pb2 .Message ] = []
106111
107112 for message in messages :
108113 if isinstance (message , ModelRequest ):
109- xai_messages .extend (self ._map_request_parts (message .parts ))
114+ xai_messages .extend (await self ._map_request_parts (message .parts ))
110115 elif isinstance (message , ModelResponse ):
111116 if response_msg := self ._map_response_parts (message .parts ):
112117 xai_messages .append (response_msg )
113118
114119 return xai_messages
115120
116- def _map_request_parts (self , parts : Sequence [ModelRequestPart ]) -> list [chat_types .chat_pb2 .Message ]:
121+ async def _map_request_parts (self , parts : Sequence [ModelRequestPart ]) -> list [chat_types .chat_pb2 .Message ]:
117122 """Map ModelRequest parts to xAI messages."""
118123 xai_messages : list [chat_types .chat_pb2 .Message ] = []
119124
120125 for part in parts :
121126 if isinstance (part , SystemPromptPart ):
122127 xai_messages .append (system (part .content ))
123128 elif isinstance (part , UserPromptPart ):
124- if user_msg := self ._map_user_prompt (part ):
129+ if user_msg := await self ._map_user_prompt (part ):
125130 xai_messages .append (user_msg )
126131 elif isinstance (part , ToolReturnPart ):
127132 xai_messages .append (tool_result (part .model_response_str ()))
@@ -137,7 +142,20 @@ def _map_request_parts(self, parts: Sequence[ModelRequestPart]) -> list[chat_typ
137142
138143 return xai_messages
139144
140- def _map_user_prompt (self , part : UserPromptPart ) -> chat_types .chat_pb2 .Message | None :
145+ async def _upload_file_to_xai (self , data : bytes , filename : str ) -> str :
146+ """Upload a file to xAI files API and return the file ID.
147+
148+ Args:
149+ data: The file content as bytes
150+ filename: The filename to use for the upload
151+
152+ Returns:
153+ The file ID from xAI
154+ """
155+ uploaded_file = await self ._provider .client .files .upload (data , filename = filename )
156+ return uploaded_file .id
157+
158+ async def _map_user_prompt (self , part : UserPromptPart ) -> chat_types .chat_pb2 .Message | None : # noqa: C901
141159 """Map a UserPromptPart to an xAI user message."""
142160 if isinstance (part .content , str ):
143161 return user (part .content )
@@ -158,9 +176,33 @@ def _map_user_prompt(self, part: UserPromptPart) -> chat_types.chat_pb2.Message
158176 if item .is_image :
159177 # Convert binary content to data URI and use image()
160178 content_items .append (image (item .data_uri , detail = 'auto' ))
161- else :
162- # xAI SDK doesn't support non-image binary content yet
163- pass
179+ elif item .is_audio :
180+ raise NotImplementedError ('AudioUrl/BinaryContent with audio is not supported by xAI SDK' )
181+ elif item .is_document :
182+ # Upload document to xAI files API and reference it
183+ filename = item .identifier or f'document.{ item .format } '
184+ file_id = await self ._upload_file_to_xai (item .data , filename )
185+ content_items .append (file (file_id ))
186+ else : # pragma: no cover
187+ raise RuntimeError (f'Unsupported binary content type: { item .media_type } ' )
188+ elif isinstance (item , AudioUrl ):
189+ raise NotImplementedError ('AudioUrl is not supported by xAI SDK' )
190+ elif isinstance (item , DocumentUrl ):
191+ # Download and upload to xAI files API
192+ downloaded = await download_item (item , data_format = 'bytes' )
193+ filename = item .identifier or 'document'
194+ if 'data_type' in downloaded and downloaded ['data_type' ]:
195+ filename = f'{ filename } .{ downloaded ["data_type" ]} '
196+
197+ file_id = await self ._upload_file_to_xai (downloaded ['data' ], filename )
198+ content_items .append (file (file_id ))
199+ elif isinstance (item , VideoUrl ):
200+ raise NotImplementedError ('VideoUrl is not supported by xAI SDK' )
201+ elif isinstance (item , CachePoint ):
202+ # xAI doesn't support prompt caching via CachePoint, so we filter it out
203+ pass
204+ else :
205+ assert_never (item )
164206
165207 if content_items :
166208 return user (* content_items )
@@ -225,7 +267,7 @@ async def request(
225267 client = self ._provider .client
226268
227269 # Convert messages to xAI format
228- xai_messages = self ._map_messages (messages )
270+ xai_messages = await self ._map_messages (messages )
229271
230272 # Convert tools: combine built-in (server-side) tools and custom (client-side) tools
231273 tools : list [chat_types .chat_pb2 .Tool ] = []
@@ -277,7 +319,7 @@ async def request_stream(
277319 client = self ._provider .client
278320
279321 # Convert messages to xAI format
280- xai_messages = self ._map_messages (messages )
322+ xai_messages = await self ._map_messages (messages )
281323
282324 # Convert tools: combine built-in (server-side) tools and custom (client-side) tools
283325 tools : list [chat_types .chat_pb2 .Tool ] = []
0 commit comments