Skip to content

Commit e60d3ea

Browse files
committed
Updated grok to support passing in AsyncClient, added initial tests
1 parent 4513c8f commit e60d3ea

File tree

3 files changed

+939
-15
lines changed

3 files changed

+939
-15
lines changed

pydantic_ai_slim/pydantic_ai/models/grok.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,26 +42,33 @@ class GrokModel(Model):
4242

4343
_model_name: str
4444
_api_key: str
45+
_client: AsyncClient | None
4546

4647
def __init__(
4748
self,
4849
model_name: str,
4950
*,
5051
api_key: str | None = None,
52+
client: AsyncClient | None = None,
5153
settings: ModelSettings | None = None,
5254
):
5355
"""Initialize the Grok model.
5456
5557
Args:
5658
model_name: The name of the Grok model to use (e.g., "grok-3", "grok-4-fast-non-reasoning")
5759
api_key: The xAI API key. If not provided, uses XAI_API_KEY environment variable.
60+
client: Optional AsyncClient instance for testing. If provided, api_key is ignored.
5861
settings: Optional model settings.
5962
"""
6063
super().__init__(settings=settings)
6164
self._model_name = model_name
62-
self._api_key = api_key or os.getenv('XAI_API_KEY') or ''
63-
if not self._api_key:
64-
raise ValueError('XAI API key is required')
65+
self._client = client
66+
if client is None:
67+
self._api_key = api_key or os.getenv('XAI_API_KEY') or ''
68+
if not self._api_key:
69+
raise ValueError('XAI API key is required')
70+
else:
71+
self._api_key = api_key or ''
6572

6673
@property
6774
def model_name(self) -> str:
@@ -141,8 +148,8 @@ async def request(
141148
model_request_parameters: ModelRequestParameters,
142149
) -> ModelResponse:
143150
"""Make a request to the Grok model."""
144-
# Create client in the current async context to avoid event loop issues
145-
client = AsyncClient(api_key=self._api_key)
151+
# Use injected client or create one in the current async context
152+
client = self._client or AsyncClient(api_key=self._api_key)
146153

147154
# Convert messages to xAI format
148155
xai_messages = self._map_messages(messages)
@@ -183,8 +190,8 @@ async def request_stream(
183190
run_context: RunContext[Any] | None = None,
184191
) -> AsyncIterator[StreamedResponse]:
185192
"""Make a streaming request to the Grok model."""
186-
# Create client in the current async context to avoid event loop issues
187-
client = AsyncClient(api_key=self._api_key)
193+
# Use injected client or create one in the current async context
194+
client = self._client or AsyncClient(api_key=self._api_key)
188195

189196
# Convert messages to xAI format
190197
xai_messages = self._map_messages(messages)
@@ -318,14 +325,17 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
318325
yield event
319326

320327
# Handle tool calls
321-
if hasattr(chunk, 'tool_calls'):
322-
for tool_call in chunk.tool_calls:
323-
yield self._parts_manager.handle_tool_call_part(
324-
vendor_part_id=tool_call.id,
325-
tool_name=tool_call.function.name,
326-
args=tool_call.function.arguments,
327-
tool_call_id=tool_call.id,
328-
)
328+
# Note: We use the accumulated Response tool calls, not the Chunk deltas,
329+
# because pydantic validation needs complete JSON, not partial deltas
330+
if hasattr(response, 'tool_calls'):
331+
for tool_call in response.tool_calls:
332+
if hasattr(tool_call.function, 'name') and tool_call.function.name:
333+
yield self._parts_manager.handle_tool_call_part(
334+
vendor_part_id=tool_call.id,
335+
tool_name=tool_call.function.name,
336+
args=tool_call.function.arguments,
337+
tool_call_id=tool_call.id,
338+
)
329339

330340
@property
331341
def model_name(self) -> str:

tests/models/mock_grok.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
from __future__ import annotations as _annotations
2+
3+
from collections.abc import Sequence
4+
from dataclasses import dataclass, field
5+
from functools import cached_property
6+
from typing import Any, cast
7+
8+
from ..conftest import raise_if_exception, try_import
9+
from .mock_async_stream import MockAsyncStream
10+
11+
with try_import() as imports_successful:
12+
import xai_sdk.chat as chat_types
13+
from xai_sdk import AsyncClient
14+
15+
MockResponse = chat_types.Response | Exception
16+
# xai_sdk streaming returns tuples of (Response, chunk) where chunk type is not explicitly defined
17+
MockResponseChunk = tuple[chat_types.Response, Any] | Exception
18+
19+
20+
@dataclass
21+
class MockGrok:
22+
"""Mock for xAI SDK AsyncClient to simulate Grok API responses."""
23+
24+
responses: MockResponse | Sequence[MockResponse] | None = None
25+
stream_data: Sequence[MockResponseChunk] | Sequence[Sequence[MockResponseChunk]] | None = None
26+
index: int = 0
27+
chat_create_kwargs: list[dict[str, Any]] = field(default_factory=list)
28+
api_key: str = 'test-api-key'
29+
30+
@cached_property
31+
def chat(self) -> Any:
32+
"""Create mock chat interface."""
33+
return type('Chat', (), {'create': self.chat_create})
34+
35+
@classmethod
36+
def create_mock(
37+
cls, responses: MockResponse | Sequence[MockResponse], api_key: str = 'test-api-key'
38+
) -> AsyncClient:
39+
"""Create a mock AsyncClient for non-streaming responses."""
40+
return cast(AsyncClient, cls(responses=responses, api_key=api_key))
41+
42+
@classmethod
43+
def create_mock_stream(
44+
cls,
45+
stream: Sequence[MockResponseChunk] | Sequence[Sequence[MockResponseChunk]],
46+
api_key: str = 'test-api-key',
47+
) -> AsyncClient:
48+
"""Create a mock AsyncClient for streaming responses."""
49+
return cast(AsyncClient, cls(stream_data=stream, api_key=api_key))
50+
51+
def chat_create(self, *_args: Any, **kwargs: Any) -> MockChatInstance:
52+
"""Mock the chat.create method."""
53+
self.chat_create_kwargs.append(kwargs)
54+
return MockChatInstance(
55+
responses=self.responses,
56+
stream_data=self.stream_data,
57+
index=self.index,
58+
parent=self,
59+
)
60+
61+
62+
@dataclass
63+
class MockChatInstance:
64+
"""Mock for the chat instance returned by client.chat.create()."""
65+
66+
responses: MockResponse | Sequence[MockResponse] | None = None
67+
stream_data: Sequence[MockResponseChunk] | Sequence[Sequence[MockResponseChunk]] | None = None
68+
index: int = 0
69+
parent: MockGrok | None = None
70+
71+
async def sample(self) -> chat_types.Response:
72+
"""Mock the sample() method for non-streaming responses."""
73+
assert self.responses is not None, 'you can only use sample() if responses are provided'
74+
75+
if isinstance(self.responses, Sequence):
76+
raise_if_exception(self.responses[self.index])
77+
response = cast(chat_types.Response, self.responses[self.index])
78+
else:
79+
raise_if_exception(self.responses)
80+
response = cast(chat_types.Response, self.responses)
81+
82+
if self.parent:
83+
self.parent.index += 1
84+
85+
return response
86+
87+
def stream(self) -> MockAsyncStream[MockResponseChunk]:
88+
"""Mock the stream() method for streaming responses."""
89+
assert self.stream_data is not None, 'you can only use stream() if stream_data is provided'
90+
91+
# Check if we have nested sequences (multiple streams) vs single stream
92+
# We need to check if it's a list of tuples (single stream) vs list of lists (multiple streams)
93+
if isinstance(self.stream_data, list) and len(self.stream_data) > 0:
94+
first_item = self.stream_data[0]
95+
# If first item is a list (not a tuple), we have multiple streams
96+
if isinstance(first_item, list):
97+
data = cast(list[MockResponseChunk], self.stream_data[self.index])
98+
else:
99+
# Single stream - use the data as is
100+
data = cast(list[MockResponseChunk], self.stream_data)
101+
else:
102+
data = cast(list[MockResponseChunk], self.stream_data)
103+
104+
if self.parent:
105+
self.parent.index += 1
106+
107+
return MockAsyncStream(iter(data))
108+
109+
110+
def get_mock_chat_create_kwargs(async_client: AsyncClient) -> list[dict[str, Any]]:
111+
"""Extract the kwargs passed to chat.create from a mock client."""
112+
if isinstance(async_client, MockGrok):
113+
return async_client.chat_create_kwargs
114+
else: # pragma: no cover
115+
raise RuntimeError('Not a MockGrok instance')
116+
117+
118+
@dataclass
119+
class MockGrokResponse:
120+
"""Mock Response object that mimics xai_sdk.chat.Response interface."""
121+
122+
id: str = 'grok-123'
123+
content: str = ''
124+
tool_calls: list[Any] = field(default_factory=list)
125+
finish_reason: str = 'stop'
126+
usage: Any | None = None # Would be usage_pb2.SamplingUsage in real xai_sdk
127+
128+
129+
@dataclass
130+
class MockGrokToolCall:
131+
"""Mock ToolCall object that mimics chat_pb2.ToolCall interface."""
132+
133+
id: str
134+
function: Any # Would be chat_pb2.Function with name and arguments
135+
136+
137+
@dataclass
138+
class MockGrokFunction:
139+
"""Mock Function object for tool calls."""
140+
141+
name: str
142+
arguments: dict[str, Any]
143+
144+
145+
def create_response(
146+
content: str = '',
147+
tool_calls: list[Any] | None = None,
148+
finish_reason: str = 'stop',
149+
usage: Any | None = None,
150+
) -> MockGrokResponse:
151+
"""Create a mock Response object for testing.
152+
153+
Returns a MockGrokResponse that mimics the xai_sdk.chat.Response interface.
154+
"""
155+
return MockGrokResponse(
156+
id='grok-123',
157+
content=content,
158+
tool_calls=tool_calls or [],
159+
finish_reason=finish_reason,
160+
usage=usage,
161+
)
162+
163+
164+
def create_tool_call(
165+
id: str,
166+
name: str,
167+
arguments: dict[str, Any],
168+
) -> MockGrokToolCall:
169+
"""Create a mock ToolCall object for testing.
170+
171+
Returns a MockGrokToolCall that mimics the chat_pb2.ToolCall interface.
172+
"""
173+
return MockGrokToolCall(
174+
id=id,
175+
function=MockGrokFunction(name=name, arguments=arguments),
176+
)
177+
178+
179+
@dataclass
180+
class MockGrokResponseChunk:
181+
"""Mock response chunk for streaming."""
182+
183+
content: str = ''
184+
tool_calls: list[Any] = field(default_factory=list)
185+
186+
187+
def create_response_chunk(
188+
content: str = '',
189+
tool_calls: list[Any] | None = None,
190+
) -> MockGrokResponseChunk:
191+
"""Create a mock response chunk object for testing.
192+
193+
Returns a MockGrokResponseChunk for streaming responses.
194+
"""
195+
return MockGrokResponseChunk(
196+
content=content,
197+
tool_calls=tool_calls or [],
198+
)

0 commit comments

Comments
 (0)