|
20 | 20 | from .config import COLORS, AVAILABLE_MODELS, DEFAULT_MODEL, get_system_prompt, ModelProvider |
21 | 21 | from .clients import ClientManager, StreamChunk, ToolCall, ExecutedTool |
22 | 22 | from .lib.prompts import mode_manager |
23 | | -from .api_key_manager import api_key_manager, is_rate_limit_error, is_credit_error |
| 23 | +from .api_key_manager import ( |
| 24 | + api_key_manager, is_rate_limit_error, is_credit_error, |
| 25 | + set_rotation_callbacks, model_fallback_manager |
| 26 | +) |
24 | 27 | from .tools import TOOL_DEFINITIONS, execute_tool, TOOLS, get_all_tool_definitions |
25 | | -from .ui import console, display_tool_call, display_tool_result, display_executed_tool, display_code_execution_result, display_info, display_warning |
| 28 | +from .ui import ( |
| 29 | + console, display_tool_call, display_tool_result, display_executed_tool, |
| 30 | + display_code_execution_result, display_info, display_warning, |
| 31 | + display_key_rotation_notice, display_model_fallback_notice, display_provider_exhausted_notice |
| 32 | +) |
26 | 33 | from .logger import log_error, log_api_error, log_tool_error, log_debug |
27 | 34 | from .history import history_manager |
28 | 35 | from .name_detector import detect_and_save_name |
@@ -202,11 +209,30 @@ def __init__(self, model_key: str = DEFAULT_MODEL): |
202 | 209 | self._init_system_prompt() |
203 | 210 | # Start a new conversation |
204 | 211 | history_manager.start_new_conversation() |
| 212 | + # Set up rotation and fallback callbacks for user notifications |
| 213 | + self._setup_rotation_callbacks() |
205 | 214 |
|
206 | 215 | def set_status_callback(self, callback: StatusCallback): |
207 | 216 | """Set a callback function for status updates""" |
208 | 217 | self._status_callback = callback |
209 | 218 |
|
| 219 | + def _setup_rotation_callbacks(self): |
| 220 | + """Set up callbacks for API key rotation and model fallback notifications""" |
| 221 | + def on_key_rotated(provider: str, old_key: str, new_key: str): |
| 222 | + display_key_rotation_notice(provider, "rate limit or quota exceeded") |
| 223 | + |
| 224 | + def on_model_fallback(provider: str, old_model: str, new_model: str): |
| 225 | + display_model_fallback_notice(provider, old_model, new_model) |
| 226 | + |
| 227 | + def on_provider_exhausted(provider: str): |
| 228 | + display_provider_exhausted_notice(provider) |
| 229 | + |
| 230 | + set_rotation_callbacks( |
| 231 | + on_key_rotated=on_key_rotated, |
| 232 | + on_model_fallback=on_model_fallback, |
| 233 | + on_provider_exhausted=on_provider_exhausted |
| 234 | + ) |
| 235 | + |
210 | 236 | def _update_status(self, status: str, detail: str = ""): |
211 | 237 | """Send a status update if callback is set""" |
212 | 238 | if self._status_callback: |
@@ -815,11 +841,31 @@ def chat(self, user_input: str, _retry_count: int = 0) -> str: |
815 | 841 | self.messages.append({"role": "user", "content": user_input}) |
816 | 842 | return self.chat(user_input, _retry_count=_retry_count + 1) |
817 | 843 |
|
818 | | - # Check if this is a quota/rate limit error - try to switch provider |
| 844 | + # Check if this is a quota/rate limit error - try model fallback first, then provider switch |
819 | 845 | current_provider = AVAILABLE_MODELS[self.model_key].provider.value |
820 | 846 | if is_quota_or_rate_error(error_str): |
821 | 847 | log_debug(f"Quota/rate error detected for {current_provider}") |
822 | 848 |
|
| 849 | + # First, try to fallback to a simpler model within the same provider (if enabled) |
| 850 | + if model_fallback_manager.is_enabled() and _retry_count < 2: |
| 851 | + fallback_model_id = model_fallback_manager.get_fallback_model(current_provider, model_id) |
| 852 | + if fallback_model_id: |
| 853 | + # Find the model key for this fallback model |
| 854 | + for key, config in AVAILABLE_MODELS.items(): |
| 855 | + if config.id == fallback_model_id and config.provider.value == current_provider: |
| 856 | + # Activate the fallback with notification |
| 857 | + model_fallback_manager.activate_fallback( |
| 858 | + current_provider, |
| 859 | + model_id, |
| 860 | + fallback_model_id, |
| 861 | + duration_minutes=5 |
| 862 | + ) |
| 863 | + old_model_key = self.model_key |
| 864 | + self.model_key = key |
| 865 | + log_debug(f"Model fallback: {model_id} -> {fallback_model_id}") |
| 866 | + # Retry with simpler model |
| 867 | + return self.chat(user_input, _retry_count=_retry_count + 1) |
| 868 | + |
823 | 869 | # Show friendly message |
824 | 870 | friendly_msg = get_friendly_quota_message(current_provider) |
825 | 871 | console.print() |
|
0 commit comments