|
10 | 10 | from inline_snapshot import Is, snapshot |
11 | 11 | from pydantic import BaseModel |
12 | 12 | from typing_extensions import TypedDict |
13 | | -from unittest.mock import AsyncMock, MagicMock |
14 | 13 |
|
15 | 14 | from pydantic_ai import ( |
16 | 15 | AudioUrl, |
@@ -1896,6 +1895,7 @@ def dummy() -> None: ... # pragma: no cover |
1896 | 1895 | PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=IsStr())), |
1897 | 1896 | PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=IsStr())), |
1898 | 1897 | PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=IsStr())), |
| 1898 | + PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=IsStr())), |
1899 | 1899 | PartEndEvent( |
1900 | 1900 | index=1, |
1901 | 1901 | part=TextPart( |
@@ -3091,205 +3091,6 @@ async def test_google_vertexai_image_generation(allow_model_requests: None, vert |
3091 | 3091 | ) |
3092 | 3092 |
|
3093 | 3093 |
|
3094 | | -def _create_google_stream_mock_helpers(): |
3095 | | - from google.genai.types import Candidate, Content, GenerateContentResponse, Part |
3096 | | - |
3097 | | - class _FR(str): |
3098 | | - def __new__(cls, name: str): |
3099 | | - obj = str.__new__(cls, name) |
3100 | | - return obj |
3101 | | - |
3102 | | - @property |
3103 | | - def value(self) -> str: |
3104 | | - return str(self) |
3105 | | - |
3106 | | - def create_mock_response(**kwargs: Any) -> MagicMock: |
3107 | | - mock = MagicMock(spec=GenerateContentResponse) |
3108 | | - for key, value in kwargs.items(): |
3109 | | - setattr(mock, key, value) |
3110 | | - mock.model_version = 'gemini-1.5-flash' |
3111 | | - mock.create_time = None |
3112 | | - mock.response_id = 'test-response-id' |
3113 | | - mock.usage_metadata = None |
3114 | | - return mock |
3115 | | - |
3116 | | - def create_mock_candidate(**kwargs: Any) -> MagicMock: |
3117 | | - mock = MagicMock(spec=Candidate) |
3118 | | - mock.grounding_metadata = None |
3119 | | - if 'finish_reason' in kwargs and isinstance(kwargs['finish_reason'], str): |
3120 | | - kwargs['finish_reason'] = _FR(kwargs['finish_reason']) |
3121 | | - for key, value in kwargs.items(): |
3122 | | - setattr(mock, key, value) |
3123 | | - return mock |
3124 | | - |
3125 | | - def create_mock_content(*parts: MagicMock) -> MagicMock: |
3126 | | - mock = MagicMock(spec=Content) |
3127 | | - mock.parts = list(parts) |
3128 | | - return mock |
3129 | | - |
3130 | | - def create_mock_part(**kwargs: Any) -> MagicMock: |
3131 | | - mock = MagicMock(spec=Part) |
3132 | | - mock.text = None |
3133 | | - mock.function_call = None |
3134 | | - mock.thought_signature = None |
3135 | | - mock.thought = None |
3136 | | - mock.inline_data = None |
3137 | | - mock.executable_code = None |
3138 | | - mock.code_execution_result = None |
3139 | | - mock.function_response = None |
3140 | | - |
3141 | | - for key, value in kwargs.items(): |
3142 | | - setattr(mock, key, value) |
3143 | | - return mock |
3144 | | - |
3145 | | - return create_mock_response, create_mock_candidate, create_mock_content, create_mock_part |
3146 | | - |
3147 | | - |
3148 | | -async def test_google_model_stream_malformed_function_call_continue( |
3149 | | - allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch |
3150 | | -): |
3151 | | - """Test that the model can continue streaming after a malformed function call.""" |
3152 | | - mock_stream = AsyncMock() |
3153 | | - ( |
3154 | | - create_mock_response, |
3155 | | - create_mock_candidate, |
3156 | | - create_mock_content, |
3157 | | - create_mock_part, |
3158 | | - ) = _create_google_stream_mock_helpers() |
3159 | | - create_mock_response: Any |
3160 | | - create_mock_candidate: Any |
3161 | | - create_mock_content: Any |
3162 | | - create_mock_part: Any |
3163 | | - async def mock_stream_iterator(): |
3164 | | - # This part is a bit of a fiction, in reality the tool call and return would not be in the same stream. |
3165 | | - # But for testing the malformed function call in between, this is fine. |
3166 | | - yield create_mock_response( |
3167 | | - candidates=[ |
3168 | | - create_mock_candidate( |
3169 | | - content=create_mock_content( |
3170 | | - create_mock_part(function_call=MagicMock(name='get_capital', args={'country': 'France'})) |
3171 | | - ), |
3172 | | - finish_reason=None, |
3173 | | - ) |
3174 | | - ] |
3175 | | - ) |
3176 | | - |
3177 | | - yield create_mock_response( |
3178 | | - candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')] |
3179 | | - ) |
3180 | | - |
3181 | | - yield create_mock_response( |
3182 | | - candidates=[ |
3183 | | - create_mock_candidate( |
3184 | | - content=create_mock_content(create_mock_part(text='Here is the capital.')), finish_reason=None |
3185 | | - ) |
3186 | | - ] |
3187 | | - ) |
3188 | | - |
3189 | | - yield create_mock_response(candidates=[create_mock_candidate(content=None, finish_reason='STOP')]) |
3190 | | - |
3191 | | - mock_stream.return_value = mock_stream_iterator() |
3192 | | - model = GoogleModel('gemini-2.5-flash', provider=google_provider) |
3193 | | - monkeypatch.setattr(model.client.aio.models, 'generate_content_stream', mock_stream) |
3194 | | - |
3195 | | - agent = Agent(model) |
3196 | | - |
3197 | | - @agent.tool_plain |
3198 | | - def get_capital(country: str) -> str: |
3199 | | - """Get the capital of a country.""" |
3200 | | - return 'Paris' if country == 'France' else 'Unknown' |
3201 | | - |
3202 | | - async with agent.run_stream('What is the capital of France?') as result: |
3203 | | - data = await result.get_output() |
3204 | | - |
3205 | | - assert data == 'Here is the capital.' |
3206 | | - assert mock_stream.call_count == 1 |
3207 | | - |
3208 | | - |
3209 | | -async def test_google_model_stream_malformed_function_call_retry( |
3210 | | - allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch |
3211 | | -): |
3212 | | - """Test that the model retries when the stream only contains a malformed function call.""" |
3213 | | - mock_stream = AsyncMock() |
3214 | | - ( |
3215 | | - create_mock_response, |
3216 | | - create_mock_candidate, |
3217 | | - create_mock_content, |
3218 | | - create_mock_part, |
3219 | | - ) = _create_google_stream_mock_helpers() |
3220 | | - create_mock_response: Any |
3221 | | - create_mock_candidate: Any |
3222 | | - create_mock_content: Any |
3223 | | - create_mock_part: Any |
3224 | | - async def first_call_iterator(): |
3225 | | - yield create_mock_response( |
3226 | | - candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')] |
3227 | | - ) |
3228 | | - |
3229 | | - async def second_call_iterator(): |
3230 | | - yield create_mock_response( |
3231 | | - candidates=[ |
3232 | | - create_mock_candidate( |
3233 | | - content=create_mock_content(create_mock_part(text='Successful response')), finish_reason=None |
3234 | | - ) |
3235 | | - ] |
3236 | | - ) |
3237 | | - yield create_mock_response(candidates=[create_mock_candidate(content=None, finish_reason='STOP')]) |
3238 | | - |
3239 | | - mock_stream.side_effect = [first_call_iterator(), second_call_iterator()] |
3240 | | - |
3241 | | - model = GoogleModel('gemini-2.5-flash', provider=google_provider) |
3242 | | - monkeypatch.setattr(model.client.aio.models, 'generate_content_stream', mock_stream) |
3243 | | - |
3244 | | - agent = Agent(model, retries=1) |
3245 | | - |
3246 | | - async with agent.run_stream('Some prompt') as result: |
3247 | | - data = await result.get_output() |
3248 | | - |
3249 | | - assert data == 'Successful response' |
3250 | | - assert mock_stream.call_count == 2 |
3251 | | - |
3252 | | - |
3253 | | -async def test_google_model_malformed_function_call_retry( |
3254 | | - allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch |
3255 | | -): |
3256 | | - """Test that the model retries when non-streamed response contains a malformed function call.""" |
3257 | | - mock_generate = AsyncMock() |
3258 | | - ( |
3259 | | - create_mock_response, |
3260 | | - create_mock_candidate, |
3261 | | - create_mock_content, |
3262 | | - create_mock_part, |
3263 | | - ) = _create_google_stream_mock_helpers() |
3264 | | - create_mock_response: Any |
3265 | | - create_mock_candidate: Any |
3266 | | - create_mock_content: Any |
3267 | | - create_mock_part: Any |
3268 | | - first_response: MagicMock = create_mock_response( |
3269 | | - candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')] |
3270 | | - ) |
3271 | | - |
3272 | | - second_response: MagicMock = create_mock_response( |
3273 | | - candidates=[ |
3274 | | - create_mock_candidate( |
3275 | | - content=create_mock_content(create_mock_part(text='Successful response')), finish_reason='STOP' |
3276 | | - ) |
3277 | | - ] |
3278 | | - ) |
3279 | | - |
3280 | | - mock_generate.side_effect = [first_response, second_response] |
3281 | | - |
3282 | | - model = GoogleModel('gemini-2.5-flash', provider=google_provider) |
3283 | | - monkeypatch.setattr(model.client.aio.models, 'generate_content', mock_generate) |
3284 | | - |
3285 | | - agent = Agent(model, retries=1) |
3286 | | - |
3287 | | - result = await agent.run('Some prompt') |
3288 | | - |
3289 | | - assert result.output == 'Successful response' |
3290 | | - assert mock_generate.call_count == 2 |
3291 | | - |
3292 | | - |
3293 | 3094 | async def test_google_httpx_client_is_not_closed(allow_model_requests: None, gemini_api_key: str): |
3294 | 3095 | # This should not raise any errors, see https://github.com/pydantic/pydantic-ai/issues/3242. |
3295 | 3096 | agent = Agent(GoogleModel('gemini-2.5-flash-lite', provider=GoogleProvider(api_key=gemini_api_key))) |
|
0 commit comments