|
2 | 2 |
|
3 | 3 | import datetime |
4 | 4 | import os |
5 | | -from typing import Any, Union |
| 5 | +from typing import Any |
6 | 6 |
|
7 | 7 | import pytest |
8 | 8 | from httpx import Timeout |
|
50 | 50 | with try_import() as imports_successful: |
51 | 51 | from google.genai.types import CodeExecutionResult, HarmBlockThreshold, HarmCategory, Language, Outcome |
52 | 52 |
|
53 | | - from pydantic_ai.models.google import GoogleModel, GoogleModelSettings |
| 53 | + from pydantic_ai.models.google import ( |
| 54 | + GoogleModel, |
| 55 | + GoogleModelSettings, |
| 56 | + _content_model_response, # pyright: ignore[reportPrivateUsage] |
| 57 | + ) |
54 | 58 | from pydantic_ai.providers.google import GoogleProvider |
55 | 59 |
|
56 | 60 | pytestmark = [ |
@@ -1815,117 +1819,110 @@ async def test_google_model_response_part_handling( |
1815 | 1819 | assert contents == expected_contents |
1816 | 1820 |
|
1817 | 1821 |
|
1818 | | -class FunctionCallDict(TypedDict): |
1819 | | - name: str |
1820 | | - args: dict[str, Any] |
1821 | | - id: str |
1822 | | - |
1823 | | - |
1824 | | -class FunctionCallPartDict(TypedDict): |
1825 | | - function_call: FunctionCallDict |
1826 | | - |
1827 | | - |
1828 | | -class TextPartDict(TypedDict): |
1829 | | - text: str |
1830 | | - |
1831 | | - |
1832 | | -class OtherPartDict(TypedDict, total=False): |
1833 | | - other_field: str |
1834 | | - |
1835 | | - |
1836 | | -# Union of all possible part types we're testing |
1837 | | -TestPartDict = Union[FunctionCallPartDict, TextPartDict, OtherPartDict, str] # str for non-dict parts |
1838 | | - |
1839 | | - |
1840 | | -class MockContentResponse(TypedDict, total=False): |
1841 | | - role: str |
1842 | | - parts: list[TestPartDict] |
1843 | | - |
1844 | | - |
1845 | | -class ExpectedContent(TypedDict, total=False): |
1846 | | - role: str |
1847 | | - parts: list[TestPartDict] |
1848 | | - |
1849 | | - |
1850 | 1822 | @pytest.mark.parametrize( |
1851 | | - 'mock_content_response,expected_contents', |
| 1823 | + 'model_parts,expected_parts', |
1852 | 1824 | [ |
1853 | 1825 | pytest.param( |
1854 | | - MockContentResponse( |
1855 | | - { |
1856 | | - 'role': 'model', |
1857 | | - 'parts': [ |
1858 | | - 'not_a_dict', # Non-dict part to test isinstance check |
1859 | | - {'function_call': {'name': 'test', 'args': {}, 'id': '123'}}, |
1860 | | - ], |
1861 | | - } |
1862 | | - ), |
| 1826 | + [ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')], |
1863 | 1827 | [ |
1864 | | - ExpectedContent( |
1865 | | - { |
1866 | | - 'role': 'model', |
1867 | | - 'parts': [ |
1868 | | - 'not_a_dict', |
1869 | | - {'function_call': {'name': 'test', 'args': {}, 'id': '123'}}, |
1870 | | - {'text': 'I have completed the function calls above.'}, |
1871 | | - ], |
| 1828 | + { |
| 1829 | + 'function_call': { |
| 1830 | + 'args': {'param': 'value'}, |
| 1831 | + 'id': 'call_123', |
| 1832 | + 'name': 'test_tool', |
1872 | 1833 | } |
1873 | | - ) |
| 1834 | + }, |
| 1835 | + {'text': 'I have completed the function calls above.'}, |
1874 | 1836 | ], |
1875 | | - id='non_dict_parts_with_function_call', |
| 1837 | + id='function_call_only_adds_text', |
1876 | 1838 | ), |
1877 | 1839 | pytest.param( |
1878 | | - MockContentResponse( |
| 1840 | + [ |
| 1841 | + ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'), |
| 1842 | + TextPart(content='Here is the result:'), |
| 1843 | + ], |
| 1844 | + [ |
1879 | 1845 | { |
1880 | | - 'role': 'model', |
1881 | | - 'parts': [ |
1882 | | - {'other_field': 'value'}, # Dict without function_call or text |
1883 | | - {'function_call': {'name': 'test', 'args': {}, 'id': '123'}}, |
1884 | | - ], |
1885 | | - } |
1886 | | - ), |
| 1846 | + 'function_call': { |
| 1847 | + 'args': {'param': 'value'}, |
| 1848 | + 'id': 'call_123', |
| 1849 | + 'name': 'test_tool', |
| 1850 | + } |
| 1851 | + }, |
| 1852 | + {'text': 'Here is the result:'}, |
| 1853 | + ], |
| 1854 | + id='function_call_with_text_no_addition', |
| 1855 | + ), |
| 1856 | + pytest.param( |
| 1857 | + [TextPart(content='Just text response')], |
| 1858 | + [{'text': 'Just text response'}], |
| 1859 | + id='text_only_no_addition', |
| 1860 | + ), |
| 1861 | + pytest.param( |
| 1862 | + [ |
| 1863 | + ToolCallPart(tool_name='tool1', args={'a': 1}, tool_call_id='call_1'), |
| 1864 | + ToolCallPart(tool_name='tool2', args={'b': 2}, tool_call_id='call_2'), |
| 1865 | + ], |
1887 | 1866 | [ |
1888 | | - ExpectedContent( |
1889 | | - { |
1890 | | - 'role': 'model', |
1891 | | - 'parts': [ |
1892 | | - {'other_field': 'value'}, |
1893 | | - {'function_call': {'name': 'test', 'args': {}, 'id': '123'}}, |
1894 | | - {'text': 'I have completed the function calls above.'}, |
1895 | | - ], |
| 1867 | + { |
| 1868 | + 'function_call': { |
| 1869 | + 'args': {'a': 1}, |
| 1870 | + 'id': 'call_1', |
| 1871 | + 'name': 'tool1', |
1896 | 1872 | } |
1897 | | - ) |
| 1873 | + }, |
| 1874 | + { |
| 1875 | + 'function_call': { |
| 1876 | + 'args': {'b': 2}, |
| 1877 | + 'id': 'call_2', |
| 1878 | + 'name': 'tool2', |
| 1879 | + } |
| 1880 | + }, |
| 1881 | + {'text': 'I have completed the function calls above.'}, |
1898 | 1882 | ], |
1899 | | - id='dict_parts_without_function_call_or_text', |
| 1883 | + id='multiple_function_calls_only', |
1900 | 1884 | ), |
1901 | 1885 | pytest.param( |
1902 | | - MockContentResponse({'role': 'model'}), # No 'parts' key |
| 1886 | + [ThinkingPart(content='Let me think...')], |
1903 | 1887 | [], |
1904 | | - id='no_parts_key', |
| 1888 | + id='thinking_only_empty_parts', |
| 1889 | + ), |
| 1890 | + pytest.param( |
| 1891 | + [ |
| 1892 | + ThinkingPart(content='Let me think...'), |
| 1893 | + ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'), |
| 1894 | + ], |
| 1895 | + [ |
| 1896 | + { |
| 1897 | + 'function_call': { |
| 1898 | + 'args': {'param': 'value'}, |
| 1899 | + 'id': 'call_123', |
| 1900 | + 'name': 'test_tool', |
| 1901 | + } |
| 1902 | + }, |
| 1903 | + {'text': 'I have completed the function calls above.'}, |
| 1904 | + ], |
| 1905 | + id='thinking_and_function_call', |
1905 | 1906 | ), |
1906 | 1907 | pytest.param( |
1907 | | - MockContentResponse({'role': 'model', 'parts': []}), # Empty parts |
1908 | 1908 | [], |
1909 | | - id='empty_parts_list', |
| 1909 | + [], |
| 1910 | + id='empty_parts', |
1910 | 1911 | ), |
1911 | 1912 | ], |
1912 | 1913 | ) |
1913 | | -async def test_google_model_response_edge_cases( |
1914 | | - google_provider: GoogleProvider, |
1915 | | - mock_content_response: MockContentResponse, |
1916 | | - expected_contents: list[ExpectedContent], |
| 1914 | +def test_content_model_response_function_call_handling( |
| 1915 | + model_parts: list[ModelResponsePart], expected_parts: list[dict[str, Any]] |
1917 | 1916 | ): |
1918 | | - """Test Google model's _map_messages method with various edge cases for function call handling.""" |
1919 | | - from unittest.mock import patch |
| 1917 | + """Test _content_model_response function's handling of function calls without text.""" |
1920 | 1918 |
|
1921 | | - model = GoogleModel('gemini-2.0-flash', provider=google_provider) |
1922 | 1919 | model_response = ModelResponse( |
1923 | | - parts=[ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')], |
| 1920 | + parts=model_parts, |
1924 | 1921 | usage=Usage(requests=1, request_tokens=10, response_tokens=5, total_tokens=15), |
1925 | 1922 | model_name='gemini-2.0-flash', |
1926 | 1923 | ) |
1927 | 1924 |
|
1928 | | - with patch('pydantic_ai.models.google._content_model_response') as mock_content: |
1929 | | - mock_content.return_value = mock_content_response |
1930 | | - _, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage] |
1931 | | - assert contents == expected_contents |
| 1925 | + result = _content_model_response(model_response) |
| 1926 | + |
| 1927 | + expected_result = {'role': 'model', 'parts': expected_parts} |
| 1928 | + assert result == expected_result |
0 commit comments