Skip to content

Commit 185a6be

Browse files
committed
Adding image tests that require an active API_KEY
1 parent 4f15966 commit 185a6be

File tree

4 files changed

+196
-20
lines changed

4 files changed

+196
-20
lines changed

pydantic_ai_slim/pydantic_ai/models/grok.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010

1111
# Import xai_sdk components
1212
from xai_sdk import AsyncClient
13-
from xai_sdk.chat import assistant, system, tool, tool_result, user
13+
from xai_sdk.chat import assistant, image, system, tool, tool_result, user
1414

1515
from .._run_context import RunContext
1616
from .._utils import now_utc
1717
from ..messages import (
18+
BinaryContent,
1819
FinishReason,
20+
ImageUrl,
1921
ModelMessage,
2022
ModelRequest,
2123
ModelRequestPart,
@@ -113,10 +115,28 @@ def _map_user_prompt(self, part: UserPromptPart) -> chat_types.chat_pb2.Message
113115
if isinstance(part.content, str):
114116
return user(part.content)
115117

116-
# Handle complex content (images, etc.)
117-
text_parts: list[str] = [item for item in part.content if isinstance(item, str)]
118-
if text_parts:
119-
return user(' '.join(text_parts))
118+
# Handle complex content (images, text, etc.)
119+
content_items: list[chat_types.Content] = []
120+
121+
for item in part.content:
122+
if isinstance(item, str):
123+
content_items.append(item)
124+
elif isinstance(item, ImageUrl):
125+
# Get detail from vendor_metadata if available
126+
detail: chat_types.ImageDetail = 'auto'
127+
if item.vendor_metadata and 'detail' in item.vendor_metadata:
128+
detail = item.vendor_metadata['detail']
129+
content_items.append(image(item.url, detail=detail))
130+
elif isinstance(item, BinaryContent):
131+
if item.is_image:
132+
# Convert binary content to data URI and use image()
133+
content_items.append(image(item.data_uri, detail='auto'))
134+
else:
135+
# xAI SDK doesn't support non-image binary content yet
136+
pass
137+
138+
if content_items:
139+
return user(*content_items)
120140

121141
return None
122142

@@ -171,6 +191,12 @@ async def request(
171191
xai_settings['stop'] = model_settings['stop_sequences']
172192
if 'seed' in model_settings:
173193
xai_settings['seed'] = model_settings['seed']
194+
if 'parallel_tool_calls' in model_settings:
195+
xai_settings['parallel_tool_calls'] = model_settings['parallel_tool_calls']
196+
if 'presence_penalty' in model_settings:
197+
xai_settings['presence_penalty'] = model_settings['presence_penalty']
198+
if 'frequency_penalty' in model_settings:
199+
xai_settings['frequency_penalty'] = model_settings['frequency_penalty']
174200

175201
# Create chat instance
176202
chat = client.chat.create(model=self._model_name, messages=xai_messages, tools=tools, **xai_settings)
@@ -213,6 +239,12 @@ async def request_stream(
213239
xai_settings['stop'] = model_settings['stop_sequences']
214240
if 'seed' in model_settings:
215241
xai_settings['seed'] = model_settings['seed']
242+
if 'parallel_tool_calls' in model_settings:
243+
xai_settings['parallel_tool_calls'] = model_settings['parallel_tool_calls']
244+
if 'presence_penalty' in model_settings:
245+
xai_settings['presence_penalty'] = model_settings['presence_penalty']
246+
if 'frequency_penalty' in model_settings:
247+
xai_settings['frequency_penalty'] = model_settings['frequency_penalty']
216248

217249
# Create chat instance
218250
chat = client.chat.create(model=self._model_name, messages=xai_messages, tools=tools, **xai_settings)

pydantic_ai_slim/pydantic_ai/settings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class ModelSettings(TypedDict, total=False):
8686
* OpenAI (some models, not o1)
8787
* Groq
8888
* Anthropic
89+
* Grok
8990
"""
9091

9192
seed: int
@@ -112,6 +113,7 @@ class ModelSettings(TypedDict, total=False):
112113
* Gemini
113114
* Mistral
114115
* Outlines (LlamaCpp, SgLang, VLLMOffline)
116+
* Grok
115117
"""
116118

117119
frequency_penalty: float
@@ -125,6 +127,7 @@ class ModelSettings(TypedDict, total=False):
125127
* Gemini
126128
* Mistral
127129
* Outlines (LlamaCpp, SgLang, VLLMOffline)
130+
* Grok
128131
"""
129132

130133
logit_bias: dict[str, int]
@@ -149,6 +152,7 @@ class ModelSettings(TypedDict, total=False):
149152
* Groq
150153
* Cohere
151154
* Google
155+
* Grok
152156
"""
153157

154158
extra_headers: dict[str, str]
@@ -159,6 +163,7 @@ class ModelSettings(TypedDict, total=False):
159163
* OpenAI
160164
* Anthropic
161165
* Groq
166+
* Grok
162167
"""
163168

164169
extra_body: object

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,11 @@ def cerebras_api_key() -> str:
412412
return os.getenv('CEREBRAS_API_KEY', 'mock-api-key')
413413

414414

415+
@pytest.fixture(scope='session')
416+
def xai_api_key() -> str:
417+
return os.getenv('XAI_API_KEY', 'mock-api-key')
418+
419+
415420
@pytest.fixture(scope='session')
416421
def bedrock_provider():
417422
try:

tests/models/test_grok.py

Lines changed: 149 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
import json
4+
import os
45
from datetime import timezone
56
from types import SimpleNamespace
67
from typing import Any, cast
@@ -545,23 +546,155 @@ async def test_grok_none_delta(allow_model_requests: None):
545546
# test_openai_o1_mini_system_role - OpenAI specific
546547

547548

549+
@pytest.mark.parametrize('parallel_tool_calls', [True, False])
550+
async def test_grok_parallel_tool_calls(allow_model_requests: None, parallel_tool_calls: bool) -> None:
551+
tool_call = create_tool_call(
552+
id='123',
553+
name='final_result',
554+
arguments={'response': [1, 2, 3]},
555+
)
556+
response = create_response(content='', tool_calls=[tool_call], finish_reason='tool_calls')
557+
mock_client = MockGrok.create_mock(response)
558+
m = GrokModel('grok-4-fast-non-reasoning', client=mock_client)
559+
agent = Agent(m, output_type=list[int], model_settings=ModelSettings(parallel_tool_calls=parallel_tool_calls))
560+
561+
await agent.run('Hello')
562+
assert get_mock_chat_create_kwargs(mock_client)[0]['parallel_tool_calls'] == parallel_tool_calls
563+
564+
565+
async def test_grok_penalty_parameters(allow_model_requests: None) -> None:
566+
response = create_response(content='test response')
567+
mock_client = MockGrok.create_mock(response)
568+
m = GrokModel('grok-4-fast-non-reasoning', client=mock_client)
569+
570+
settings = ModelSettings(
571+
temperature=0.7,
572+
presence_penalty=0.5,
573+
frequency_penalty=0.3,
574+
parallel_tool_calls=False,
575+
)
576+
577+
agent = Agent(m, model_settings=settings)
578+
result = await agent.run('Hello')
579+
580+
# Check that all settings were passed to the xAI SDK
581+
kwargs = get_mock_chat_create_kwargs(mock_client)[0]
582+
assert kwargs['temperature'] == 0.7
583+
assert kwargs['presence_penalty'] == 0.5
584+
assert kwargs['frequency_penalty'] == 0.3
585+
assert kwargs['parallel_tool_calls'] is False
586+
assert result.output == 'test response'
587+
588+
589+
async def test_grok_image_url_input(allow_model_requests: None):
590+
response = create_response(content='world')
591+
mock_client = MockGrok.create_mock(response)
592+
m = GrokModel('grok-4-fast-non-reasoning', client=mock_client)
593+
agent = Agent(m)
594+
595+
result = await agent.run(
596+
[
597+
'hello',
598+
ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg'),
599+
]
600+
)
601+
assert result.output == 'world'
602+
# Verify that the image URL was included in the messages
603+
assert len(get_mock_chat_create_kwargs(mock_client)) == 1
604+
605+
606+
@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)')
607+
async def test_grok_image_url_tool_response(allow_model_requests: None, xai_api_key: str):
608+
m = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key)
609+
agent = Agent(m)
610+
611+
@agent.tool_plain
612+
async def get_image() -> ImageUrl:
613+
return ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg')
614+
615+
result = await agent.run(['What food is in the image you can get from the get_image tool?'])
616+
617+
# Verify structure with matchers for dynamic values
618+
messages = result.all_messages()
619+
assert len(messages) == 4
620+
621+
# Verify message types and key content
622+
assert isinstance(messages[0], ModelRequest)
623+
assert isinstance(messages[1], ModelResponse)
624+
assert isinstance(messages[2], ModelRequest)
625+
assert isinstance(messages[3], ModelResponse)
626+
627+
# Verify tool was called
628+
assert isinstance(messages[1].parts[0], ToolCallPart)
629+
assert messages[1].parts[0].tool_name == 'get_image'
630+
631+
# Verify image was passed back to model
632+
assert isinstance(messages[2].parts[1], UserPromptPart)
633+
assert isinstance(messages[2].parts[1].content, list)
634+
assert any(isinstance(item, ImageUrl) for item in messages[2].parts[1].content)
635+
636+
# Verify model responded about the image
637+
assert isinstance(messages[3].parts[0], TextPart)
638+
assert 'potato' in messages[3].parts[0].content.lower()
639+
640+
641+
@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)')
642+
async def test_grok_image_as_binary_content_tool_response(
643+
allow_model_requests: None, image_content: BinaryContent, xai_api_key: str
644+
):
645+
m = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key)
646+
agent = Agent(m)
647+
648+
@agent.tool_plain
649+
async def get_image() -> BinaryContent:
650+
return image_content
651+
652+
result = await agent.run(['What fruit is in the image you can get from the get_image tool?'])
653+
654+
# Verify structure with matchers for dynamic values
655+
messages = result.all_messages()
656+
assert len(messages) == 4
657+
658+
# Verify message types and key content
659+
assert isinstance(messages[0], ModelRequest)
660+
assert isinstance(messages[1], ModelResponse)
661+
assert isinstance(messages[2], ModelRequest)
662+
assert isinstance(messages[3], ModelResponse)
663+
664+
# Verify tool was called
665+
assert isinstance(messages[1].parts[0], ToolCallPart)
666+
assert messages[1].parts[0].tool_name == 'get_image'
667+
668+
# Verify binary image content was passed back to model
669+
assert isinstance(messages[2].parts[1], UserPromptPart)
670+
assert isinstance(messages[2].parts[1].content, list)
671+
has_binary_image = any(isinstance(item, BinaryContent) and item.is_image for item in messages[2].parts[1].content)
672+
assert has_binary_image, 'Expected BinaryContent image in tool response'
673+
674+
# Verify model responded about the image
675+
assert isinstance(messages[3].parts[0], TextPart)
676+
response_text = messages[3].parts[0].content.lower()
677+
assert 'kiwi' in response_text or 'fruit' in response_text
678+
679+
680+
@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)')
681+
async def test_grok_image_as_binary_content_input(
682+
allow_model_requests: None, image_content: BinaryContent, xai_api_key: str
683+
):
684+
"""Test passing binary image content directly as input (not from a tool)."""
685+
m = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key)
686+
agent = Agent(m)
687+
688+
result = await agent.run(['What fruit is in the image?', image_content])
689+
690+
# Verify the model received and processed the image
691+
assert result.output
692+
response_text = result.output.lower()
693+
assert 'kiwi' in response_text or 'fruit' in response_text
694+
695+
548696
# Skip tests that are not applicable to Grok model
549697
# The following tests were removed as they are OpenAI-specific:
550-
# - test_system_prompt_role (OpenAI-specific system prompt roles)
551-
# - test_system_prompt_role_o1_mini (OpenAI o1 specific)
552-
# - test_openai_pass_custom_system_prompt_role (OpenAI-specific)
553-
# - test_openai_o1_mini_system_role (OpenAI-specific)
554-
# - test_parallel_tool_calls (OpenAI-specific parameter)
555-
# - test_image_url_input (OpenAI-specific image handling - would need VCR cassettes for Grok)
556-
# - test_image_url_input_force_download (OpenAI-specific)
557-
# - test_image_url_input_force_download_response_api (OpenAI-specific)
558-
# - test_openai_audio_url_input (OpenAI-specific audio)
559-
# - test_document_url_input (OpenAI-specific documents)
560-
# - test_image_url_tool_response (OpenAI-specific)
561-
# - test_image_as_binary_content_tool_response (OpenAI-specific)
562-
# - test_image_as_binary_content_input (OpenAI-specific)
563-
# - test_audio_as_binary_content_input (OpenAI-specific)
564-
# - test_binary_content_input_unknown_media_type (OpenAI-specific)
565698

566699

567700
# Continue with model request/response tests
@@ -691,6 +824,7 @@ async def get_info(query: str) -> str:
691824

692825

693826
# Test for error handling
827+
@pytest.mark.skipif(os.getenv('XAI_API_KEY') is not None, reason='Skipped when XAI_API_KEY is set')
694828
async def test_grok_model_invalid_api_key():
695829
"""Test Grok model with invalid API key."""
696830
with pytest.raises(ValueError, match='XAI API key is required'):

0 commit comments

Comments
 (0)