33from collections .abc import AsyncIterator , Iterator , Sequence
44from contextlib import asynccontextmanager
55from dataclasses import dataclass
6- from typing import Any , Literal
6+ from typing import Any , Literal , cast
77
88try :
99 import xai_sdk .chat as chat_types
@@ -312,12 +312,10 @@ async def request_stream(
312312
313313 def _process_response (self , response : chat_types .Response ) -> ModelResponse :
314314 """Convert xAI SDK response to pydantic_ai ModelResponse."""
315- from typing import cast
316-
317315 parts : list [ModelResponsePart ] = []
318316
319317 # Add reasoning/thinking content first if present
320- if hasattr ( response , 'reasoning_content' ) and response .reasoning_content :
318+ if response .reasoning_content :
321319 # reasoning_content is the human-readable summary
322320 parts .append (
323321 ThinkingPart (
@@ -326,7 +324,7 @@ def _process_response(self, response: chat_types.Response) -> ModelResponse:
326324 provider_name = 'xai' ,
327325 )
328326 )
329- elif hasattr ( response , 'encrypted_content' ) and response .encrypted_content :
327+ elif response .encrypted_content :
330328 # encrypted_content is a signature that can be sent back for reasoning continuity
331329 parts .append (
332330 ThinkingPart (
@@ -343,14 +341,13 @@ def _process_response(self, response: chat_types.Response) -> ModelResponse:
343341 # In real responses, we can use get_tool_call_type()
344342 # In mock responses, we default to client-side tools
345343 is_server_side_tool = False
346- if hasattr (tool_call , 'type' ):
347- try :
348- tool_type = get_tool_call_type (tool_call )
349- # If it's not a client-side tool, it's a server-side tool
350- is_server_side_tool = tool_type != 'client_side_tool'
351- except Exception :
352- # If we can't determine the type, treat as client-side
353- pass
344+ try :
345+ tool_type = get_tool_call_type (tool_call )
346+ # If it's not a client-side tool, it's a server-side tool
347+ is_server_side_tool = tool_type != 'client_side_tool'
348+ except Exception :
349+ # If we can't determine the type, treat as client-side
350+ pass
354351
355352 if is_server_side_tool :
356353 # Server-side tools are executed by xAI, so we add both call and return parts
@@ -410,7 +407,7 @@ def _process_response(self, response: chat_types.Response) -> ModelResponse:
410407 model_name = self ._model_name ,
411408 timestamp = now_utc (),
412409 provider_name = 'xai' ,
413- provider_response_id = response .id if hasattr ( response , 'id' ) else None ,
410+ provider_response_id = response .id ,
414411 finish_reason = finish_reason ,
415412 )
416413
@@ -427,43 +424,38 @@ def extract_usage(response: chat_types.Response) -> RequestUsage:
427424 - cache_read_tokens: Tokens read from prompt cache
428425 - server_side_tools_used: Count of server-side (built-in) tools executed
429426 """
430- if not hasattr ( response , ' usage' ) :
427+ if not response . usage :
431428 return RequestUsage ()
432429
433- usage_obj = getattr (response , 'usage' , None )
434- if not usage_obj :
435- return RequestUsage ()
430+ usage_obj = response .usage
436431
437- prompt_tokens = getattr ( usage_obj , ' prompt_tokens' , 0 )
438- completion_tokens = getattr ( usage_obj , ' completion_tokens' , 0 )
432+ prompt_tokens = usage_obj . prompt_tokens or 0
433+ completion_tokens = usage_obj . completion_tokens or 0
439434
440435 # Build details dict for additional usage metrics
441436 details : dict [str , int ] = {}
442437
443- # Add reasoning tokens if available
444- if hasattr (usage_obj , 'reasoning_tokens' ):
445- reasoning_tokens = getattr (usage_obj , 'reasoning_tokens' , 0 )
446- if reasoning_tokens :
447- details ['reasoning_tokens' ] = reasoning_tokens
448-
449- # Add cached prompt tokens if available
450- if hasattr (usage_obj , 'cached_prompt_text_tokens' ):
451- cached_tokens = getattr (usage_obj , 'cached_prompt_text_tokens' , 0 )
452- if cached_tokens :
453- details ['cache_read_tokens' ] = cached_tokens
454-
455- # Add server-side tools used count if available
456- if hasattr (usage_obj , 'server_side_tools_used' ):
457- server_side_tools = getattr (usage_obj , 'server_side_tools_used' , None )
438+ # Add reasoning tokens if available (optional attribute)
439+ reasoning_tokens = getattr (usage_obj , 'reasoning_tokens' , None )
440+ if reasoning_tokens :
441+ details ['reasoning_tokens' ] = reasoning_tokens
442+
443+ # Add cached prompt tokens if available (optional attribute)
444+ cached_tokens = getattr (usage_obj , 'cached_prompt_text_tokens' , None )
445+ if cached_tokens :
446+ details ['cache_read_tokens' ] = cached_tokens
447+
448+ # Add server-side tools used count if available (optional attribute)
449+ server_side_tools = getattr (usage_obj , 'server_side_tools_used' , None )
450+ if server_side_tools :
458451 # server_side_tools_used is a repeated field (list-like) in the real SDK
459452 # but may be an int in mocks for simplicity
460- if server_side_tools :
461- if isinstance (server_side_tools , int ):
462- tools_count = server_side_tools
463- else :
464- tools_count = len (server_side_tools )
465- if tools_count :
466- details ['server_side_tools_used' ] = tools_count
453+ if isinstance (server_side_tools , int ):
454+ tools_count = server_side_tools
455+ else :
456+ tools_count = len (server_side_tools )
457+ if tools_count :
458+ details ['server_side_tools_used' ] = tools_count
467459
468460 if details :
469461 return RequestUsage (
@@ -489,18 +481,16 @@ class XaiStreamedResponse(StreamedResponse):
489481
490482 def _update_response_state (self , response : Any ) -> None :
491483 """Update response state including usage, response ID, and finish reason."""
492- from typing import cast
493-
494484 # Update usage
495- if hasattr ( response , ' usage' ) :
485+ if response . usage :
496486 self ._usage = XaiModel .extract_usage (response )
497487
498488 # Set provider response ID
499- if hasattr ( response , 'id' ) and self .provider_response_id is None :
489+ if response . id and self .provider_response_id is None :
500490 self .provider_response_id = response .id
501491
502492 # Handle finish reason
503- if hasattr ( response , 'finish_reason' ) and response .finish_reason :
493+ if response .finish_reason :
504494 finish_reason_map = {
505495 'stop' : 'stop' ,
506496 'length' : 'length' ,
@@ -517,15 +507,15 @@ def _handle_reasoning_content(self, response: Any, reasoning_handled: bool) -> I
517507 if reasoning_handled :
518508 return
519509
520- if hasattr ( response , 'reasoning_content' ) and response .reasoning_content :
510+ if response .reasoning_content :
521511 # reasoning_content is the human-readable summary
522512 thinking_part = ThinkingPart (
523513 content = response .reasoning_content ,
524514 signature = None ,
525515 provider_name = 'xai' ,
526516 )
527517 yield self ._parts_manager .handle_part (vendor_part_id = 'reasoning' , part = thinking_part )
528- elif hasattr ( response , 'encrypted_content' ) and response .encrypted_content :
518+ elif response .encrypted_content :
529519 # encrypted_content is a signature that can be sent back for reasoning continuity
530520 thinking_part = ThinkingPart (
531521 content = '' , # No readable content for encrypted-only reasoning
@@ -536,7 +526,7 @@ def _handle_reasoning_content(self, response: Any, reasoning_handled: bool) -> I
536526
537527 def _handle_text_delta (self , chunk : Any ) -> Iterator [ModelResponseStreamEvent ]:
538528 """Handle text content delta from chunk."""
539- if hasattr ( chunk , 'content' ) and chunk .content :
529+ if chunk .content :
540530 event = self ._parts_manager .handle_text_delta (
541531 vendor_part_id = 'content' ,
542532 content = chunk .content ,
@@ -546,17 +536,16 @@ def _handle_text_delta(self, chunk: Any) -> Iterator[ModelResponseStreamEvent]:
546536
547537 def _handle_single_tool_call (self , tool_call : Any ) -> Iterator [ModelResponseStreamEvent ]:
548538 """Handle a single tool call, routing to server-side or client-side handler."""
549- if not ( hasattr ( tool_call .function , 'name' ) and tool_call . function . name ) :
539+ if not tool_call .function . name :
550540 return
551541
552542 # Determine if this is a server-side (built-in) tool
553543 is_server_side_tool = False
554- if hasattr (tool_call , 'type' ):
555- try :
556- tool_type = get_tool_call_type (tool_call )
557- is_server_side_tool = tool_type != 'client_side_tool'
558- except Exception :
559- pass # Treat as client-side if we can't determine
544+ try :
545+ tool_type = get_tool_call_type (tool_call )
546+ is_server_side_tool = tool_type != 'client_side_tool'
547+ except Exception :
548+ pass # Treat as client-side if we can't determine
560549
561550 if is_server_side_tool :
562551 # Server-side tools - create BuiltinToolCallPart and BuiltinToolReturnPart
@@ -588,7 +577,7 @@ def _handle_single_tool_call(self, tool_call: Any) -> Iterator[ModelResponseStre
588577
589578 def _handle_tool_calls (self , response : Any ) -> Iterator [ModelResponseStreamEvent ]:
590579 """Handle tool calls (both client-side and server-side)."""
591- if not hasattr ( response , ' tool_calls' ) :
580+ if not response . tool_calls :
592581 return
593582
594583 for tool_call in response .tool_calls :
0 commit comments