|
26 | 26 | ImageUrl, |
27 | 27 | ModelRequest, |
28 | 28 | ModelResponse, |
| 29 | + ModelResponsePart, |
29 | 30 | PartDeltaEvent, |
30 | 31 | PartStartEvent, |
31 | 32 | RetryPromptPart, |
@@ -1740,3 +1741,70 @@ async def get_user_country() -> str: |
1740 | 1741 | 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', |
1741 | 1742 | usage_limits=UsageLimits(total_tokens_limit=9, count_tokens_before_request=True), |
1742 | 1743 | ) |
| 1744 | + |
| 1745 | + |
| 1746 | +@pytest.mark.parametrize( |
| 1747 | + 'model_parts,expected_contents', |
| 1748 | + [ |
| 1749 | + pytest.param( |
| 1750 | + [ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')], |
| 1751 | + [ |
| 1752 | + { |
| 1753 | + 'role': 'model', |
| 1754 | + 'parts': [ |
| 1755 | + { |
| 1756 | + 'function_call': { |
| 1757 | + 'args': {'param': 'value'}, |
| 1758 | + 'id': 'call_123', |
| 1759 | + 'name': 'test_tool', |
| 1760 | + } |
| 1761 | + }, |
| 1762 | + {'text': 'I have completed the function calls above.'}, |
| 1763 | + ], |
| 1764 | + } |
| 1765 | + ], |
| 1766 | + id='function_call_without_text', |
| 1767 | + ), |
| 1768 | + pytest.param( |
| 1769 | + [], |
| 1770 | + [], |
| 1771 | + id='empty_response_parts', |
| 1772 | + ), |
| 1773 | + pytest.param( |
| 1774 | + [ |
| 1775 | + ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'), |
| 1776 | + TextPart(content='Here is the result:'), |
| 1777 | + ], |
| 1778 | + [ |
| 1779 | + { |
| 1780 | + 'role': 'model', |
| 1781 | + 'parts': [ |
| 1782 | + { |
| 1783 | + 'function_call': { |
| 1784 | + 'args': {'param': 'value'}, |
| 1785 | + 'id': 'call_123', |
| 1786 | + 'name': 'test_tool', |
| 1787 | + } |
| 1788 | + }, |
| 1789 | + {'text': 'Here is the result:'}, |
| 1790 | + ], |
| 1791 | + } |
| 1792 | + ], |
| 1793 | + id='function_call_with_text', |
| 1794 | + ), |
| 1795 | + ], |
| 1796 | +) |
| 1797 | +async def test_google_model_response_part_handling( |
| 1798 | + google_provider: GoogleProvider, model_parts: list[ModelResponsePart], expected_contents: list[dict[str, Any]] |
| 1799 | +): |
| 1800 | + """Test Google model's handling of different response part combinations for API compatibility.""" |
| 1801 | + model = GoogleModel('gemini-2.0-flash', provider=google_provider) |
| 1802 | + |
| 1803 | + model_response = ModelResponse( |
| 1804 | + parts=model_parts, |
| 1805 | + usage=Usage(requests=1, request_tokens=10, response_tokens=5, total_tokens=15), |
| 1806 | + model_name='gemini-2.0-flash', |
| 1807 | + ) |
| 1808 | + |
| 1809 | + _, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage] |
| 1810 | + assert contents == expected_contents |
0 commit comments