diff --git a/src/rotator_library/client.py b/src/rotator_library/client.py index a0ec4dfa..1ae5e3ad 100644 --- a/src/rotator_library/client.py +++ b/src/rotator_library/client.py @@ -1495,6 +1495,14 @@ async def _execute_with_retry( f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" ) + # Add retry callback to track additional requests (bare 429s, empty responses) + async def on_retry_attempt(m: str) -> None: + await self.usage_manager.increment_request_count( + current_cred, m + ) + + litellm_kwargs["on_retry_attempt"] = on_retry_attempt + response = await provider_plugin.acompletion( self.http_client, **litellm_kwargs ) @@ -2268,6 +2276,14 @@ async def _streaming_acompletion_with_retry( f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" ) + # Add retry callback to track additional requests (bare 429s, empty responses) + async def on_retry_attempt(m: str) -> None: + await self.usage_manager.increment_request_count( + current_cred, m + ) + + litellm_kwargs["on_retry_attempt"] = on_retry_attempt + response = await provider_plugin.acompletion( self.http_client, **litellm_kwargs ) diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index b8a02fc9..0b1073c4 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -31,6 +31,8 @@ from typing import ( Any, AsyncGenerator, + Awaitable, + Callable, Dict, List, Optional, @@ -3924,6 +3926,7 @@ async def acompletion( temperature = kwargs.get("temperature") max_tokens = kwargs.get("max_tokens") transaction_context = kwargs.pop("transaction_context", None) + on_retry_attempt = kwargs.pop("on_retry_attempt", None) # Create provider logger from transaction context file_logger = AntigravityProviderLogger(transaction_context) @@ -4136,6 +4139,7 @@ async def acompletion( max_tokens, reasoning_effort, tool_choice, + on_retry_attempt, ) if stream: @@ -4527,6 +4531,7 @@ async def _streaming_with_retry( max_tokens: Optional[int] = None, reasoning_effort: Optional[str] = None, tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + on_retry_attempt: Optional[Callable[[str], Awaitable[None]]] = None, ) -> AsyncGenerator[litellm.ModelResponse, None]: """ Wrapper around _handle_streaming that retries on empty responses, bare 429s, @@ -4538,6 +4543,10 @@ async def _streaming_with_retry( If MALFORMED_FUNCTION_CALL is detected, inject corrective messages and retry up to MALFORMED_CALL_MAX_RETRIES times. + + Args: + on_retry_attempt: Optional callback invoked on retry attempts (attempt > 0). + Used to track additional requests that consume quota. """ empty_error_msg = ( "The model returned an empty response after multiple attempts. " @@ -4554,6 +4563,13 @@ async def _streaming_with_retry( current_payload = payload for attempt in range(EMPTY_RESPONSE_MAX_ATTEMPTS): + # Track retry attempts (not first attempt - that's counted by record_success/failure) + if attempt > 0 and on_retry_attempt: + try: + await on_retry_attempt(model) + except Exception as e: + lib_logger.warning(f"Retry attempt callback failed: {e}") + chunk_count = 0 try: diff --git a/src/rotator_library/providers/utilities/antigravity_quota_tracker.py b/src/rotator_library/providers/utilities/antigravity_quota_tracker.py index 0352d07b..d732708b 100644 --- a/src/rotator_library/providers/utilities/antigravity_quota_tracker.py +++ b/src/rotator_library/providers/utilities/antigravity_quota_tracker.py @@ -987,6 +987,7 @@ async def _store_baselines_to_usage_manager( self, quota_results: Dict[str, Dict[str, Any]], usage_manager: "UsageManager", + sync_mode: str = "force", ) -> int: """ Store fetched quota baselines into UsageManager. @@ -994,6 +995,7 @@ async def _store_baselines_to_usage_manager( Args: quota_results: Dict from fetch_quota_from_api or fetch_initial_baselines usage_manager: UsageManager instance to store baselines in + sync_mode: How to sync request_count ("force", "if_exhausted", "none") Returns: Number of baselines successfully stored @@ -1052,7 +1054,12 @@ async def _store_baselines_to_usage_manager( # Store with provider prefix for consistency with usage tracking prefixed_model = f"antigravity/{user_model}" cooldown_info = await usage_manager.update_quota_baseline( - cred_path, prefixed_model, remaining, max_requests, reset_timestamp + cred_path, + prefixed_model, + remaining, + max_requests, + reset_timestamp, + sync_mode=sync_mode, ) # Aggregate cooldown info if returned diff --git a/src/rotator_library/providers/utilities/base_quota_tracker.py b/src/rotator_library/providers/utilities/base_quota_tracker.py index 6fbacc4c..b91e3c39 100644 --- a/src/rotator_library/providers/utilities/base_quota_tracker.py +++ b/src/rotator_library/providers/utilities/base_quota_tracker.py @@ -507,6 +507,7 @@ async def _store_baselines_to_usage_manager( self, quota_results: Dict[str, Dict[str, Any]], usage_manager: "UsageManager", + sync_mode: str = "force", ) -> int: """ Store fetched quota baselines into UsageManager. @@ -514,6 +515,7 @@ async def _store_baselines_to_usage_manager( Args: quota_results: Dict from _fetch_quota_for_credential or fetch_initial_baselines usage_manager: UsageManager instance to store baselines in + sync_mode: How to sync request_count ("force", "if_exhausted", "none") Returns: Number of baselines successfully stored @@ -541,7 +543,11 @@ async def _store_baselines_to_usage_manager( # Store baseline await usage_manager.update_quota_baseline( - cred_path, prefixed_model, remaining, max_requests=max_requests + cred_path, + prefixed_model, + remaining, + max_requests=max_requests, + sync_mode=sync_mode, ) stored_count += 1 diff --git a/src/rotator_library/providers/utilities/gemini_credential_manager.py b/src/rotator_library/providers/utilities/gemini_credential_manager.py index 83d02ffc..37ade25a 100644 --- a/src/rotator_library/providers/utilities/gemini_credential_manager.py +++ b/src/rotator_library/providers/utilities/gemini_credential_manager.py @@ -245,7 +245,9 @@ async def run_background_job( Refresh quota baselines for credentials. On first run (startup): Fetches quota for ALL credentials to establish baselines. - On subsequent runs: Only fetches for credentials used since last refresh. + On subsequent runs: Behavior depends on provider: + - Antigravity: Disabled (local counting is authoritative) + - Others: Only fetches for credentials used since last refresh. Handles both file paths and env:// credential formats. @@ -257,6 +259,7 @@ async def run_background_job( return provider_name = getattr(self, "provider_env_name", "Provider") + is_antigravity = provider_name.lower() == "antigravity" if not self._initial_quota_fetch_done: # First run: fetch ALL credentials to establish baselines @@ -265,24 +268,46 @@ async def run_background_job( ) quota_results = await self.fetch_initial_baselines(credentials) self._initial_quota_fetch_done = True + + if not quota_results: + return + + # For Antigravity: use "if_exhausted" to pick up state from other instances + # For others: use "force" (backwards-compatible default) + sync_mode = "if_exhausted" if is_antigravity else "force" + stored = await self._store_baselines_to_usage_manager( + quota_results, usage_manager, sync_mode=sync_mode + ) + if stored > 0: + lib_logger.debug( + f"{provider_name} initial quota fetch: updated {stored} model baselines" + ) else: - # Subsequent runs: only recently used credentials (incremental updates) + # Subsequent runs + if is_antigravity: + # Antigravity: Background refresh disabled (local counting is authoritative) + lib_logger.debug( + f"{provider_name}: Background quota refresh disabled (local counting is authoritative)" + ) + return + + # Other providers: refresh recently used credentials usage_data = await usage_manager._get_usage_data_snapshot() quota_results = await self.refresh_active_quota_baselines( credentials, usage_data ) - if not quota_results: - return + if not quota_results: + return - # Store new baselines in UsageManager - stored = await self._store_baselines_to_usage_manager( - quota_results, usage_manager - ) - if stored > 0: - lib_logger.debug( - f"{provider_name} quota refresh: updated {stored} model baselines" + # Store new baselines in UsageManager + stored = await self._store_baselines_to_usage_manager( + quota_results, usage_manager ) + if stored > 0: + lib_logger.debug( + f"{provider_name} quota refresh: updated {stored} model baselines" + ) # ========================================================================= # ABSTRACT METHODS - Must be implemented by providers diff --git a/src/rotator_library/usage_manager.py b/src/rotator_library/usage_manager.py index dd2bd480..902d388a 100644 --- a/src/rotator_library/usage_manager.py +++ b/src/rotator_library/usage_manager.py @@ -2843,6 +2843,37 @@ async def record_success( await self._save_usage() + async def increment_request_count(self, key: str, model: str) -> None: + """ + Increment request_count for a model. Used for tracking retry attempts + that don't go through record_success/record_failure (e.g., bare 429 retries). + """ + await self._lazy_init() + model = self._normalize_model(key, model) + + async with self._data_lock: + key_data = self._usage_data.get(key) + if not key_data or "models" not in key_data: + return + + model_data = key_data["models"].get(model) + if not model_data: + return + + model_data["request_count"] = model_data.get("request_count", 0) + 1 + + # Sync across quota group + group = self._get_model_quota_group(key, model) + if group: + new_count = model_data["request_count"] + for grouped_model in self._get_grouped_models(key, group): + if grouped_model != model: + other = key_data["models"].get(grouped_model) + if other: + other["request_count"] = new_count + + await self._save_usage() + async def record_failure( self, key: str, @@ -2953,6 +2984,14 @@ async def record_failure( model_data["quota_display"] = f"{max_req}/{max_req}" new_request_count = model_data["request_count"] + # Track measured max requests (highest count before exhaustion) + measured_max = model_data.get("measured_max_requests") + if measured_max is None or new_request_count > measured_max: + model_data["measured_max_requests"] = new_request_count + lib_logger.info( + f"New measured max for {model}: {new_request_count} requests" + ) + # Apply to all models in the same quota group group = self._get_model_quota_group(key, model) if group: @@ -2975,6 +3014,11 @@ async def record_failure( group_model_data["quota_reset_ts"] = quota_reset_ts # Sync request_count across quota group group_model_data["request_count"] = new_request_count + # Sync measured_max_requests across quota group + if model_data.get("measured_max_requests"): + group_model_data["measured_max_requests"] = model_data[ + "measured_max_requests" + ] # Also sync quota_max_requests if set max_req = model_data.get("quota_max_requests") if max_req: @@ -3152,6 +3196,7 @@ async def update_quota_baseline( remaining_fraction: float, max_requests: Optional[int] = None, reset_timestamp: Optional[float] = None, + sync_mode: str = "force", ) -> Optional[Dict[str, Any]]: """ Update quota baseline data for a credential/model after fetching from API. @@ -3170,6 +3215,10 @@ async def update_quota_baseline( reset_timestamp: Unix timestamp when quota resets. Only trusted when remaining_fraction < 1.0 (quota has been used). API returns garbage reset times for unused quota (100%). + sync_mode: How to sync request_count from API: + - "force": Always overwrite with API value (default, backwards-compatible) + - "if_exhausted": Use max() but overwrite if exhausted (first refresh) + - "none": Don't touch request_count (local counting is authoritative) Returns: None if no cooldown was set/updated, otherwise: @@ -3243,9 +3292,24 @@ async def update_quota_baseline( used_requests = model_data.get("request_count", 0) max_requests = model_data.get("quota_max_requests") - # Sync local request count to API's authoritative value - model_data["request_count"] = used_requests - model_data["requests_at_baseline"] = used_requests + # Sync request_count based on sync_mode + current_count = model_data.get("request_count", 0) + if sync_mode == "force": + # Force refresh: always overwrite with API value + synced_count = used_requests + elif sync_mode == "if_exhausted": + if remaining_fraction <= 0.0: + # Exhausted: accept API value (pick up state from other instances) + synced_count = used_requests + else: + # Not exhausted: use max() to prevent downward reset + synced_count = max(current_count, used_requests) + else: # sync_mode == "none" + # Don't touch request_count (local counting is authoritative) + synced_count = current_count + + model_data["request_count"] = synced_count + model_data["requests_at_baseline"] = synced_count # Update baseline fields model_data["baseline_remaining_fraction"] = remaining_fraction @@ -3254,7 +3318,7 @@ async def update_quota_baseline( # Update max_requests and quota_display if max_requests is not None: model_data["quota_max_requests"] = max_requests - model_data["quota_display"] = f"{used_requests}/{max_requests}" + model_data["quota_display"] = f"{synced_count}/{max_requests}" # Handle reset_timestamp: only trust it when quota has been used (< 100%) # API returns garbage reset times for unused quota @@ -3270,11 +3334,13 @@ async def update_quota_baseline( # Set cooldowns when quota is exhausted model_cooldowns = key_data.setdefault("model_cooldowns", {}) is_exhausted = remaining_fraction <= 0.0 + # Only mark exhausted from API if sync_mode allows it + can_mark_exhausted = sync_mode in ("force", "if_exhausted") cooldown_set_info = ( None # Will be returned if cooldown was newly set/updated ) - if is_exhausted and valid_reset_ts: + if is_exhausted and valid_reset_ts and can_mark_exhausted: # Only update cooldown if not set or differs by more than 5 minutes existing_cooldown = model_cooldowns.get(model) should_update = ( @@ -3339,19 +3405,19 @@ async def update_quota_baseline( "approx_cost": 0.0, }, ) - # Sync request tracking - other_model_data["request_count"] = used_requests + # Sync request tracking (use synced_count for consistency) + other_model_data["request_count"] = synced_count if max_requests is not None: other_model_data["quota_max_requests"] = max_requests other_model_data["quota_display"] = ( - f"{used_requests}/{max_requests}" + f"{synced_count}/{max_requests}" ) # Sync baseline fields other_model_data["baseline_remaining_fraction"] = ( remaining_fraction ) other_model_data["baseline_fetched_at"] = now_ts - other_model_data["requests_at_baseline"] = used_requests + other_model_data["requests_at_baseline"] = synced_count # Sync reset timestamp if valid if valid_reset_ts: other_model_data["quota_reset_ts"] = reset_timestamp @@ -3360,7 +3426,7 @@ async def update_quota_baseline( if window_start: other_model_data["window_start_ts"] = window_start # Sync cooldown if exhausted (with ±5 min check) - if is_exhausted and valid_reset_ts: + if is_exhausted and valid_reset_ts and can_mark_exhausted: existing_grouped = model_cooldowns.get(grouped_model) should_update_grouped = ( existing_grouped is None @@ -3381,7 +3447,7 @@ async def update_quota_baseline( lib_logger.debug( f"Updated quota baseline for {mask_credential(credential)} model={model}: " - f"remaining={remaining_fraction:.2%}, synced_request_count={used_requests}" + f"remaining={remaining_fraction:.2%}, synced_request_count={synced_count}, sync_mode={sync_mode}" ) await self._save_usage()