|
10 | 10 | from inline_snapshot import Is, snapshot |
11 | 11 | from pydantic import BaseModel |
12 | 12 | from typing_extensions import TypedDict |
| 13 | +from unittest.mock import AsyncMock |
| 14 | +from uuid import uuid4 |
13 | 15 |
|
14 | 16 | from pydantic_ai import ( |
15 | 17 | AudioUrl, |
@@ -2929,3 +2931,154 @@ async def test_google_vertexai_image_generation(allow_model_requests: None, vert |
2929 | 2931 | identifier='f3edd8', |
2930 | 2932 | ) |
2931 | 2933 | ) |
| 2934 | + |
| 2935 | + |
| 2936 | +async def test_google_model_stream_malformed_function_call_continue( |
| 2937 | + allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch |
| 2938 | +): |
| 2939 | + """Test that the model can continue streaming after a malformed function call.""" |
| 2940 | + from unittest.mock import MagicMock |
| 2941 | + |
| 2942 | + from google.genai.types import Candidate, Content, GenerateContentResponse, Part |
| 2943 | + |
| 2944 | + mock_stream = AsyncMock() |
| 2945 | + |
| 2946 | + def create_mock_response(**kwargs): |
| 2947 | + mock = MagicMock(spec=GenerateContentResponse) |
| 2948 | + for key, value in kwargs.items(): |
| 2949 | + setattr(mock, key, value) |
| 2950 | + mock.model_version = 'gemini-1.5-flash' |
| 2951 | + mock.create_time = None |
| 2952 | + mock.response_id = 'test-response-id' |
| 2953 | + return mock |
| 2954 | + |
| 2955 | + def create_mock_candidate(**kwargs): |
| 2956 | + mock = MagicMock(spec=Candidate) |
| 2957 | + for key, value in kwargs.items(): |
| 2958 | + setattr(mock, key, value) |
| 2959 | + return mock |
| 2960 | + |
| 2961 | + def create_mock_content(*parts): |
| 2962 | + mock = MagicMock(spec=Content) |
| 2963 | + mock.parts = list(parts) |
| 2964 | + return mock |
| 2965 | + |
| 2966 | + def create_mock_part(**kwargs): |
| 2967 | + mock = MagicMock(spec=Part) |
| 2968 | + mock.text = None |
| 2969 | + mock.function_call = None |
| 2970 | + for key, value in kwargs.items(): |
| 2971 | + setattr(mock, key, value) |
| 2972 | + return mock |
| 2973 | + |
| 2974 | + |
| 2975 | + async def mock_stream_iterator(): |
| 2976 | + # This part is a bit of a fiction, in reality the tool call and return would not be in the same stream. |
| 2977 | + # But for testing the malformed function call in between, this is fine. |
| 2978 | + yield create_mock_response( |
| 2979 | + candidates=[ |
| 2980 | + create_mock_candidate( |
| 2981 | + content=create_mock_content( |
| 2982 | + create_mock_part(function_call=MagicMock(name='get_capital', args={'country': 'France'})) |
| 2983 | + ), |
| 2984 | + finish_reason=None, |
| 2985 | + ) |
| 2986 | + ] |
| 2987 | + ) |
| 2988 | + |
| 2989 | + yield create_mock_response( |
| 2990 | + candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')] |
| 2991 | + ) |
| 2992 | + |
| 2993 | + yield create_mock_response( |
| 2994 | + candidates=[ |
| 2995 | + create_mock_candidate( |
| 2996 | + content=create_mock_content(create_mock_part(text='Here is the capital.')), finish_reason=None |
| 2997 | + ) |
| 2998 | + ] |
| 2999 | + ) |
| 3000 | + |
| 3001 | + yield create_mock_response(candidates=[create_mock_candidate(content=None, finish_reason='STOP')]) |
| 3002 | + |
| 3003 | + mock_stream.return_value = mock_stream_iterator() |
| 3004 | + model = GoogleModel('gemini-1.5-flash', provider=google_provider) |
| 3005 | + monkeypatch.setattr(model.client.aio.models, 'generate_content_stream', mock_stream) |
| 3006 | + |
| 3007 | + agent = Agent(model) |
| 3008 | + |
| 3009 | + @agent.tool_plain |
| 3010 | + def get_capital(country: str) -> str: |
| 3011 | + """Get the capital of a country.""" |
| 3012 | + return 'Paris' if country == 'France' else 'Unknown' |
| 3013 | + |
| 3014 | + async with agent.run_stream('What is the capital of France?') as result: |
| 3015 | + data = await result.get_output() |
| 3016 | + |
| 3017 | + assert data == 'Here is the capital.' |
| 3018 | + assert mock_stream.call_count == 1 |
| 3019 | + |
| 3020 | + |
| 3021 | +async def test_google_model_stream_malformed_function_call_retry( |
| 3022 | + allow_model_requests: None, google_provider: GoogleProvider, monkeypatch: pytest.MonkeyPatch |
| 3023 | +): |
| 3024 | + """Test that the model retries when the stream only contains a malformed function call.""" |
| 3025 | + from unittest.mock import MagicMock |
| 3026 | + |
| 3027 | + from google.genai.types import Candidate, Content, GenerateContentResponse, Part |
| 3028 | + |
| 3029 | + mock_stream = AsyncMock() |
| 3030 | + |
| 3031 | + def create_mock_response(**kwargs): |
| 3032 | + mock = MagicMock(spec=GenerateContentResponse) |
| 3033 | + for key, value in kwargs.items(): |
| 3034 | + setattr(mock, key, value) |
| 3035 | + mock.model_version = 'gemini-1.5-flash' |
| 3036 | + mock.create_time = None |
| 3037 | + mock.response_id = 'test-response-id' |
| 3038 | + return mock |
| 3039 | + |
| 3040 | + def create_mock_candidate(**kwargs): |
| 3041 | + mock = MagicMock(spec=Candidate) |
| 3042 | + for key, value in kwargs.items(): |
| 3043 | + setattr(mock, key, value) |
| 3044 | + return mock |
| 3045 | + |
| 3046 | + def create_mock_content(*parts): |
| 3047 | + mock = MagicMock(spec=Content) |
| 3048 | + mock.parts = list(parts) |
| 3049 | + return mock |
| 3050 | + |
| 3051 | + def create_mock_part(**kwargs): |
| 3052 | + mock = MagicMock(spec=Part) |
| 3053 | + mock.text = None |
| 3054 | + for key, value in kwargs.items(): |
| 3055 | + setattr(mock, key, value) |
| 3056 | + return mock |
| 3057 | + |
| 3058 | + async def first_call_iterator(): |
| 3059 | + yield create_mock_response( |
| 3060 | + candidates=[create_mock_candidate(content=None, finish_reason='MALFORMED_FUNCTION_CALL')] |
| 3061 | + ) |
| 3062 | + |
| 3063 | + async def second_call_iterator(): |
| 3064 | + yield create_mock_response( |
| 3065 | + candidates=[ |
| 3066 | + create_mock_candidate( |
| 3067 | + content=create_mock_content(create_mock_part(text='Successful response')), finish_reason=None |
| 3068 | + ) |
| 3069 | + ] |
| 3070 | + ) |
| 3071 | + yield create_mock_response(candidates=[create_mock_candidate(content=None, finish_reason='STOP')]) |
| 3072 | + |
| 3073 | + mock_stream.side_effect = [first_call_iterator(), second_call_iterator()] |
| 3074 | + |
| 3075 | + model = GoogleModel('gemini-1.5-flash', provider=google_provider) |
| 3076 | + monkeypatch.setattr(model.client.aio.models, 'generate_content_stream', mock_stream) |
| 3077 | + |
| 3078 | + agent = Agent(model, retries=1) |
| 3079 | + |
| 3080 | + async with agent.run_stream('Some prompt') as result: |
| 3081 | + data = await result.get_output() |
| 3082 | + |
| 3083 | + assert data == 'Successful response' |
| 3084 | + assert mock_stream.call_count == 2 |
0 commit comments