Skip to content

Commit 5f9a80b

Browse files
authored
feat(genai): add timeout and max_retries handling in chat methods (#1180)
Fixes #731 Ensures instance-level (model defined) `timeout` and `max_retries` params are used if none are provided in the invocation.
1 parent 8add7fd commit 5f9a80b

File tree

2 files changed

+222
-2
lines changed

2 files changed

+222
-2
lines changed

libs/genai/langchain_google_genai/chat_models.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,10 @@ def _chat_with_retry(**kwargs: Any) -> Any:
220220
raise ChatGoogleGenerativeAIError(msg) from e
221221
except ResourceExhausted as e:
222222
# Handle quota-exceeded error with recommended retry delay
223-
if hasattr(e, "retry_after") and e.retry_after < kwargs.get(
223+
if hasattr(e, "retry_after") and getattr(e, "retry_after", 0) < kwargs.get(
224224
"wait_exponential_max", 60.0
225225
):
226-
time.sleep(e.retry_after)
226+
time.sleep(getattr(e, "retry_after"))
227227
raise
228228
except Exception:
229229
raise
@@ -267,6 +267,13 @@ async def _achat_with_retry(**kwargs: Any) -> Any:
267267
# Do not retry for these errors.
268268
msg = f"Invalid argument provided to Gemini: {e}"
269269
raise ChatGoogleGenerativeAIError(msg) from e
270+
except ResourceExhausted as e:
271+
# Handle quota-exceeded error with recommended retry delay
272+
if hasattr(e, "retry_after") and getattr(e, "retry_after", 0) < kwargs.get(
273+
"wait_exponential_max", 60.0
274+
):
275+
time.sleep(getattr(e, "retry_after"))
276+
raise
270277
except Exception:
271278
raise
272279

@@ -1776,6 +1783,10 @@ def _generate(
17761783
tool_choice=tool_choice,
17771784
**kwargs,
17781785
)
1786+
if self.timeout is not None and "timeout" not in kwargs:
1787+
kwargs["timeout"] = self.timeout
1788+
if "max_retries" not in kwargs:
1789+
kwargs["max_retries"] = self.max_retries
17791790
response: GenerateContentResponse = _chat_with_retry(
17801791
request=request,
17811792
**kwargs,
@@ -1824,6 +1835,10 @@ async def _agenerate(
18241835
tool_choice=tool_choice,
18251836
**kwargs,
18261837
)
1838+
if self.timeout is not None and "timeout" not in kwargs:
1839+
kwargs["timeout"] = self.timeout
1840+
if "max_retries" not in kwargs:
1841+
kwargs["max_retries"] = self.max_retries
18271842
response: GenerateContentResponse = await _achat_with_retry(
18281843
request=request,
18291844
**kwargs,
@@ -1859,6 +1874,10 @@ def _stream(
18591874
tool_choice=tool_choice,
18601875
**kwargs,
18611876
)
1877+
if self.timeout is not None and "timeout" not in kwargs:
1878+
kwargs["timeout"] = self.timeout
1879+
if "max_retries" not in kwargs:
1880+
kwargs["max_retries"] = self.max_retries
18621881
response: GenerateContentResponse = _chat_with_retry(
18631882
request=request,
18641883
generation_method=self.client.stream_generate_content,
@@ -1925,6 +1944,10 @@ async def _astream(
19251944
tool_choice=tool_choice,
19261945
**kwargs,
19271946
)
1947+
if self.timeout is not None and "timeout" not in kwargs:
1948+
kwargs["timeout"] = self.timeout
1949+
if "max_retries" not in kwargs:
1950+
kwargs["max_retries"] = self.max_retries
19281951
prev_usage_metadata: UsageMetadata | None = None # cumulative usage
19291952
async for chunk in await _achat_with_retry(
19301953
request=request,

libs/genai/tests/unit_tests/test_chat_models.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import base64
55
import json
66
import warnings
7+
from collections.abc import Iterator
78
from concurrent.futures import ThreadPoolExecutor
89
from typing import Optional, Union
910
from unittest.mock import ANY, Mock, patch
@@ -20,6 +21,7 @@
2021
from langchain_core.load import dumps, loads
2122
from langchain_core.messages import (
2223
AIMessage,
24+
BaseMessage,
2325
FunctionMessage,
2426
HumanMessage,
2527
SystemMessage,
@@ -917,3 +919,198 @@ def test_response_to_result_grounding_metadata(
917919
else {}
918920
)
919921
assert grounding_metadata == expected_grounding_metadata
922+
923+
924+
@pytest.mark.parametrize(
925+
"is_async,mock_target,method_name",
926+
[
927+
(False, "_chat_with_retry", "_generate"), # Sync
928+
(True, "_achat_with_retry", "_agenerate"), # Async
929+
],
930+
)
931+
@pytest.mark.parametrize(
932+
"instance_timeout,call_timeout,expected_timeout,should_have_timeout",
933+
[
934+
(5.0, None, 5.0, True), # Instance-level timeout
935+
(5.0, 10.0, 10.0, True), # Call-level overrides instance
936+
(None, None, None, False), # No timeout anywhere
937+
],
938+
)
939+
async def test_timeout_parameter_handling(
940+
is_async: bool,
941+
mock_target: str,
942+
method_name: str,
943+
instance_timeout: Optional[float],
944+
call_timeout: Optional[float],
945+
expected_timeout: Optional[float],
946+
should_have_timeout: bool,
947+
) -> None:
948+
"""Test timeout parameter handling for sync and async methods."""
949+
with patch(f"langchain_google_genai.chat_models.{mock_target}") as mock_retry:
950+
mock_retry.return_value = GenerateContentResponse(
951+
{
952+
"candidates": [
953+
{
954+
"content": {"parts": [{"text": "Test response"}]},
955+
"finish_reason": "STOP",
956+
}
957+
]
958+
}
959+
)
960+
961+
# Create LLM with optional instance-level timeout
962+
llm_kwargs = {
963+
"model": "gemini-2.5-flash",
964+
"google_api_key": SecretStr("test-key"),
965+
}
966+
if instance_timeout is not None:
967+
llm_kwargs["timeout"] = instance_timeout
968+
969+
llm = ChatGoogleGenerativeAI(**llm_kwargs)
970+
messages: list[BaseMessage] = [HumanMessage(content="Hello")]
971+
972+
# Call the appropriate method with optional call-level timeout
973+
method = getattr(llm, method_name)
974+
call_kwargs = {}
975+
if call_timeout is not None:
976+
call_kwargs["timeout"] = call_timeout
977+
978+
if is_async:
979+
await method(messages, **call_kwargs)
980+
else:
981+
method(messages, **call_kwargs)
982+
983+
# Verify timeout was passed correctly
984+
mock_retry.assert_called_once()
985+
call_kwargs_actual = mock_retry.call_args[1]
986+
987+
if should_have_timeout:
988+
assert "timeout" in call_kwargs_actual
989+
assert call_kwargs_actual["timeout"] == expected_timeout
990+
else:
991+
assert "timeout" not in call_kwargs_actual
992+
993+
994+
@pytest.mark.parametrize(
995+
"instance_timeout,expected_timeout,should_have_timeout",
996+
[
997+
(5.0, 5.0, True), # Instance-level timeout
998+
(None, None, False), # No timeout
999+
],
1000+
)
1001+
@patch("langchain_google_genai.chat_models._chat_with_retry")
1002+
def test_timeout_streaming_parameter_handling(
1003+
mock_retry: Mock,
1004+
instance_timeout: Optional[float],
1005+
expected_timeout: Optional[float],
1006+
should_have_timeout: bool,
1007+
) -> None:
1008+
"""Test timeout parameter handling for streaming methods."""
1009+
1010+
# Mock the return value for _chat_with_retry to return an iterator
1011+
def mock_stream() -> Iterator[GenerateContentResponse]:
1012+
yield GenerateContentResponse(
1013+
{
1014+
"candidates": [
1015+
{
1016+
"content": {"parts": [{"text": "chunk1"}]},
1017+
"finish_reason": "STOP",
1018+
}
1019+
]
1020+
}
1021+
)
1022+
1023+
mock_retry.return_value = mock_stream()
1024+
1025+
# Create LLM with optional instance-level timeout
1026+
llm_kwargs = {
1027+
"model": "gemini-2.5-flash",
1028+
"google_api_key": SecretStr("test-key"),
1029+
}
1030+
if instance_timeout is not None:
1031+
llm_kwargs["timeout"] = instance_timeout
1032+
1033+
llm = ChatGoogleGenerativeAI(**llm_kwargs)
1034+
1035+
# Call _stream (which should pass timeout to _chat_with_retry)
1036+
messages: list[BaseMessage] = [HumanMessage(content="Hello")]
1037+
list(llm._stream(messages)) # Convert generator to list to trigger execution
1038+
1039+
# Verify timeout was passed correctly
1040+
mock_retry.assert_called_once()
1041+
call_kwargs = mock_retry.call_args[1]
1042+
1043+
if should_have_timeout:
1044+
assert "timeout" in call_kwargs
1045+
assert call_kwargs["timeout"] == expected_timeout
1046+
else:
1047+
assert "timeout" not in call_kwargs
1048+
1049+
1050+
@pytest.mark.parametrize(
1051+
"is_async,mock_target,method_name",
1052+
[
1053+
(False, "_chat_with_retry", "_generate"), # Sync
1054+
(True, "_achat_with_retry", "_agenerate"), # Async
1055+
],
1056+
)
1057+
@pytest.mark.parametrize(
1058+
"instance_max_retries,call_max_retries,expected_max_retries,should_have_max_retries",
1059+
[
1060+
(1, None, 1, True), # Instance-level max_retries
1061+
(3, 5, 5, True), # Call-level overrides instance
1062+
(6, None, 6, True), # Default instance value
1063+
],
1064+
)
1065+
async def test_max_retries_parameter_handling(
1066+
is_async: bool,
1067+
mock_target: str,
1068+
method_name: str,
1069+
instance_max_retries: int,
1070+
call_max_retries: Optional[int],
1071+
expected_max_retries: int,
1072+
should_have_max_retries: bool,
1073+
) -> None:
1074+
"""Test max_retries parameter handling for sync and async methods."""
1075+
with patch(f"langchain_google_genai.chat_models.{mock_target}") as mock_retry:
1076+
mock_retry.return_value = GenerateContentResponse(
1077+
{
1078+
"candidates": [
1079+
{
1080+
"content": {"parts": [{"text": "Test response"}]},
1081+
"finish_reason": "STOP",
1082+
}
1083+
]
1084+
}
1085+
)
1086+
1087+
# Instance-level max_retries
1088+
llm_kwargs = {
1089+
"model": "gemini-2.5-flash",
1090+
"google_api_key": SecretStr("test-key"),
1091+
"max_retries": instance_max_retries,
1092+
}
1093+
1094+
llm = ChatGoogleGenerativeAI(**llm_kwargs)
1095+
messages: list[BaseMessage] = [HumanMessage(content="Hello")]
1096+
1097+
# Call the appropriate method with optional call-level max_retries
1098+
method = getattr(llm, method_name)
1099+
call_kwargs = {}
1100+
if call_max_retries is not None:
1101+
call_kwargs["max_retries"] = call_max_retries
1102+
1103+
if is_async:
1104+
await method(messages, **call_kwargs)
1105+
else:
1106+
method(messages, **call_kwargs)
1107+
1108+
# Verify max_retries was passed correctly
1109+
mock_retry.assert_called_once()
1110+
call_kwargs_actual = mock_retry.call_args[1]
1111+
1112+
if should_have_max_retries:
1113+
assert "max_retries" in call_kwargs_actual
1114+
assert call_kwargs_actual["max_retries"] == expected_max_retries
1115+
else:
1116+
assert "max_retries" not in call_kwargs_actual

0 commit comments

Comments
 (0)