diff --git a/.env.example b/.env.example index ce2c0cc8..26ffda48 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,10 @@ -# Required: Your OpenAI API key +# Required: Your OpenAI API key(s) +# Single key: OPENAI_API_KEY="sk-your-openai-api-key-here" +# Multiple keys (comma-separated for load balancing and failover): +# OPENAI_API_KEY="sk-key1,sk-key2,sk-key3" + # Optional: Expected Anthropic API key for client validation # If set, clients must provide this exact API key to access the proxy ANTHROPIC_API_KEY="your-expected-anthropic-api-key" @@ -34,6 +38,7 @@ MAX_RETRIES="2" # For Azure OpenAI (recommended if OpenAI is not available in your region): # OPENAI_API_KEY="your-azure-api-key" +# Multiple Azure keys: OPENAI_API_KEY="azure-key1,azure-key2" # OPENAI_BASE_URL="https://your-resource-name.openai.azure.com/openai/deployments/your-deployment-name" # AZURE_API_VERSION="2024-03-01-preview" # BIG_MODEL="gpt-4" diff --git a/.gitignore b/.gitignore index 92ced80c..ad4a1f17 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,3 @@ poetry.toml pyrightconfig.json # End of https://www.toptal.com/developers/gitignore/api/python -n \ No newline at end of file diff --git a/README.md b/README.md index a3d12729..9dc732f3 100755 --- a/README.md +++ b/README.md @@ -59,7 +59,10 @@ ANTHROPIC_BASE_URL=http://localhost:8082 ANTHROPIC_API_KEY="exact-matching-key" **Required:** -- `OPENAI_API_KEY` - Your API key for the target provider +- `OPENAI_API_KEY` - Your API key(s) for the target provider + - Single key: `OPENAI_API_KEY="sk-your-key"` + - Multiple keys (comma-separated): `OPENAI_API_KEY="sk-key1,sk-key2,sk-key3"` + - Multiple keys support automatic load balancing and failover **Security:** @@ -134,6 +137,40 @@ SMALL_MODEL="llama3.1:8b" Any OpenAI-compatible API can be used by setting the appropriate `OPENAI_BASE_URL`. +## Multiple API Keys Support + +The proxy now supports multiple OpenAI API keys for improved reliability and load distribution: + +### Configuration + +```bash +# Multiple keys separated by commas +OPENAI_API_KEY="sk-key1,sk-key2,sk-key3" +``` + +### Features + +- **Load Balancing**: Requests are distributed across all available keys using round-robin +- **Automatic Failover**: If one key fails (rate limit, auth error), the proxy automatically tries the next key +- **Cooldown Management**: Failed keys are temporarily disabled (5 minutes by default) before being retried +- **Status Monitoring**: Check the status of all keys via `/api-keys/status` endpoint + +### Monitoring API Keys + +```bash +# Check status of all API keys +curl http://localhost:8082/api-keys/status + +# Reset all failed keys (remove from cooldown) +curl -X POST http://localhost:8082/api-keys/reset +``` + +### Benefits + +- **Higher Rate Limits**: Combine rate limits from multiple API keys +- **Better Reliability**: Service continues even if some keys fail +- **Reduced Downtime**: Automatic failover prevents service interruption + ## Usage Examples ### Basic Chat diff --git a/src/api/endpoints.py b/src/api/endpoints.py index 94ab8b1c..430cbcd7 100644 --- a/src/api/endpoints.py +++ b/src/api/endpoints.py @@ -18,7 +18,7 @@ router = APIRouter() openai_client = OpenAIClient( - config.openai_api_key, + config.openai_api_keys, config.openai_base_url, config.request_timeout, api_version=config.azure_api_version, @@ -159,10 +159,14 @@ async def count_tokens(request: ClaudeTokenCountRequest, _: None = Depends(valid @router.get("/health") async def health_check(): """Health check endpoint""" + api_key_status = openai_client.get_api_key_status() return { "status": "healthy", "timestamp": datetime.now().isoformat(), - "openai_api_configured": bool(config.openai_api_key), + "openai_api_configured": bool(config.openai_api_keys), + "api_key_count": config.get_api_key_count(), + "available_api_keys": api_key_status["available_keys"], + "failed_api_keys": api_key_status["failed_keys"], "api_key_valid": config.validate_api_key(), "client_api_key_validation": bool(config.anthropic_api_key), } @@ -216,7 +220,8 @@ async def root(): "config": { "openai_base_url": config.openai_base_url, "max_tokens_limit": config.max_tokens_limit, - "api_key_configured": bool(config.openai_api_key), + "api_key_configured": bool(config.openai_api_keys), + "api_key_count": config.get_api_key_count(), "client_api_key_validation": bool(config.anthropic_api_key), "big_model": config.big_model, "small_model": config.small_model, @@ -226,5 +231,24 @@ async def root(): "count_tokens": "/v1/messages/count_tokens", "health": "/health", "test_connection": "/test-connection", + "api_keys_status": "/api-keys/status", + "api_keys_reset": "/api-keys/reset", }, } + + +@router.get("/api-keys/status") +async def api_keys_status(): + """Get detailed status of all API keys""" + return openai_client.get_api_key_status() + + +@router.post("/api-keys/reset") +async def reset_api_keys(): + """Reset all failed API keys (remove from cooldown)""" + openai_client.reset_api_key_failures() + return { + "message": "All API key failures have been reset", + "timestamp": datetime.now().isoformat(), + "status": openai_client.get_api_key_status() + } diff --git a/src/core/api_key_manager.py b/src/core/api_key_manager.py new file mode 100644 index 00000000..7187dd3a --- /dev/null +++ b/src/core/api_key_manager.py @@ -0,0 +1,130 @@ +import asyncio +import time +from typing import List, Optional, Dict, Set +from threading import Lock +import logging + +logger = logging.getLogger(__name__) + +class APIKeyManager: + """Manages multiple OpenAI API keys with round-robin distribution and error handling.""" + + def __init__(self, api_keys: List[str], cooldown_period: int = 300): + """ + Initialize the API key manager. + + Args: + api_keys: List of OpenAI API keys + cooldown_period: Time in seconds to wait before retrying a failed key + """ + self.api_keys = api_keys + self.cooldown_period = cooldown_period + self.current_index = 0 + self.failed_keys: Dict[str, float] = {} # key -> timestamp of failure + self.lock = Lock() + + logger.info(f"Initialized API key manager with {len(api_keys)} keys") + + def get_next_key(self) -> Optional[str]: + """ + Get the next available API key using round-robin strategy. + + Returns: + Next available API key or None if all keys are in cooldown + """ + with self.lock: + current_time = time.time() + + # Clean up expired cooldowns + expired_keys = [ + key for key, fail_time in self.failed_keys.items() + if current_time - fail_time > self.cooldown_period + ] + for key in expired_keys: + del self.failed_keys[key] + logger.info(f"API key cooldown expired, key is available again") + + # Find next available key + attempts = 0 + while attempts < len(self.api_keys): + key = self.api_keys[self.current_index] + self.current_index = (self.current_index + 1) % len(self.api_keys) + + if key not in self.failed_keys: + logger.debug(f"Selected API key index {self.current_index - 1}") + return key + + attempts += 1 + + # All keys are in cooldown + logger.warning("All API keys are in cooldown period") + return None + + def mark_key_failed(self, api_key: str, error_message: str = ""): + """ + Mark an API key as failed and put it in cooldown. + + Args: + api_key: The failed API key + error_message: Optional error message for logging + """ + with self.lock: + self.failed_keys[api_key] = time.time() + key_index = self.api_keys.index(api_key) if api_key in self.api_keys else -1 + logger.warning(f"API key (index {key_index}) marked as failed: {error_message}") + + def get_available_key_count(self) -> int: + """Get the number of currently available (not in cooldown) API keys.""" + with self.lock: + current_time = time.time() + available_count = 0 + + for key in self.api_keys: + if key not in self.failed_keys: + available_count += 1 + elif current_time - self.failed_keys[key] > self.cooldown_period: + available_count += 1 + + return available_count + + def get_status(self) -> Dict: + """Get the current status of all API keys.""" + with self.lock: + current_time = time.time() + status = { + "total_keys": len(self.api_keys), + "available_keys": 0, + "failed_keys": 0, + "keys_status": [] + } + + for i, key in enumerate(self.api_keys): + key_status = { + "index": i, + "key_prefix": key[:10] + "..." if len(key) > 10 else key, + "status": "available" + } + + if key in self.failed_keys: + fail_time = self.failed_keys[key] + time_since_failure = current_time - fail_time + + if time_since_failure > self.cooldown_period: + key_status["status"] = "available" + status["available_keys"] += 1 + else: + key_status["status"] = "cooldown" + key_status["cooldown_remaining"] = int(self.cooldown_period - time_since_failure) + status["failed_keys"] += 1 + else: + status["available_keys"] += 1 + + status["keys_status"].append(key_status) + + return status + + def reset_all_failures(self): + """Reset all failed keys (remove from cooldown).""" + with self.lock: + self.failed_keys.clear() + logger.info("All API key failures have been reset") \ No newline at end of file diff --git a/src/core/client.py b/src/core/client.py index dabd977b..d001f58c 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -1,36 +1,47 @@ import asyncio import json from fastapi import HTTPException -from typing import Optional, AsyncGenerator, Dict, Any +from typing import Optional, AsyncGenerator, Dict, Any, List from openai import AsyncOpenAI, AsyncAzureOpenAI from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai._exceptions import APIError, RateLimitError, AuthenticationError, BadRequestError +from .api_key_manager import APIKeyManager class OpenAIClient: - """Async OpenAI client with cancellation support.""" + """Async OpenAI client with cancellation support and multiple API key management.""" - def __init__(self, api_key: str, base_url: str, timeout: int = 90, api_version: Optional[str] = None): - self.api_key = api_key + def __init__(self, api_keys: List[str], base_url: str, timeout: int = 90, api_version: Optional[str] = None): + # Support both single key (backward compatibility) and multiple keys + if isinstance(api_keys, str): + api_keys = [api_keys] + + self.api_key_manager = APIKeyManager(api_keys) self.base_url = base_url + self.timeout = timeout + self.api_version = api_version + self.active_requests: Dict[str, asyncio.Event] = {} - # Detect if using Azure and instantiate the appropriate client - if api_version: - self.client = AsyncAzureOpenAI( + # Keep backward compatibility + self.api_key = api_keys[0] + + def _create_client(self, api_key: str): + """Create an OpenAI client instance with the given API key.""" + if self.api_version: + return AsyncAzureOpenAI( api_key=api_key, - azure_endpoint=base_url, - api_version=api_version, - timeout=timeout + azure_endpoint=self.base_url, + api_version=self.api_version, + timeout=self.timeout ) else: - self.client = AsyncOpenAI( + return AsyncOpenAI( api_key=api_key, - base_url=base_url, - timeout=timeout + base_url=self.base_url, + timeout=self.timeout ) - self.active_requests: Dict[str, asyncio.Event] = {} async def create_chat_completion(self, request: Dict[str, Any], request_id: Optional[str] = None) -> Dict[str, Any]: - """Send chat completion to OpenAI API with cancellation support.""" + """Send chat completion to OpenAI API with cancellation support and automatic key rotation.""" # Create cancellation token if request_id provided if request_id: @@ -38,50 +49,90 @@ async def create_chat_completion(self, request: Dict[str, Any], request_id: Opti self.active_requests[request_id] = cancel_event try: - # Create task that can be cancelled - completion_task = asyncio.create_task( - self.client.chat.completions.create(**request) - ) + last_exception = None + attempts = 0 + max_attempts = self.api_key_manager.get_available_key_count() - if request_id: - # Wait for either completion or cancellation - cancel_task = asyncio.create_task(cancel_event.wait()) - done, pending = await asyncio.wait( - [completion_task, cancel_task], - return_when=asyncio.FIRST_COMPLETED - ) - - # Cancel pending tasks - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + while attempts < max_attempts: + # Get next available API key + api_key = self.api_key_manager.get_next_key() + if not api_key: + # All keys are in cooldown + if last_exception: + raise last_exception + raise HTTPException(status_code=503, detail="All API keys are temporarily unavailable") - # Check if request was cancelled - if cancel_task in done: - completion_task.cancel() - raise HTTPException(status_code=499, detail="Request cancelled by client") + try: + # Create client with current API key + client = self._create_client(api_key) + + # Create task that can be cancelled + completion_task = asyncio.create_task( + client.chat.completions.create(**request) + ) + + if request_id: + # Wait for either completion or cancellation + cancel_task = asyncio.create_task(cancel_event.wait()) + done, pending = await asyncio.wait( + [completion_task, cancel_task], + return_when=asyncio.FIRST_COMPLETED + ) + + # Cancel pending tasks + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Check if request was cancelled + if cancel_task in done: + completion_task.cancel() + raise HTTPException(status_code=499, detail="Request cancelled by client") + + completion = await completion_task + else: + completion = await completion_task + + # Success! Convert to dict format that matches the original interface + return completion.model_dump() - completion = await completion_task - else: - completion = await completion_task + except (AuthenticationError, RateLimitError) as e: + # These errors indicate the API key should be marked as failed + self.api_key_manager.mark_key_failed(api_key, str(e)) + last_exception = HTTPException( + status_code=401 if isinstance(e, AuthenticationError) else 429, + detail=self.classify_openai_error(str(e)) + ) + attempts += 1 + continue + + except BadRequestError as e: + # Bad request errors are not key-specific, don't retry + raise HTTPException(status_code=400, detail=self.classify_openai_error(str(e))) + except APIError as e: + status_code = getattr(e, 'status_code', 500) + # For 5xx errors, we might want to retry with a different key + if status_code >= 500: + self.api_key_manager.mark_key_failed(api_key, str(e)) + last_exception = HTTPException(status_code=status_code, detail=self.classify_openai_error(str(e))) + attempts += 1 + continue + else: + raise HTTPException(status_code=status_code, detail=self.classify_openai_error(str(e))) + except Exception as e: + # For unexpected errors, try next key + self.api_key_manager.mark_key_failed(api_key, str(e)) + last_exception = HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") + attempts += 1 + continue - # Convert to dict format that matches the original interface - return completion.model_dump() - - except AuthenticationError as e: - raise HTTPException(status_code=401, detail=self.classify_openai_error(str(e))) - except RateLimitError as e: - raise HTTPException(status_code=429, detail=self.classify_openai_error(str(e))) - except BadRequestError as e: - raise HTTPException(status_code=400, detail=self.classify_openai_error(str(e))) - except APIError as e: - status_code = getattr(e, 'status_code', 500) - raise HTTPException(status_code=status_code, detail=self.classify_openai_error(str(e))) - except Exception as e: - raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") + # If we get here, all attempts failed + if last_exception: + raise last_exception + raise HTTPException(status_code=503, detail="All API keys failed") finally: # Clean up active request tracking @@ -89,7 +140,7 @@ async def create_chat_completion(self, request: Dict[str, Any], request_id: Opti del self.active_requests[request_id] async def create_chat_completion_stream(self, request: Dict[str, Any], request_id: Optional[str] = None) -> AsyncGenerator[str, None]: - """Send streaming chat completion to OpenAI API with cancellation support.""" + """Send streaming chat completion to OpenAI API with cancellation support and automatic key rotation.""" # Create cancellation token if request_id provided if request_id: @@ -97,40 +148,81 @@ async def create_chat_completion_stream(self, request: Dict[str, Any], request_i self.active_requests[request_id] = cancel_event try: - # Ensure stream is enabled - request["stream"] = True - if "stream_options" not in request: - request["stream_options"] = {} - request["stream_options"]["include_usage"] = True - - # Create the streaming completion - streaming_completion = await self.client.chat.completions.create(**request) + last_exception = None + attempts = 0 + max_attempts = self.api_key_manager.get_available_key_count() - async for chunk in streaming_completion: - # Check for cancellation before yielding each chunk - if request_id and request_id in self.active_requests: - if self.active_requests[request_id].is_set(): - raise HTTPException(status_code=499, detail="Request cancelled by client") + while attempts < max_attempts: + # Get next available API key + api_key = self.api_key_manager.get_next_key() + if not api_key: + # All keys are in cooldown + if last_exception: + raise last_exception + raise HTTPException(status_code=503, detail="All API keys are temporarily unavailable") - # Convert chunk to SSE format matching original HTTP client format - chunk_dict = chunk.model_dump() - chunk_json = json.dumps(chunk_dict, ensure_ascii=False) - yield f"data: {chunk_json}" + try: + # Ensure stream is enabled + request["stream"] = True + if "stream_options" not in request: + request["stream_options"] = {} + request["stream_options"]["include_usage"] = True + + # Create client with current API key + client = self._create_client(api_key) + + # Create the streaming completion + streaming_completion = await client.chat.completions.create(**request) + + async for chunk in streaming_completion: + # Check for cancellation before yielding each chunk + if request_id and request_id in self.active_requests: + if self.active_requests[request_id].is_set(): + raise HTTPException(status_code=499, detail="Request cancelled by client") + + # Convert chunk to SSE format matching original HTTP client format + chunk_dict = chunk.model_dump() + chunk_json = json.dumps(chunk_dict, ensure_ascii=False) + yield f"data: {chunk_json}" + + # Signal end of stream + yield "data: [DONE]" + return # Success, exit the retry loop + + except (AuthenticationError, RateLimitError) as e: + # These errors indicate the API key should be marked as failed + self.api_key_manager.mark_key_failed(api_key, str(e)) + last_exception = HTTPException( + status_code=401 if isinstance(e, AuthenticationError) else 429, + detail=self.classify_openai_error(str(e)) + ) + attempts += 1 + continue + + except BadRequestError as e: + # Bad request errors are not key-specific, don't retry + raise HTTPException(status_code=400, detail=self.classify_openai_error(str(e))) + except APIError as e: + status_code = getattr(e, 'status_code', 500) + # For 5xx errors, we might want to retry with a different key + if status_code >= 500: + self.api_key_manager.mark_key_failed(api_key, str(e)) + last_exception = HTTPException(status_code=status_code, detail=self.classify_openai_error(str(e))) + attempts += 1 + continue + else: + raise HTTPException(status_code=status_code, detail=self.classify_openai_error(str(e))) + except Exception as e: + # For unexpected errors, try next key + self.api_key_manager.mark_key_failed(api_key, str(e)) + last_exception = HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") + attempts += 1 + continue - # Signal end of stream - yield "data: [DONE]" - - except AuthenticationError as e: - raise HTTPException(status_code=401, detail=self.classify_openai_error(str(e))) - except RateLimitError as e: - raise HTTPException(status_code=429, detail=self.classify_openai_error(str(e))) - except BadRequestError as e: - raise HTTPException(status_code=400, detail=self.classify_openai_error(str(e))) - except APIError as e: - status_code = getattr(e, 'status_code', 500) - raise HTTPException(status_code=status_code, detail=self.classify_openai_error(str(e))) - except Exception as e: - raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") + # If we get here, all attempts failed + if last_exception: + raise last_exception + raise HTTPException(status_code=503, detail="All API keys failed") finally: # Clean up active request tracking @@ -169,4 +261,12 @@ def cancel_request(self, request_id: str) -> bool: if request_id in self.active_requests: self.active_requests[request_id].set() return True - return False \ No newline at end of file + return False + + def get_api_key_status(self) -> Dict: + """Get the current status of all API keys.""" + return self.api_key_manager.get_status() + + def reset_api_key_failures(self): + """Reset all API key failures.""" + self.api_key_manager.reset_all_failures() \ No newline at end of file diff --git a/src/core/config.py b/src/core/config.py index 7254d3c6..ac0da85c 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -1,13 +1,22 @@ import os import sys +from typing import List # Configuration class Config: def __init__(self): - self.openai_api_key = os.environ.get("OPENAI_API_KEY") - if not self.openai_api_key: + openai_api_key_str = os.environ.get("OPENAI_API_KEY") + if not openai_api_key_str: raise ValueError("OPENAI_API_KEY not found in environment variables") + # Support multiple API keys separated by commas + self.openai_api_keys = [key.strip() for key in openai_api_key_str.split(",") if key.strip()] + if not self.openai_api_keys: + raise ValueError("No valid OPENAI_API_KEY found in environment variables") + + # Keep backward compatibility - first key as primary + self.openai_api_key = self.openai_api_keys[0] + # Add Anthropic API key for client validation self.anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY") if not self.anthropic_api_key: @@ -31,13 +40,18 @@ def __init__(self): self.small_model = os.environ.get("SMALL_MODEL", "gpt-4o-mini") def validate_api_key(self): - """Basic API key validation""" - if not self.openai_api_key: + """Basic API key validation for all keys""" + if not self.openai_api_keys: return False # Basic format check for OpenAI API keys - if not self.openai_api_key.startswith('sk-'): - return False + for key in self.openai_api_keys: + if not key.startswith('sk-'): + return False return True + + def get_api_key_count(self): + """Get the number of configured API keys""" + return len(self.openai_api_keys) def validate_client_api_key(self, client_api_key): """Validate client's Anthropic API key""" @@ -50,7 +64,8 @@ def validate_client_api_key(self, client_api_key): try: config = Config() - print(f" Configuration loaded: API_KEY={'*' * 20}..., BASE_URL='{config.openai_base_url}'") + key_count = config.get_api_key_count() + print(f" Configuration loaded: {key_count} API_KEY(s) configured, BASE_URL='{config.openai_base_url}'") except Exception as e: print(f"=4 Configuration Error: {e}") sys.exit(1)