|
1 | 1 | from __future__ import annotations as _annotations |
2 | 2 |
|
3 | 3 | import json |
| 4 | +import os |
4 | 5 | from datetime import timezone |
5 | 6 | from types import SimpleNamespace |
6 | 7 | from typing import Any, cast |
@@ -545,23 +546,155 @@ async def test_grok_none_delta(allow_model_requests: None): |
545 | 546 | # test_openai_o1_mini_system_role - OpenAI specific |
546 | 547 |
|
547 | 548 |
|
| 549 | +@pytest.mark.parametrize('parallel_tool_calls', [True, False]) |
| 550 | +async def test_grok_parallel_tool_calls(allow_model_requests: None, parallel_tool_calls: bool) -> None: |
| 551 | + tool_call = create_tool_call( |
| 552 | + id='123', |
| 553 | + name='final_result', |
| 554 | + arguments={'response': [1, 2, 3]}, |
| 555 | + ) |
| 556 | + response = create_response(content='', tool_calls=[tool_call], finish_reason='tool_calls') |
| 557 | + mock_client = MockGrok.create_mock(response) |
| 558 | + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) |
| 559 | + agent = Agent(m, output_type=list[int], model_settings=ModelSettings(parallel_tool_calls=parallel_tool_calls)) |
| 560 | + |
| 561 | + await agent.run('Hello') |
| 562 | + assert get_mock_chat_create_kwargs(mock_client)[0]['parallel_tool_calls'] == parallel_tool_calls |
| 563 | + |
| 564 | + |
| 565 | +async def test_grok_penalty_parameters(allow_model_requests: None) -> None: |
| 566 | + response = create_response(content='test response') |
| 567 | + mock_client = MockGrok.create_mock(response) |
| 568 | + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) |
| 569 | + |
| 570 | + settings = ModelSettings( |
| 571 | + temperature=0.7, |
| 572 | + presence_penalty=0.5, |
| 573 | + frequency_penalty=0.3, |
| 574 | + parallel_tool_calls=False, |
| 575 | + ) |
| 576 | + |
| 577 | + agent = Agent(m, model_settings=settings) |
| 578 | + result = await agent.run('Hello') |
| 579 | + |
| 580 | + # Check that all settings were passed to the xAI SDK |
| 581 | + kwargs = get_mock_chat_create_kwargs(mock_client)[0] |
| 582 | + assert kwargs['temperature'] == 0.7 |
| 583 | + assert kwargs['presence_penalty'] == 0.5 |
| 584 | + assert kwargs['frequency_penalty'] == 0.3 |
| 585 | + assert kwargs['parallel_tool_calls'] is False |
| 586 | + assert result.output == 'test response' |
| 587 | + |
| 588 | + |
| 589 | +async def test_grok_image_url_input(allow_model_requests: None): |
| 590 | + response = create_response(content='world') |
| 591 | + mock_client = MockGrok.create_mock(response) |
| 592 | + m = GrokModel('grok-4-fast-non-reasoning', client=mock_client) |
| 593 | + agent = Agent(m) |
| 594 | + |
| 595 | + result = await agent.run( |
| 596 | + [ |
| 597 | + 'hello', |
| 598 | + ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg'), |
| 599 | + ] |
| 600 | + ) |
| 601 | + assert result.output == 'world' |
| 602 | + # Verify that the image URL was included in the messages |
| 603 | + assert len(get_mock_chat_create_kwargs(mock_client)) == 1 |
| 604 | + |
| 605 | + |
| 606 | +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') |
| 607 | +async def test_grok_image_url_tool_response(allow_model_requests: None, xai_api_key: str): |
| 608 | + m = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key) |
| 609 | + agent = Agent(m) |
| 610 | + |
| 611 | + @agent.tool_plain |
| 612 | + async def get_image() -> ImageUrl: |
| 613 | + return ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg') |
| 614 | + |
| 615 | + result = await agent.run(['What food is in the image you can get from the get_image tool?']) |
| 616 | + |
| 617 | + # Verify structure with matchers for dynamic values |
| 618 | + messages = result.all_messages() |
| 619 | + assert len(messages) == 4 |
| 620 | + |
| 621 | + # Verify message types and key content |
| 622 | + assert isinstance(messages[0], ModelRequest) |
| 623 | + assert isinstance(messages[1], ModelResponse) |
| 624 | + assert isinstance(messages[2], ModelRequest) |
| 625 | + assert isinstance(messages[3], ModelResponse) |
| 626 | + |
| 627 | + # Verify tool was called |
| 628 | + assert isinstance(messages[1].parts[0], ToolCallPart) |
| 629 | + assert messages[1].parts[0].tool_name == 'get_image' |
| 630 | + |
| 631 | + # Verify image was passed back to model |
| 632 | + assert isinstance(messages[2].parts[1], UserPromptPart) |
| 633 | + assert isinstance(messages[2].parts[1].content, list) |
| 634 | + assert any(isinstance(item, ImageUrl) for item in messages[2].parts[1].content) |
| 635 | + |
| 636 | + # Verify model responded about the image |
| 637 | + assert isinstance(messages[3].parts[0], TextPart) |
| 638 | + assert 'potato' in messages[3].parts[0].content.lower() |
| 639 | + |
| 640 | + |
| 641 | +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') |
| 642 | +async def test_grok_image_as_binary_content_tool_response( |
| 643 | + allow_model_requests: None, image_content: BinaryContent, xai_api_key: str |
| 644 | +): |
| 645 | + m = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key) |
| 646 | + agent = Agent(m) |
| 647 | + |
| 648 | + @agent.tool_plain |
| 649 | + async def get_image() -> BinaryContent: |
| 650 | + return image_content |
| 651 | + |
| 652 | + result = await agent.run(['What fruit is in the image you can get from the get_image tool?']) |
| 653 | + |
| 654 | + # Verify structure with matchers for dynamic values |
| 655 | + messages = result.all_messages() |
| 656 | + assert len(messages) == 4 |
| 657 | + |
| 658 | + # Verify message types and key content |
| 659 | + assert isinstance(messages[0], ModelRequest) |
| 660 | + assert isinstance(messages[1], ModelResponse) |
| 661 | + assert isinstance(messages[2], ModelRequest) |
| 662 | + assert isinstance(messages[3], ModelResponse) |
| 663 | + |
| 664 | + # Verify tool was called |
| 665 | + assert isinstance(messages[1].parts[0], ToolCallPart) |
| 666 | + assert messages[1].parts[0].tool_name == 'get_image' |
| 667 | + |
| 668 | + # Verify binary image content was passed back to model |
| 669 | + assert isinstance(messages[2].parts[1], UserPromptPart) |
| 670 | + assert isinstance(messages[2].parts[1].content, list) |
| 671 | + has_binary_image = any(isinstance(item, BinaryContent) and item.is_image for item in messages[2].parts[1].content) |
| 672 | + assert has_binary_image, 'Expected BinaryContent image in tool response' |
| 673 | + |
| 674 | + # Verify model responded about the image |
| 675 | + assert isinstance(messages[3].parts[0], TextPart) |
| 676 | + response_text = messages[3].parts[0].content.lower() |
| 677 | + assert 'kiwi' in response_text or 'fruit' in response_text |
| 678 | + |
| 679 | + |
| 680 | +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is None, reason='Requires XAI_API_KEY (gRPC, no cassettes)') |
| 681 | +async def test_grok_image_as_binary_content_input( |
| 682 | + allow_model_requests: None, image_content: BinaryContent, xai_api_key: str |
| 683 | +): |
| 684 | + """Test passing binary image content directly as input (not from a tool).""" |
| 685 | + m = GrokModel('grok-4-fast-non-reasoning', api_key=xai_api_key) |
| 686 | + agent = Agent(m) |
| 687 | + |
| 688 | + result = await agent.run(['What fruit is in the image?', image_content]) |
| 689 | + |
| 690 | + # Verify the model received and processed the image |
| 691 | + assert result.output |
| 692 | + response_text = result.output.lower() |
| 693 | + assert 'kiwi' in response_text or 'fruit' in response_text |
| 694 | + |
| 695 | + |
548 | 696 | # Skip tests that are not applicable to Grok model |
549 | 697 | # The following tests were removed as they are OpenAI-specific: |
550 | | -# - test_system_prompt_role (OpenAI-specific system prompt roles) |
551 | | -# - test_system_prompt_role_o1_mini (OpenAI o1 specific) |
552 | | -# - test_openai_pass_custom_system_prompt_role (OpenAI-specific) |
553 | | -# - test_openai_o1_mini_system_role (OpenAI-specific) |
554 | | -# - test_parallel_tool_calls (OpenAI-specific parameter) |
555 | | -# - test_image_url_input (OpenAI-specific image handling - would need VCR cassettes for Grok) |
556 | | -# - test_image_url_input_force_download (OpenAI-specific) |
557 | | -# - test_image_url_input_force_download_response_api (OpenAI-specific) |
558 | | -# - test_openai_audio_url_input (OpenAI-specific audio) |
559 | | -# - test_document_url_input (OpenAI-specific documents) |
560 | | -# - test_image_url_tool_response (OpenAI-specific) |
561 | | -# - test_image_as_binary_content_tool_response (OpenAI-specific) |
562 | | -# - test_image_as_binary_content_input (OpenAI-specific) |
563 | | -# - test_audio_as_binary_content_input (OpenAI-specific) |
564 | | -# - test_binary_content_input_unknown_media_type (OpenAI-specific) |
565 | 698 |
|
566 | 699 |
|
567 | 700 | # Continue with model request/response tests |
@@ -691,6 +824,7 @@ async def get_info(query: str) -> str: |
691 | 824 |
|
692 | 825 |
|
693 | 826 | # Test for error handling |
| 827 | +@pytest.mark.skipif(os.getenv('XAI_API_KEY') is not None, reason='Skipped when XAI_API_KEY is set') |
694 | 828 | async def test_grok_model_invalid_api_key(): |
695 | 829 | """Test Grok model with invalid API key.""" |
696 | 830 | with pytest.raises(ValueError, match='XAI API key is required'): |
|
0 commit comments