|
2 | 2 |
|
3 | 3 | import datetime |
4 | 4 | import os |
5 | | -from typing import Any |
| 5 | +from typing import Any, Union |
6 | 6 |
|
7 | 7 | import pytest |
8 | 8 | from httpx import Timeout |
|
26 | 26 | ImageUrl, |
27 | 27 | ModelRequest, |
28 | 28 | ModelResponse, |
| 29 | + ModelResponsePart, |
29 | 30 | PartDeltaEvent, |
30 | 31 | PartStartEvent, |
31 | 32 | RetryPromptPart, |
@@ -1740,3 +1741,191 @@ async def get_user_country() -> str: |
1740 | 1741 | 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', |
1741 | 1742 | usage_limits=UsageLimits(total_tokens_limit=9, count_tokens_before_request=True), |
1742 | 1743 | ) |
| 1744 | + |
| 1745 | + |
| 1746 | +@pytest.mark.parametrize( |
| 1747 | + 'model_parts,expected_contents', |
| 1748 | + [ |
| 1749 | + pytest.param( |
| 1750 | + [ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')], |
| 1751 | + [ |
| 1752 | + { |
| 1753 | + 'role': 'model', |
| 1754 | + 'parts': [ |
| 1755 | + { |
| 1756 | + 'function_call': { |
| 1757 | + 'args': {'param': 'value'}, |
| 1758 | + 'id': 'call_123', |
| 1759 | + 'name': 'test_tool', |
| 1760 | + } |
| 1761 | + }, |
| 1762 | + {'text': 'I have completed the function calls above.'}, |
| 1763 | + ], |
| 1764 | + } |
| 1765 | + ], |
| 1766 | + id='function_call_without_text', |
| 1767 | + ), |
| 1768 | + pytest.param( |
| 1769 | + [], |
| 1770 | + [], |
| 1771 | + id='empty_response_parts', |
| 1772 | + ), |
| 1773 | + pytest.param( |
| 1774 | + [ |
| 1775 | + ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'), |
| 1776 | + TextPart(content='Here is the result:'), |
| 1777 | + ], |
| 1778 | + [ |
| 1779 | + { |
| 1780 | + 'role': 'model', |
| 1781 | + 'parts': [ |
| 1782 | + { |
| 1783 | + 'function_call': { |
| 1784 | + 'args': {'param': 'value'}, |
| 1785 | + 'id': 'call_123', |
| 1786 | + 'name': 'test_tool', |
| 1787 | + } |
| 1788 | + }, |
| 1789 | + {'text': 'Here is the result:'}, |
| 1790 | + ], |
| 1791 | + } |
| 1792 | + ], |
| 1793 | + id='function_call_with_text', |
| 1794 | + ), |
| 1795 | + pytest.param( |
| 1796 | + [ThinkingPart(content='Let me think about this...')], |
| 1797 | + [], |
| 1798 | + id='thinking_only_response_skipped', |
| 1799 | + ), |
| 1800 | + ], |
| 1801 | +) |
| 1802 | +async def test_google_model_response_part_handling( |
| 1803 | + google_provider: GoogleProvider, model_parts: list[ModelResponsePart], expected_contents: list[dict[str, Any]] |
| 1804 | +): |
| 1805 | + """Test Google model's handling of different response part combinations for API compatibility.""" |
| 1806 | + model = GoogleModel('gemini-2.0-flash', provider=google_provider) |
| 1807 | + |
| 1808 | + model_response = ModelResponse( |
| 1809 | + parts=model_parts, |
| 1810 | + usage=Usage(requests=1, request_tokens=10, response_tokens=5, total_tokens=15), |
| 1811 | + model_name='gemini-2.0-flash', |
| 1812 | + ) |
| 1813 | + |
| 1814 | + _, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage] |
| 1815 | + assert contents == expected_contents |
| 1816 | + |
| 1817 | + |
| 1818 | +class FunctionCallDict(TypedDict): |
| 1819 | + name: str |
| 1820 | + args: dict[str, Any] |
| 1821 | + id: str |
| 1822 | + |
| 1823 | + |
| 1824 | +class FunctionCallPartDict(TypedDict): |
| 1825 | + function_call: FunctionCallDict |
| 1826 | + |
| 1827 | + |
| 1828 | +class TextPartDict(TypedDict): |
| 1829 | + text: str |
| 1830 | + |
| 1831 | + |
| 1832 | +class OtherPartDict(TypedDict, total=False): |
| 1833 | + other_field: str |
| 1834 | + |
| 1835 | + |
| 1836 | +# Union of all possible part types we're testing |
| 1837 | +TestPartDict = Union[FunctionCallPartDict, TextPartDict, OtherPartDict, str] # str for non-dict parts |
| 1838 | + |
| 1839 | + |
| 1840 | +class MockContentResponse(TypedDict, total=False): |
| 1841 | + role: str |
| 1842 | + parts: list[TestPartDict] |
| 1843 | + |
| 1844 | + |
| 1845 | +class ExpectedContent(TypedDict, total=False): |
| 1846 | + role: str |
| 1847 | + parts: list[TestPartDict] |
| 1848 | + |
| 1849 | + |
| 1850 | +@pytest.mark.parametrize( |
| 1851 | + 'mock_content_response,expected_contents', |
| 1852 | + [ |
| 1853 | + pytest.param( |
| 1854 | + MockContentResponse( |
| 1855 | + { |
| 1856 | + 'role': 'model', |
| 1857 | + 'parts': [ |
| 1858 | + 'not_a_dict', # Non-dict part to test isinstance check |
| 1859 | + {'function_call': {'name': 'test', 'args': {}, 'id': '123'}}, |
| 1860 | + ], |
| 1861 | + } |
| 1862 | + ), |
| 1863 | + [ |
| 1864 | + ExpectedContent( |
| 1865 | + { |
| 1866 | + 'role': 'model', |
| 1867 | + 'parts': [ |
| 1868 | + 'not_a_dict', |
| 1869 | + {'function_call': {'name': 'test', 'args': {}, 'id': '123'}}, |
| 1870 | + {'text': 'I have completed the function calls above.'}, |
| 1871 | + ], |
| 1872 | + } |
| 1873 | + ) |
| 1874 | + ], |
| 1875 | + id='non_dict_parts_with_function_call', |
| 1876 | + ), |
| 1877 | + pytest.param( |
| 1878 | + MockContentResponse( |
| 1879 | + { |
| 1880 | + 'role': 'model', |
| 1881 | + 'parts': [ |
| 1882 | + {'other_field': 'value'}, # Dict without function_call or text |
| 1883 | + {'function_call': {'name': 'test', 'args': {}, 'id': '123'}}, |
| 1884 | + ], |
| 1885 | + } |
| 1886 | + ), |
| 1887 | + [ |
| 1888 | + ExpectedContent( |
| 1889 | + { |
| 1890 | + 'role': 'model', |
| 1891 | + 'parts': [ |
| 1892 | + {'other_field': 'value'}, |
| 1893 | + {'function_call': {'name': 'test', 'args': {}, 'id': '123'}}, |
| 1894 | + {'text': 'I have completed the function calls above.'}, |
| 1895 | + ], |
| 1896 | + } |
| 1897 | + ) |
| 1898 | + ], |
| 1899 | + id='dict_parts_without_function_call_or_text', |
| 1900 | + ), |
| 1901 | + pytest.param( |
| 1902 | + MockContentResponse({'role': 'model'}), # No 'parts' key |
| 1903 | + [], |
| 1904 | + id='no_parts_key', |
| 1905 | + ), |
| 1906 | + pytest.param( |
| 1907 | + MockContentResponse({'role': 'model', 'parts': []}), # Empty parts |
| 1908 | + [], |
| 1909 | + id='empty_parts_list', |
| 1910 | + ), |
| 1911 | + ], |
| 1912 | +) |
| 1913 | +async def test_google_model_response_edge_cases( |
| 1914 | + google_provider: GoogleProvider, |
| 1915 | + mock_content_response: MockContentResponse, |
| 1916 | + expected_contents: list[ExpectedContent], |
| 1917 | +): |
| 1918 | + """Test Google model's _map_messages method with various edge cases for function call handling.""" |
| 1919 | + from unittest.mock import patch |
| 1920 | + |
| 1921 | + model = GoogleModel('gemini-2.0-flash', provider=google_provider) |
| 1922 | + model_response = ModelResponse( |
| 1923 | + parts=[ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')], |
| 1924 | + usage=Usage(requests=1, request_tokens=10, response_tokens=5, total_tokens=15), |
| 1925 | + model_name='gemini-2.0-flash', |
| 1926 | + ) |
| 1927 | + |
| 1928 | + with patch('pydantic_ai.models.google._content_model_response') as mock_content: |
| 1929 | + mock_content.return_value = mock_content_response |
| 1930 | + _, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage] |
| 1931 | + assert contents == expected_contents |
0 commit comments