|
4 | 4 | import base64 |
5 | 5 | import json |
6 | 6 | import warnings |
| 7 | +from collections.abc import Iterator |
7 | 8 | from concurrent.futures import ThreadPoolExecutor |
8 | 9 | from typing import Optional, Union |
9 | 10 | from unittest.mock import ANY, Mock, patch |
|
20 | 21 | from langchain_core.load import dumps, loads |
21 | 22 | from langchain_core.messages import ( |
22 | 23 | AIMessage, |
| 24 | + BaseMessage, |
23 | 25 | FunctionMessage, |
24 | 26 | HumanMessage, |
25 | 27 | SystemMessage, |
@@ -917,3 +919,198 @@ def test_response_to_result_grounding_metadata( |
917 | 919 | else {} |
918 | 920 | ) |
919 | 921 | assert grounding_metadata == expected_grounding_metadata |
| 922 | + |
| 923 | + |
| 924 | +@pytest.mark.parametrize( |
| 925 | + "is_async,mock_target,method_name", |
| 926 | + [ |
| 927 | + (False, "_chat_with_retry", "_generate"), # Sync |
| 928 | + (True, "_achat_with_retry", "_agenerate"), # Async |
| 929 | + ], |
| 930 | +) |
| 931 | +@pytest.mark.parametrize( |
| 932 | + "instance_timeout,call_timeout,expected_timeout,should_have_timeout", |
| 933 | + [ |
| 934 | + (5.0, None, 5.0, True), # Instance-level timeout |
| 935 | + (5.0, 10.0, 10.0, True), # Call-level overrides instance |
| 936 | + (None, None, None, False), # No timeout anywhere |
| 937 | + ], |
| 938 | +) |
| 939 | +async def test_timeout_parameter_handling( |
| 940 | + is_async: bool, |
| 941 | + mock_target: str, |
| 942 | + method_name: str, |
| 943 | + instance_timeout: Optional[float], |
| 944 | + call_timeout: Optional[float], |
| 945 | + expected_timeout: Optional[float], |
| 946 | + should_have_timeout: bool, |
| 947 | +) -> None: |
| 948 | + """Test timeout parameter handling for sync and async methods.""" |
| 949 | + with patch(f"langchain_google_genai.chat_models.{mock_target}") as mock_retry: |
| 950 | + mock_retry.return_value = GenerateContentResponse( |
| 951 | + { |
| 952 | + "candidates": [ |
| 953 | + { |
| 954 | + "content": {"parts": [{"text": "Test response"}]}, |
| 955 | + "finish_reason": "STOP", |
| 956 | + } |
| 957 | + ] |
| 958 | + } |
| 959 | + ) |
| 960 | + |
| 961 | + # Create LLM with optional instance-level timeout |
| 962 | + llm_kwargs = { |
| 963 | + "model": "gemini-2.5-flash", |
| 964 | + "google_api_key": SecretStr("test-key"), |
| 965 | + } |
| 966 | + if instance_timeout is not None: |
| 967 | + llm_kwargs["timeout"] = instance_timeout |
| 968 | + |
| 969 | + llm = ChatGoogleGenerativeAI(**llm_kwargs) |
| 970 | + messages: list[BaseMessage] = [HumanMessage(content="Hello")] |
| 971 | + |
| 972 | + # Call the appropriate method with optional call-level timeout |
| 973 | + method = getattr(llm, method_name) |
| 974 | + call_kwargs = {} |
| 975 | + if call_timeout is not None: |
| 976 | + call_kwargs["timeout"] = call_timeout |
| 977 | + |
| 978 | + if is_async: |
| 979 | + await method(messages, **call_kwargs) |
| 980 | + else: |
| 981 | + method(messages, **call_kwargs) |
| 982 | + |
| 983 | + # Verify timeout was passed correctly |
| 984 | + mock_retry.assert_called_once() |
| 985 | + call_kwargs_actual = mock_retry.call_args[1] |
| 986 | + |
| 987 | + if should_have_timeout: |
| 988 | + assert "timeout" in call_kwargs_actual |
| 989 | + assert call_kwargs_actual["timeout"] == expected_timeout |
| 990 | + else: |
| 991 | + assert "timeout" not in call_kwargs_actual |
| 992 | + |
| 993 | + |
| 994 | +@pytest.mark.parametrize( |
| 995 | + "instance_timeout,expected_timeout,should_have_timeout", |
| 996 | + [ |
| 997 | + (5.0, 5.0, True), # Instance-level timeout |
| 998 | + (None, None, False), # No timeout |
| 999 | + ], |
| 1000 | +) |
| 1001 | +@patch("langchain_google_genai.chat_models._chat_with_retry") |
| 1002 | +def test_timeout_streaming_parameter_handling( |
| 1003 | + mock_retry: Mock, |
| 1004 | + instance_timeout: Optional[float], |
| 1005 | + expected_timeout: Optional[float], |
| 1006 | + should_have_timeout: bool, |
| 1007 | +) -> None: |
| 1008 | + """Test timeout parameter handling for streaming methods.""" |
| 1009 | + |
| 1010 | + # Mock the return value for _chat_with_retry to return an iterator |
| 1011 | + def mock_stream() -> Iterator[GenerateContentResponse]: |
| 1012 | + yield GenerateContentResponse( |
| 1013 | + { |
| 1014 | + "candidates": [ |
| 1015 | + { |
| 1016 | + "content": {"parts": [{"text": "chunk1"}]}, |
| 1017 | + "finish_reason": "STOP", |
| 1018 | + } |
| 1019 | + ] |
| 1020 | + } |
| 1021 | + ) |
| 1022 | + |
| 1023 | + mock_retry.return_value = mock_stream() |
| 1024 | + |
| 1025 | + # Create LLM with optional instance-level timeout |
| 1026 | + llm_kwargs = { |
| 1027 | + "model": "gemini-2.5-flash", |
| 1028 | + "google_api_key": SecretStr("test-key"), |
| 1029 | + } |
| 1030 | + if instance_timeout is not None: |
| 1031 | + llm_kwargs["timeout"] = instance_timeout |
| 1032 | + |
| 1033 | + llm = ChatGoogleGenerativeAI(**llm_kwargs) |
| 1034 | + |
| 1035 | + # Call _stream (which should pass timeout to _chat_with_retry) |
| 1036 | + messages: list[BaseMessage] = [HumanMessage(content="Hello")] |
| 1037 | + list(llm._stream(messages)) # Convert generator to list to trigger execution |
| 1038 | + |
| 1039 | + # Verify timeout was passed correctly |
| 1040 | + mock_retry.assert_called_once() |
| 1041 | + call_kwargs = mock_retry.call_args[1] |
| 1042 | + |
| 1043 | + if should_have_timeout: |
| 1044 | + assert "timeout" in call_kwargs |
| 1045 | + assert call_kwargs["timeout"] == expected_timeout |
| 1046 | + else: |
| 1047 | + assert "timeout" not in call_kwargs |
| 1048 | + |
| 1049 | + |
| 1050 | +@pytest.mark.parametrize( |
| 1051 | + "is_async,mock_target,method_name", |
| 1052 | + [ |
| 1053 | + (False, "_chat_with_retry", "_generate"), # Sync |
| 1054 | + (True, "_achat_with_retry", "_agenerate"), # Async |
| 1055 | + ], |
| 1056 | +) |
| 1057 | +@pytest.mark.parametrize( |
| 1058 | + "instance_max_retries,call_max_retries,expected_max_retries,should_have_max_retries", |
| 1059 | + [ |
| 1060 | + (1, None, 1, True), # Instance-level max_retries |
| 1061 | + (3, 5, 5, True), # Call-level overrides instance |
| 1062 | + (6, None, 6, True), # Default instance value |
| 1063 | + ], |
| 1064 | +) |
| 1065 | +async def test_max_retries_parameter_handling( |
| 1066 | + is_async: bool, |
| 1067 | + mock_target: str, |
| 1068 | + method_name: str, |
| 1069 | + instance_max_retries: int, |
| 1070 | + call_max_retries: Optional[int], |
| 1071 | + expected_max_retries: int, |
| 1072 | + should_have_max_retries: bool, |
| 1073 | +) -> None: |
| 1074 | + """Test max_retries parameter handling for sync and async methods.""" |
| 1075 | + with patch(f"langchain_google_genai.chat_models.{mock_target}") as mock_retry: |
| 1076 | + mock_retry.return_value = GenerateContentResponse( |
| 1077 | + { |
| 1078 | + "candidates": [ |
| 1079 | + { |
| 1080 | + "content": {"parts": [{"text": "Test response"}]}, |
| 1081 | + "finish_reason": "STOP", |
| 1082 | + } |
| 1083 | + ] |
| 1084 | + } |
| 1085 | + ) |
| 1086 | + |
| 1087 | + # Instance-level max_retries |
| 1088 | + llm_kwargs = { |
| 1089 | + "model": "gemini-2.5-flash", |
| 1090 | + "google_api_key": SecretStr("test-key"), |
| 1091 | + "max_retries": instance_max_retries, |
| 1092 | + } |
| 1093 | + |
| 1094 | + llm = ChatGoogleGenerativeAI(**llm_kwargs) |
| 1095 | + messages: list[BaseMessage] = [HumanMessage(content="Hello")] |
| 1096 | + |
| 1097 | + # Call the appropriate method with optional call-level max_retries |
| 1098 | + method = getattr(llm, method_name) |
| 1099 | + call_kwargs = {} |
| 1100 | + if call_max_retries is not None: |
| 1101 | + call_kwargs["max_retries"] = call_max_retries |
| 1102 | + |
| 1103 | + if is_async: |
| 1104 | + await method(messages, **call_kwargs) |
| 1105 | + else: |
| 1106 | + method(messages, **call_kwargs) |
| 1107 | + |
| 1108 | + # Verify max_retries was passed correctly |
| 1109 | + mock_retry.assert_called_once() |
| 1110 | + call_kwargs_actual = mock_retry.call_args[1] |
| 1111 | + |
| 1112 | + if should_have_max_retries: |
| 1113 | + assert "max_retries" in call_kwargs_actual |
| 1114 | + assert call_kwargs_actual["max_retries"] == expected_max_retries |
| 1115 | + else: |
| 1116 | + assert "max_retries" not in call_kwargs_actual |
0 commit comments