Skip to content

Commit 9ebde8e

Browse files
authored
Merge branch 'main' into mcp-json-env-vars
2 parents c19413f + b8d2904 commit 9ebde8e

File tree

10 files changed

+187
-60
lines changed

10 files changed

+187
-60
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ jobs:
8080
- run: make docs
8181

8282
- run: make docs-insiders
83-
if: github.event.pull_request.head.repo.full_name == github.repository || github.ref == 'refs/heads/main'
83+
if: (github.event.pull_request.head.repo.full_name == github.repository || github.ref == 'refs/heads/main') && github.repository == 'pydantic/pydantic-ai'
8484
env:
8585
PPPR_TOKEN: ${{ secrets.PPPR_TOKEN }}
8686

@@ -103,7 +103,7 @@ jobs:
103103
test-live:
104104
runs-on: ubuntu-latest
105105
timeout-minutes: 5
106-
if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name == 'push'
106+
if: (github.event.pull_request.head.repo.full_name == github.repository || github.event_name == 'push') && github.repository == 'pydantic/pydantic-ai'
107107
steps:
108108
- uses: actions/checkout@v4
109109

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@ node_modules/
2121
/test_tmp/
2222
.mcp.json
2323
.claude/
24+
/.cursor/
25+
/.devcontainer/

pydantic_ai_slim/pydantic_ai/exceptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def __init__(self, message: str):
4444
def __eq__(self, other: Any) -> bool:
4545
return isinstance(other, self.__class__) and other.message == self.message
4646

47+
def __hash__(self) -> int:
48+
return hash((self.__class__, self.message))
49+
4750
@classmethod
4851
def __get_pydantic_core_schema__(cls, _: Any, __: Any) -> core_schema.CoreSchema:
4952
"""Pydantic core schema to allow `ModelRetry` to be (de)serialized."""

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,10 +776,11 @@ def model_response_str(self) -> str:
776776
def model_response_object(self) -> dict[str, Any]:
777777
"""Return a dictionary representation of the content, wrapping non-dict types appropriately."""
778778
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
779-
if isinstance(self.content, dict):
780-
return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
779+
json_content = tool_return_ta.dump_python(self.content, mode='json')
780+
if isinstance(json_content, dict):
781+
return json_content # type: ignore[reportUnknownReturn]
781782
else:
782-
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
783+
return {'return_value': json_content}
783784

784785
def otel_event(self, settings: InstrumentationSettings) -> Event:
785786
return Event(

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -677,13 +677,18 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
677677
provider_name=self.provider_name,
678678
)
679679

680-
if part.text:
681-
if part.thought:
682-
yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
683-
else:
684-
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
685-
if maybe_event is not None: # pragma: no branch
686-
yield maybe_event
680+
if part.text is not None:
681+
if len(part.text) > 0:
682+
if part.thought:
683+
yield self._parts_manager.handle_thinking_delta(
684+
vendor_part_id='thinking', content=part.text
685+
)
686+
else:
687+
maybe_event = self._parts_manager.handle_text_delta(
688+
vendor_part_id='content', content=part.text
689+
)
690+
if maybe_event is not None: # pragma: no branch
691+
yield maybe_event
687692
elif part.function_call:
688693
maybe_event = self._parts_manager.handle_tool_call_delta(
689694
vendor_part_id=uuid4(),
@@ -822,7 +827,10 @@ def _process_response_from_parts(
822827
elif part.code_execution_result is not None:
823828
assert code_execution_tool_call_id is not None
824829
item = _map_code_execution_result(part.code_execution_result, provider_name, code_execution_tool_call_id)
825-
elif part.text:
830+
elif part.text is not None:
831+
# Google sometimes sends empty text parts, we don't want to add them to the response
832+
if len(part.text) == 0:
833+
continue
826834
if part.thought:
827835
item = ThinkingPart(content=part.text)
828836
else:

pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_event_stream.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pydantic_core import to_json
1010

1111
from ...messages import (
12+
BaseToolReturnPart,
1213
BuiltinToolCallPart,
1314
BuiltinToolReturnPart,
1415
FilePart,
@@ -155,21 +156,23 @@ async def handle_tool_call_delta(self, delta: ToolCallPartDelta) -> AsyncIterato
155156
)
156157

157158
async def handle_tool_call_end(self, part: ToolCallPart) -> AsyncIterator[BaseChunk]:
158-
yield ToolInputAvailableChunk(tool_call_id=part.tool_call_id, tool_name=part.tool_name, input=part.args)
159+
yield ToolInputAvailableChunk(
160+
tool_call_id=part.tool_call_id, tool_name=part.tool_name, input=part.args_as_dict()
161+
)
159162

160163
async def handle_builtin_tool_call_end(self, part: BuiltinToolCallPart) -> AsyncIterator[BaseChunk]:
161164
yield ToolInputAvailableChunk(
162165
tool_call_id=part.tool_call_id,
163166
tool_name=part.tool_name,
164-
input=part.args,
167+
input=part.args_as_dict(),
165168
provider_executed=True,
166169
provider_metadata={'pydantic_ai': {'provider_name': part.provider_name}},
167170
)
168171

169172
async def handle_builtin_tool_return(self, part: BuiltinToolReturnPart) -> AsyncIterator[BaseChunk]:
170173
yield ToolOutputAvailableChunk(
171174
tool_call_id=part.tool_call_id,
172-
output=part.content,
175+
output=self._tool_return_output(part),
173176
provider_executed=True,
174177
)
175178

@@ -178,10 +181,15 @@ async def handle_file(self, part: FilePart) -> AsyncIterator[BaseChunk]:
178181
yield FileChunk(url=file.data_uri, media_type=file.media_type)
179182

180183
async def handle_function_tool_result(self, event: FunctionToolResultEvent) -> AsyncIterator[BaseChunk]:
181-
result = event.result
182-
if isinstance(result, RetryPromptPart):
183-
yield ToolOutputErrorChunk(tool_call_id=result.tool_call_id, error_text=result.model_response())
184+
part = event.result
185+
if isinstance(part, RetryPromptPart):
186+
yield ToolOutputErrorChunk(tool_call_id=part.tool_call_id, error_text=part.model_response())
184187
else:
185-
yield ToolOutputAvailableChunk(tool_call_id=result.tool_call_id, output=result.content)
188+
yield ToolOutputAvailableChunk(tool_call_id=part.tool_call_id, output=self._tool_return_output(part))
186189

187190
# ToolCallResultEvent.content may hold user parts (e.g. text, images) that Vercel AI does not currently have events for
191+
192+
def _tool_return_output(self, part: BaseToolReturnPart) -> Any:
193+
output = part.model_response_object()
194+
# Unwrap the return value from the output dictionary if it exists
195+
return output.get('return_value', output)

tests/models/test_google.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import datetime
44
import os
55
import re
6+
from collections.abc import AsyncIterator
67
from typing import Any
78

89
import pytest
@@ -47,6 +48,7 @@
4748
BuiltinToolCallEvent, # pyright: ignore[reportDeprecated]
4849
BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
4950
)
51+
from pydantic_ai.models import ModelRequestParameters
5052
from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput
5153
from pydantic_ai.settings import ModelSettings
5254
from pydantic_ai.usage import RequestUsage, RunUsage, UsageLimits
@@ -56,6 +58,7 @@
5658

5759
with try_import() as imports_successful:
5860
from google.genai.types import (
61+
FinishReason as GoogleFinishReason,
5962
GenerateContentResponse,
6063
GenerateContentResponseUsageMetadata,
6164
HarmBlockThreshold,
@@ -64,7 +67,12 @@
6467
ModalityTokenCount,
6568
)
6669

67-
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings, _metadata_as_usage # type: ignore
70+
from pydantic_ai.models.google import (
71+
GeminiStreamedResponse,
72+
GoogleModel,
73+
GoogleModelSettings,
74+
_metadata_as_usage, # pyright: ignore[reportPrivateUsage]
75+
)
6876
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
6977
from pydantic_ai.providers.google import GoogleProvider
7078
from pydantic_ai.providers.openai import OpenAIProvider
@@ -3063,3 +3071,52 @@ async def test_google_httpx_client_is_not_closed(allow_model_requests: None, gem
30633071
agent = Agent(GoogleModel('gemini-2.5-flash-lite', provider=GoogleProvider(api_key=gemini_api_key)))
30643072
result = await agent.run('What is the capital of Mexico?')
30653073
assert result.output == snapshot('The capital of Mexico is **Mexico City**.')
3074+
3075+
3076+
def test_google_process_response_filters_empty_text_parts(google_provider: GoogleProvider):
3077+
model = GoogleModel('gemini-2.5-pro', provider=google_provider)
3078+
response = _generate_response_with_texts(response_id='resp-123', texts=['', 'first', '', 'second'])
3079+
3080+
result = model._process_response(response) # pyright: ignore[reportPrivateUsage]
3081+
3082+
assert result.parts == snapshot([TextPart(content='first'), TextPart(content='second')])
3083+
3084+
3085+
async def test_gemini_streamed_response_emits_text_events_for_non_empty_parts():
3086+
chunk = _generate_response_with_texts('stream-1', ['', 'streamed text'])
3087+
3088+
async def response_iterator() -> AsyncIterator[GenerateContentResponse]:
3089+
yield chunk
3090+
3091+
streamed_response = GeminiStreamedResponse(
3092+
model_request_parameters=ModelRequestParameters(),
3093+
_model_name='gemini-test',
3094+
_response=response_iterator(),
3095+
_timestamp=datetime.datetime.now(datetime.timezone.utc),
3096+
_provider_name='test-provider',
3097+
)
3098+
3099+
events = [event async for event in streamed_response._get_event_iterator()] # pyright: ignore[reportPrivateUsage]
3100+
assert events == snapshot([PartStartEvent(index=0, part=TextPart(content='streamed text'))])
3101+
3102+
3103+
def _generate_response_with_texts(response_id: str, texts: list[str]) -> GenerateContentResponse:
3104+
return GenerateContentResponse.model_validate(
3105+
{
3106+
'response_id': response_id,
3107+
'model_version': 'gemini-test',
3108+
'usage_metadata': GenerateContentResponseUsageMetadata(
3109+
prompt_token_count=0,
3110+
candidates_token_count=0,
3111+
),
3112+
'candidates': [
3113+
{
3114+
'finish_reason': GoogleFinishReason.STOP,
3115+
'content': {
3116+
'role': 'model',
3117+
'parts': [{'text': text} for text in texts],
3118+
},
3119+
}
3120+
],
3121+
}
3122+
)

tests/test_agent.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3724,13 +3724,11 @@ def test_tool_return_part_binary_content_serialization():
37243724

37253725
assert tool_return.model_response_object() == snapshot(
37263726
{
3727-
'return_value': {
3728-
'data': 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzgAAAAASUVORK5CYII=',
3729-
'media_type': 'image/png',
3730-
'vendor_metadata': None,
3731-
'_identifier': None,
3732-
'kind': 'binary',
3733-
}
3727+
'data': 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzgAAAAASUVORK5CYII=',
3728+
'media_type': 'image/png',
3729+
'vendor_metadata': None,
3730+
'_identifier': None,
3731+
'kind': 'binary',
37343732
}
37353733
)
37363734

tests/test_exceptions.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Tests for exception classes."""
2+
3+
from collections.abc import Callable
4+
from typing import Any
5+
6+
import pytest
7+
8+
from pydantic_ai import ModelRetry
9+
from pydantic_ai.exceptions import (
10+
AgentRunError,
11+
ApprovalRequired,
12+
CallDeferred,
13+
IncompleteToolCall,
14+
ModelHTTPError,
15+
UnexpectedModelBehavior,
16+
UsageLimitExceeded,
17+
UserError,
18+
)
19+
20+
21+
@pytest.mark.parametrize(
22+
'exc_factory',
23+
[
24+
lambda: ModelRetry('test'),
25+
lambda: CallDeferred(),
26+
lambda: ApprovalRequired(),
27+
lambda: UserError('test'),
28+
lambda: AgentRunError('test'),
29+
lambda: UnexpectedModelBehavior('test'),
30+
lambda: UsageLimitExceeded('test'),
31+
lambda: ModelHTTPError(500, 'model'),
32+
lambda: IncompleteToolCall('test'),
33+
],
34+
ids=[
35+
'ModelRetry',
36+
'CallDeferred',
37+
'ApprovalRequired',
38+
'UserError',
39+
'AgentRunError',
40+
'UnexpectedModelBehavior',
41+
'UsageLimitExceeded',
42+
'ModelHTTPError',
43+
'IncompleteToolCall',
44+
],
45+
)
46+
def test_exceptions_hashable(exc_factory: Callable[[], Any]):
47+
"""Test that all exception classes are hashable and usable as keys."""
48+
exc = exc_factory()
49+
50+
# Does not raise TypeError
51+
_ = hash(exc)
52+
53+
# Can be used in sets and dicts
54+
s = {exc}
55+
d = {exc: 'value'}
56+
57+
assert exc in s
58+
assert d[exc] == 'value'

0 commit comments

Comments
 (0)