-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Handle function calls without text #2557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| ImageUrl, | ||
| ModelRequest, | ||
| ModelResponse, | ||
| ModelResponsePart, | ||
| PartDeltaEvent, | ||
| PartStartEvent, | ||
| RetryPromptPart, | ||
|
|
@@ -60,7 +61,12 @@ | |
| Outcome, | ||
| ) | ||
|
|
||
| from pydantic_ai.models.google import GoogleModel, GoogleModelSettings, _metadata_as_usage # type: ignore | ||
| from pydantic_ai.models.google import ( | ||
| GoogleModel, | ||
| GoogleModelSettings, | ||
| _content_model_response, # pyright: ignore[reportPrivateUsage] | ||
| _metadata_as_usage, | ||
| ) | ||
| from pydantic_ai.providers.google import GoogleProvider | ||
|
|
||
| pytestmark = [ | ||
|
|
@@ -1828,3 +1834,184 @@ class CityLocation(BaseModel): | |
| agent = Agent(m, output_type=PromptedOutput(CityLocation), builtin_tools=[UrlContextTool()]) | ||
| result = await agent.run('What is the largest city in Mexico?') | ||
| assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be great if we had a minimal reproducible example that triggers the error against the actual API, so we can verify what parts of the fix are necessary (e.g. is sending back thinking parts enough? does a one-space string work? is it only Vertex or also GLA?). Right now, we're effectively duplicating the fix in tests, rather than verifying that a previously-failing example now succeeds. |
||
| 'model_parts,expected_contents', | ||
| [ | ||
| pytest.param( | ||
| [ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')], | ||
| [ | ||
| { | ||
| 'role': 'model', | ||
| 'parts': [ | ||
| { | ||
| 'function_call': { | ||
| 'args': {'param': 'value'}, | ||
| 'id': 'call_123', | ||
| 'name': 'test_tool', | ||
| } | ||
| }, | ||
| {'text': 'I have completed the function calls above.'}, | ||
| ], | ||
| } | ||
| ], | ||
| id='function_call_without_text', | ||
| ), | ||
| pytest.param( | ||
| [], | ||
| [], | ||
| id='empty_response_parts', | ||
| ), | ||
| pytest.param( | ||
| [ | ||
| ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'), | ||
| TextPart(content='Here is the result:'), | ||
| ], | ||
| [ | ||
| { | ||
| 'role': 'model', | ||
| 'parts': [ | ||
| { | ||
| 'function_call': { | ||
| 'args': {'param': 'value'}, | ||
| 'id': 'call_123', | ||
| 'name': 'test_tool', | ||
| } | ||
| }, | ||
| {'text': 'Here is the result:'}, | ||
| ], | ||
| } | ||
| ], | ||
| id='function_call_with_text', | ||
| ), | ||
| pytest.param( | ||
| [ThinkingPart(content='Let me think about this...')], | ||
| [], | ||
| id='thinking_only_response_skipped', | ||
| ), | ||
| ], | ||
| ) | ||
| async def test_google_model_response_part_handling( | ||
| google_provider: GoogleProvider, model_parts: list[ModelResponsePart], expected_contents: list[dict[str, Any]] | ||
| ): | ||
| """Test Google model's handling of different response part combinations for API compatibility.""" | ||
| model = GoogleModel('gemini-2.0-flash', provider=google_provider) | ||
|
|
||
| model_response = ModelResponse( | ||
| parts=model_parts, | ||
| usage=RequestUsage(input_tokens=10, output_tokens=5), | ||
| model_name='gemini-2.0-flash', | ||
| ) | ||
|
|
||
| _, contents = await model._map_messages([model_response]) # pyright: ignore[reportPrivateUsage] | ||
| assert contents == expected_contents | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| 'model_parts,expected_parts', | ||
| [ | ||
| pytest.param( | ||
| [ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123')], | ||
| [ | ||
| { | ||
| 'function_call': { | ||
| 'args': {'param': 'value'}, | ||
| 'id': 'call_123', | ||
| 'name': 'test_tool', | ||
| } | ||
| }, | ||
| {'text': 'I have completed the function calls above.'}, | ||
| ], | ||
| id='function_call_only_adds_text', | ||
| ), | ||
| pytest.param( | ||
| [ | ||
| ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'), | ||
| TextPart(content='Here is the result:'), | ||
| ], | ||
| [ | ||
| { | ||
| 'function_call': { | ||
| 'args': {'param': 'value'}, | ||
| 'id': 'call_123', | ||
| 'name': 'test_tool', | ||
| } | ||
| }, | ||
| {'text': 'Here is the result:'}, | ||
| ], | ||
| id='function_call_with_text_no_addition', | ||
| ), | ||
| pytest.param( | ||
| [TextPart(content='Just text response')], | ||
| [{'text': 'Just text response'}], | ||
| id='text_only_no_addition', | ||
| ), | ||
| pytest.param( | ||
| [ | ||
| ToolCallPart(tool_name='tool1', args={'a': 1}, tool_call_id='call_1'), | ||
| ToolCallPart(tool_name='tool2', args={'b': 2}, tool_call_id='call_2'), | ||
| ], | ||
| [ | ||
| { | ||
| 'function_call': { | ||
| 'args': {'a': 1}, | ||
| 'id': 'call_1', | ||
| 'name': 'tool1', | ||
| } | ||
| }, | ||
| { | ||
| 'function_call': { | ||
| 'args': {'b': 2}, | ||
| 'id': 'call_2', | ||
| 'name': 'tool2', | ||
| } | ||
| }, | ||
| {'text': 'I have completed the function calls above.'}, | ||
| ], | ||
| id='multiple_function_calls_only', | ||
| ), | ||
| pytest.param( | ||
| [ThinkingPart(content='Let me think...')], | ||
| [], | ||
| id='thinking_only_empty_parts', | ||
| ), | ||
| pytest.param( | ||
| [ | ||
| ThinkingPart(content='Let me think...'), | ||
| ToolCallPart(tool_name='test_tool', args={'param': 'value'}, tool_call_id='call_123'), | ||
| ], | ||
| [ | ||
| { | ||
| 'function_call': { | ||
| 'args': {'param': 'value'}, | ||
| 'id': 'call_123', | ||
| 'name': 'test_tool', | ||
| } | ||
| }, | ||
| {'text': 'I have completed the function calls above.'}, | ||
| ], | ||
| id='thinking_and_function_call', | ||
| ), | ||
| pytest.param( | ||
| [], | ||
| [], | ||
| id='empty_parts', | ||
| ), | ||
| ], | ||
| ) | ||
| def test_content_model_response_function_call_handling( | ||
| model_parts: list[ModelResponsePart], expected_parts: list[dict[str, Any]] | ||
| ): | ||
| """Test _content_model_response function's handling of function calls without text.""" | ||
|
|
||
| model_response = ModelResponse( | ||
| parts=model_parts, | ||
| usage=RequestUsage(input_tokens=10, output_tokens=5), | ||
| model_name='gemini-2.0-flash', | ||
| ) | ||
|
|
||
| result = _content_model_response(model_response) | ||
|
|
||
| expected_result = {'role': 'model', 'parts': expected_parts} | ||
| assert result == expected_result | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ethanabrooks Does this mean this would be a non-issue once we start sending back thinking parts? (#2594) I mean to work on that next week