Skip to content

Commit 4349e65

Browse files
author
Mateusz
committed
feat(gemini): implement auto-retry on 401 Unauthorized errors
- Add proactive token refresh and retry logic for 401 responses - Support both streaming and non-streaming request retries - Prevent infinite retry loops with _auth_retry_attempted flag - Add comprehensive unit and behavior tests for auth retry scenarios
1 parent 11bdca6 commit 4349e65

File tree

4 files changed

+536
-17
lines changed

4 files changed

+536
-17
lines changed

src/connectors/gemini_base/connector.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,7 @@ async def _chat_completions_code_assist(
13981398
processed_messages: list[Any],
13991399
effective_model: str,
14001400
_in_graceful_degradation: bool = False,
1401+
_auth_retry_attempted: bool = False,
14011402
**kwargs: Any,
14021403
) -> ResponseEnvelope | StreamingResponseEnvelope:
14031404
"""Handle chat completions using the Code Assist API.
@@ -1486,6 +1487,44 @@ async def _chat_completions_code_assist(
14861487
)
14871488

14881489
except AuthenticationError as e:
1490+
# Handle 401 authentication errors with token refresh and retry
1491+
if not _auth_retry_attempted:
1492+
logger.info(
1493+
"Received 401 Unauthorized in non-streaming request, attempting token refresh and retry..."
1494+
)
1495+
try:
1496+
# Use 30s timeout for refresh, leaving room for retry request
1497+
AUTH_RETRY_TIMEOUT = 30.0
1498+
refreshed = await asyncio.wait_for(
1499+
self._refresh_token_if_needed(),
1500+
timeout=AUTH_RETRY_TIMEOUT,
1501+
)
1502+
if refreshed:
1503+
logger.info(
1504+
"Token refresh successful, retrying non-streaming request..."
1505+
)
1506+
return await self._chat_completions_code_assist(
1507+
request_data=request_data,
1508+
processed_messages=processed_messages,
1509+
effective_model=effective_model,
1510+
_in_graceful_degradation=_in_graceful_degradation,
1511+
_auth_retry_attempted=True, # Prevent infinite retry loops
1512+
**kwargs,
1513+
)
1514+
else:
1515+
logger.warning(
1516+
"Token refresh failed; will raise 401 error to caller"
1517+
)
1518+
except asyncio.TimeoutError:
1519+
logger.warning(
1520+
f"Token refresh timed out after {AUTH_RETRY_TIMEOUT}s; raising 401 to caller"
1521+
)
1522+
except Exception as refresh_error:
1523+
logger.error(
1524+
f"Error during token refresh attempt: {refresh_error}",
1525+
exc_info=True,
1526+
)
1527+
# If we reach here, refresh failed or already retried - raise original error
14891528
logger.error(f"Authentication error during API call: {e}", exc_info=True)
14901529
raise
14911530
except BackendError as e:
@@ -1561,6 +1600,7 @@ async def stream_generator(
15611600
*,
15621601
_allow_tool_retry: bool = True,
15631602
without_tools: bool = False,
1603+
_auth_retry_attempted: bool = False,
15641604
) -> AsyncGenerator[ProcessedResponse, None]:
15651605
import json
15661606

@@ -1727,9 +1767,54 @@ def _build_error_chunk(
17271767
code = "quota_exceeded"
17281768
elif response.status_code == 429:
17291769
code = "rate_limit_exceeded"
1770+
elif response.status_code == 401:
1771+
code = "auth_error"
17301772
elif isinstance(error_detail, str) and error_detail.strip():
17311773
error_message = error_detail
17321774

1775+
# Handle 401 authentication errors with token refresh and retry
1776+
if response.status_code == 401 and not _auth_retry_attempted:
1777+
logger.info(
1778+
"Received 401 Unauthorized from backend, attempting token refresh and retry..."
1779+
)
1780+
with contextlib.suppress(Exception):
1781+
response.close()
1782+
1783+
# Trigger proactive token refresh with timeout
1784+
# Use 30s timeout for refresh, leaving room for retry request
1785+
AUTH_RETRY_TIMEOUT = 30.0
1786+
try:
1787+
refreshed = await asyncio.wait_for(
1788+
self._refresh_token_if_needed(),
1789+
timeout=AUTH_RETRY_TIMEOUT,
1790+
)
1791+
if refreshed:
1792+
logger.info(
1793+
"Token refresh successful, retrying streaming request..."
1794+
)
1795+
# Recursively call stream_generator with retry flag set
1796+
async for retry_chunk in stream_generator(
1797+
_allow_tool_retry=_allow_tool_retry,
1798+
without_tools=without_tools,
1799+
_auth_retry_attempted=True, # Prevent infinite retry loops
1800+
):
1801+
yield retry_chunk
1802+
return # Successfully handled via retry
1803+
else:
1804+
logger.warning(
1805+
"Token refresh failed; will return 401 error to client"
1806+
)
1807+
except asyncio.TimeoutError:
1808+
logger.warning(
1809+
f"Token refresh timed out after {AUTH_RETRY_TIMEOUT}s; returning 401 to client"
1810+
)
1811+
except Exception as refresh_error:
1812+
logger.error(
1813+
f"Error during token refresh attempt: {refresh_error}",
1814+
exc_info=True,
1815+
)
1816+
# If we reach here, refresh failed - continue to raise error below
1817+
17331818
# Attach retry-after hint when available
17341819
retry_delay = None
17351820
if response.status_code == 429:

src/connectors/gemini_base/graceful_degradation.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
logger = logging.getLogger(__name__)
2525

2626

27-
def is_rate_limit_like_error(error: BackendError) -> bool:
28-
"""Determine whether an error should trigger graceful degradation retries.
27+
def is_rate_limit_like_error(error: BackendError) -> bool:
28+
"""Determine whether an error should trigger graceful degradation retries.
2929
3030
Args:
3131
error: The BackendError to check.
@@ -113,8 +113,8 @@ def calculate_retry_delay(
113113
return max(min_delay, base_delay + jitter)
114114

115115

116-
class GracefulDegradationManager:
117-
"""Manages graceful degradation state for a connector."""
116+
class GracefulDegradationManager:
117+
"""Manages graceful degradation state for a connector."""
118118

119119
def __init__(
120120
self,
@@ -132,9 +132,9 @@ def __init__(
132132
self.model_retry_states: dict[str, ModelRetryState] = {}
133133
self.permanently_failed = False
134134

135-
def is_rate_limit_like_error(self, error: BackendError) -> bool:
136-
"""Determine whether an error should trigger graceful degradation."""
137-
return is_rate_limit_like_error(error)
135+
def is_rate_limit_like_error(self, error: BackendError) -> bool:
136+
"""Determine whether an error should trigger graceful degradation."""
137+
return is_rate_limit_like_error(error)
138138

139139
def is_in_cooldown(self, model: str) -> bool:
140140
"""Check if a model is currently in cooldown."""
@@ -163,9 +163,22 @@ def get_or_create_state(self, model: str) -> ModelRetryState:
163163
self.model_retry_states[model] = ModelRetryState()
164164
return self.model_retry_states[model]
165165

166-
def get_models_to_try(self, original_model: str, disable_fallback: bool = False) -> list[str]:
167-
"""Return only the original model; fallbacks are handled upstream."""
168-
return [original_model]
166+
def get_models_to_try(
167+
self, original_model: str, disable_fallback: bool = False
168+
) -> list[str]:
169+
"""Return only the original model; fallbacks are handled upstream.
170+
171+
Args:
172+
original_model: The model to use.
173+
disable_fallback: Reserved for API compatibility; fallbacks handled upstream.
174+
175+
Returns:
176+
List containing only the original model.
177+
"""
178+
# disable_fallback is intentionally unused - fallbacks are handled upstream
179+
# by the connector layer. This parameter is kept for API compatibility.
180+
_ = disable_fallback # Explicitly acknowledge the parameter
181+
return [original_model]
169182

170183
def record_attempt(self) -> None:
171184
"""Record an attempt in metrics."""
@@ -201,10 +214,10 @@ def get_metrics(self) -> dict[str, Any]:
201214
return self.metrics.as_dict()
202215

203216

204-
__all__ = [
205-
"GracefulDegradationManager",
206-
"calculate_retry_delay",
207-
"is_model_in_cooldown",
208-
"is_rate_limit_like_error",
209-
"set_model_cooldown",
210-
]
217+
__all__ = [
218+
"GracefulDegradationManager",
219+
"calculate_retry_delay",
220+
"is_model_in_cooldown",
221+
"is_rate_limit_like_error",
222+
"set_model_cooldown",
223+
]

0 commit comments

Comments
 (0)