Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/rotator_library/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,14 @@ async def _execute_with_retry(
f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}"
)

# Add retry callback to track additional requests (bare 429s, empty responses)
async def on_retry_attempt(m: str) -> None:
await self.usage_manager.increment_request_count(
current_cred, m
)

litellm_kwargs["on_retry_attempt"] = on_retry_attempt

response = await provider_plugin.acompletion(
self.http_client, **litellm_kwargs
)
Expand Down Expand Up @@ -2268,6 +2276,14 @@ async def _streaming_acompletion_with_retry(
f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}"
)

# Add retry callback to track additional requests (bare 429s, empty responses)
async def on_retry_attempt(m: str) -> None:
await self.usage_manager.increment_request_count(
current_cred, m
)

litellm_kwargs["on_retry_attempt"] = on_retry_attempt

response = await provider_plugin.acompletion(
self.http_client, **litellm_kwargs
)
Expand Down
16 changes: 16 additions & 0 deletions src/rotator_library/providers/antigravity_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
List,
Optional,
Expand Down Expand Up @@ -3924,6 +3926,7 @@ async def acompletion(
temperature = kwargs.get("temperature")
max_tokens = kwargs.get("max_tokens")
transaction_context = kwargs.pop("transaction_context", None)
on_retry_attempt = kwargs.pop("on_retry_attempt", None)

# Create provider logger from transaction context
file_logger = AntigravityProviderLogger(transaction_context)
Expand Down Expand Up @@ -4136,6 +4139,7 @@ async def acompletion(
max_tokens,
reasoning_effort,
tool_choice,
on_retry_attempt,
)

if stream:
Expand Down Expand Up @@ -4527,6 +4531,7 @@ async def _streaming_with_retry(
max_tokens: Optional[int] = None,
reasoning_effort: Optional[str] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
on_retry_attempt: Optional[Callable[[str], Awaitable[None]]] = None,
) -> AsyncGenerator[litellm.ModelResponse, None]:
"""
Wrapper around _handle_streaming that retries on empty responses, bare 429s,
Expand All @@ -4538,6 +4543,10 @@ async def _streaming_with_retry(

If MALFORMED_FUNCTION_CALL is detected, inject corrective messages and retry
up to MALFORMED_CALL_MAX_RETRIES times.

Args:
on_retry_attempt: Optional callback invoked on retry attempts (attempt > 0).
Used to track additional requests that consume quota.
"""
empty_error_msg = (
"The model returned an empty response after multiple attempts. "
Expand All @@ -4554,6 +4563,13 @@ async def _streaming_with_retry(
current_payload = payload

for attempt in range(EMPTY_RESPONSE_MAX_ATTEMPTS):
# Track retry attempts (not first attempt - that's counted by record_success/failure)
if attempt > 0 and on_retry_attempt:
try:
await on_retry_attempt(model)
except Exception as e:
lib_logger.warning(f"Retry attempt callback failed: {e}")

chunk_count = 0

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -987,13 +987,15 @@ async def _store_baselines_to_usage_manager(
self,
quota_results: Dict[str, Dict[str, Any]],
usage_manager: "UsageManager",
sync_mode: str = "force",
) -> int:
"""
Store fetched quota baselines into UsageManager.

Args:
quota_results: Dict from fetch_quota_from_api or fetch_initial_baselines
usage_manager: UsageManager instance to store baselines in
sync_mode: How to sync request_count ("force", "if_exhausted", "none")

Returns:
Number of baselines successfully stored
Expand Down Expand Up @@ -1052,7 +1054,12 @@ async def _store_baselines_to_usage_manager(
# Store with provider prefix for consistency with usage tracking
prefixed_model = f"antigravity/{user_model}"
cooldown_info = await usage_manager.update_quota_baseline(
cred_path, prefixed_model, remaining, max_requests, reset_timestamp
cred_path,
prefixed_model,
remaining,
max_requests,
reset_timestamp,
sync_mode=sync_mode,
)

# Aggregate cooldown info if returned
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,13 +507,15 @@ async def _store_baselines_to_usage_manager(
self,
quota_results: Dict[str, Dict[str, Any]],
usage_manager: "UsageManager",
sync_mode: str = "force",
) -> int:
"""
Store fetched quota baselines into UsageManager.

Args:
quota_results: Dict from _fetch_quota_for_credential or fetch_initial_baselines
usage_manager: UsageManager instance to store baselines in
sync_mode: How to sync request_count ("force", "if_exhausted", "none")

Returns:
Number of baselines successfully stored
Expand Down Expand Up @@ -541,7 +543,11 @@ async def _store_baselines_to_usage_manager(

# Store baseline
await usage_manager.update_quota_baseline(
cred_path, prefixed_model, remaining, max_requests=max_requests
cred_path,
prefixed_model,
remaining,
max_requests=max_requests,
sync_mode=sync_mode,
)
stored_count += 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ async def run_background_job(
Refresh quota baselines for credentials.

On first run (startup): Fetches quota for ALL credentials to establish baselines.
On subsequent runs: Only fetches for credentials used since last refresh.
On subsequent runs: Behavior depends on provider:
- Antigravity: Disabled (local counting is authoritative)
- Others: Only fetches for credentials used since last refresh.

Handles both file paths and env:// credential formats.

Expand All @@ -257,6 +259,7 @@ async def run_background_job(
return

provider_name = getattr(self, "provider_env_name", "Provider")
is_antigravity = provider_name.lower() == "antigravity"

if not self._initial_quota_fetch_done:
# First run: fetch ALL credentials to establish baselines
Expand All @@ -265,24 +268,46 @@ async def run_background_job(
)
quota_results = await self.fetch_initial_baselines(credentials)
self._initial_quota_fetch_done = True

if not quota_results:
return

# For Antigravity: use "if_exhausted" to pick up state from other instances
# For others: use "force" (backwards-compatible default)
sync_mode = "if_exhausted" if is_antigravity else "force"
stored = await self._store_baselines_to_usage_manager(
quota_results, usage_manager, sync_mode=sync_mode
)
if stored > 0:
lib_logger.debug(
f"{provider_name} initial quota fetch: updated {stored} model baselines"
)
else:
# Subsequent runs: only recently used credentials (incremental updates)
# Subsequent runs
if is_antigravity:
# Antigravity: Background refresh disabled (local counting is authoritative)
lib_logger.debug(
f"{provider_name}: Background quota refresh disabled (local counting is authoritative)"
)
return

# Other providers: refresh recently used credentials
usage_data = await usage_manager._get_usage_data_snapshot()
quota_results = await self.refresh_active_quota_baselines(
credentials, usage_data
)

if not quota_results:
return
if not quota_results:
return

# Store new baselines in UsageManager
stored = await self._store_baselines_to_usage_manager(
quota_results, usage_manager
)
if stored > 0:
lib_logger.debug(
f"{provider_name} quota refresh: updated {stored} model baselines"
# Store new baselines in UsageManager
stored = await self._store_baselines_to_usage_manager(
quota_results, usage_manager
)
if stored > 0:
lib_logger.debug(
f"{provider_name} quota refresh: updated {stored} model baselines"
)

# =========================================================================
# ABSTRACT METHODS - Must be implemented by providers
Expand Down
Loading