diff --git a/pyproject.toml b/pyproject.toml index 5f422f05..d7d3c53b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,4 +69,11 @@ packages = ["src"] python_version = "3.9" warn_return_any = true warn_unused_configs = true -disallow_untyped_defs = true \ No newline at end of file +disallow_untyped_defs = true + + +[tool.ruff] +line-length = 100 + +[tool.ruff.format] +quote-style = "double" \ No newline at end of file diff --git a/src/api/endpoints.py b/src/api/endpoints.py index 94ab8b1c..3324ebc4 100644 --- a/src/api/endpoints.py +++ b/src/api/endpoints.py @@ -1,19 +1,20 @@ -from fastapi import APIRouter, HTTPException, Request, Header, Depends -from fastapi.responses import JSONResponse, StreamingResponse -from datetime import datetime import uuid +from datetime import datetime from typing import Optional -from src.core.config import config -from src.core.logging import logger -from src.core.client import OpenAIClient -from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest +from fastapi import APIRouter, Depends, Header, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + from src.conversion.request_converter import convert_claude_to_openai from src.conversion.response_converter import ( - convert_openai_to_claude_response, convert_openai_streaming_to_claude_with_cancellation, + convert_openai_to_claude_response, ) +from src.core.client import OpenAIClient +from src.core.config import config +from src.core.logging import logger from src.core.model_manager import model_manager +from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest router = APIRouter() @@ -24,34 +25,40 @@ api_version=config.azure_api_version, ) -async def validate_api_key(x_api_key: Optional[str] = Header(None), authorization: Optional[str] = Header(None)): + +async def validate_api_key( + x_api_key: Optional[str] = Header(None), authorization: Optional[str] = Header(None) +): """Validate the client's API key from either x-api-key header or Authorization header.""" client_api_key = None - + # Extract API key from headers if x_api_key: client_api_key = x_api_key elif authorization and authorization.startswith("Bearer "): client_api_key = authorization.replace("Bearer ", "") - + # Skip validation if ANTHROPIC_API_KEY is not set in the environment if not config.anthropic_api_key: return - + # Validate the client API key if not client_api_key or not config.validate_client_api_key(client_api_key): - logger.warning(f"Invalid API key provided by client") + logger.warning("Invalid API key provided by client") raise HTTPException( status_code=401, - detail="Invalid API key. Please provide a valid Anthropic API key." + detail="Invalid API key. Please provide a valid Anthropic API key.", ) + @router.post("/v1/messages") -async def create_message(request: ClaudeMessagesRequest, http_request: Request, _: None = Depends(validate_api_key)): +async def create_message( + request: ClaudeMessagesRequest, + http_request: Request, + _: None = Depends(validate_api_key), +): try: - logger.debug( - f"Processing Claude request: model={request.model}, stream={request.stream}" - ) + logger.info(f"Processing Claude request: model={request.model}, stream={request.stream}") # Generate unique request ID for cancellation tracking request_id = str(uuid.uuid4()) @@ -100,12 +107,8 @@ async def create_message(request: ClaudeMessagesRequest, http_request: Request, return JSONResponse(status_code=e.status_code, content=error_response) else: # Non-streaming response - openai_response = await openai_client.create_chat_completion( - openai_request, request_id - ) - claude_response = convert_openai_to_claude_response( - openai_response, request - ) + openai_response = await openai_client.create_chat_completion(openai_request, request_id) + claude_response = convert_openai_to_claude_response(openai_response, request) return claude_response except HTTPException: raise @@ -123,7 +126,6 @@ async def count_tokens(request: ClaudeTokenCountRequest, _: None = Depends(valid try: # For token counting, we'll use a simple estimation # In a real implementation, you might want to use tiktoken or similar - total_chars = 0 # Count system message characters @@ -173,13 +175,17 @@ async def test_connection(): """Test API connectivity to OpenAI""" try: # Simple test request to verify API connectivity - test_response = await openai_client.create_chat_completion( - { - "model": config.small_model, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 5, - } - ) + openai_request = { + "model": config.small_model, + "messages": [{"role": "user", "content": "Hello"}], + } + if model_manager.is_o3_model(config.small_model): + openai_request["max_completion_tokens"] = 200 + openai_request["temperature"] = 1 + else: + openai_request["max_tokens"] = 5 + + test_response = await openai_client.create_chat_completion(openai_request) return { "status": "success", diff --git a/src/conversion/request_converter.py b/src/conversion/request_converter.py index f341c709..b1f0a975 100644 --- a/src/conversion/request_converter.py +++ b/src/conversion/request_converter.py @@ -1,10 +1,11 @@ import json -from typing import Dict, Any, List +import logging +from typing import Any, Dict, List from venv import logger -from src.core.constants import Constants -from src.models.claude import ClaudeMessagesRequest, ClaudeMessage + from src.core.config import config -import logging +from src.core.constants import Constants +from src.models.claude import ClaudeMessage, ClaudeMessagesRequest logger = logging.getLogger(__name__) @@ -30,17 +31,12 @@ def convert_claude_to_openai( for block in claude_request.system: if hasattr(block, "type") and block.type == Constants.CONTENT_TEXT: text_parts.append(block.text) - elif ( - isinstance(block, dict) - and block.get("type") == Constants.CONTENT_TEXT - ): + elif isinstance(block, dict) and block.get("type") == Constants.CONTENT_TEXT: text_parts.append(block.get("text", "")) system_text = "\n\n".join(text_parts) if system_text.strip(): - openai_messages.append( - {"role": Constants.ROLE_SYSTEM, "content": system_text.strip()} - ) + openai_messages.append({"role": Constants.ROLE_SYSTEM, "content": system_text.strip()}) # Process Claude messages i = 0 @@ -77,13 +73,21 @@ def convert_claude_to_openai( openai_request = { "model": openai_model, "messages": openai_messages, - "max_tokens": min( - max(claude_request.max_tokens, config.min_tokens_limit), - config.max_tokens_limit, - ), - "temperature": claude_request.temperature, "stream": claude_request.stream, } + + # Handle max tokens based on model type + max_tokens_value = min( + max(claude_request.max_tokens, config.min_tokens_limit), + config.max_tokens_limit, + ) + if model_manager.is_o3_model(openai_model): + openai_request["max_completion_tokens"] = max_tokens_value + openai_request["temperature"] = 1 + else: + openai_request["max_tokens"] = max_tokens_value + openai_request["temperature"] = claude_request.temperature + logger.debug( f"Converted Claude request to OpenAI format: {json.dumps(openai_request, indent=2, ensure_ascii=False)}" ) @@ -133,7 +137,7 @@ def convert_claude_user_message(msg: ClaudeMessage) -> Dict[str, Any]: """Convert Claude user message to OpenAI format.""" if msg.content is None: return {"role": Constants.ROLE_USER, "content": ""} - + if isinstance(msg.content, str): return {"role": Constants.ROLE_USER, "content": msg.content} @@ -172,7 +176,7 @@ def convert_claude_assistant_message(msg: ClaudeMessage) -> Dict[str, Any]: if msg.content is None: return {"role": Constants.ROLE_ASSISTANT, "content": None} - + if isinstance(msg.content, str): return {"role": Constants.ROLE_ASSISTANT, "content": msg.content} diff --git a/src/conversion/response_converter.py b/src/conversion/response_converter.py index 2980c37d..42fec32c 100644 --- a/src/conversion/response_converter.py +++ b/src/conversion/response_converter.py @@ -1,6 +1,8 @@ import json import uuid + from fastapi import HTTPException, Request + from src.core.constants import Constants from src.models.claude import ClaudeMessagesRequest @@ -69,9 +71,7 @@ def convert_openai_to_claude_response( "stop_sequence": None, "usage": { "input_tokens": openai_response.get("usage", {}).get("prompt_tokens", 0), - "output_tokens": openai_response.get("usage", {}).get( - "completion_tokens", 0 - ), + "output_tokens": openai_response.get("usage", {}).get("completion_tokens", 0), }, } @@ -112,9 +112,7 @@ async def convert_openai_streaming_to_claude( if not choices: continue except json.JSONDecodeError as e: - logger.warning( - f"Failed to parse chunk: {chunk_data}, error: {e}" - ) + logger.warning(f"Failed to parse chunk: {chunk_data}, error: {e}") continue choice = choices[0] @@ -129,7 +127,7 @@ async def convert_openai_streaming_to_claude( if "tool_calls" in delta: for tc_delta in delta["tool_calls"]: tc_index = tc_delta.get("index", 0) - + # Initialize tool call tracking by index if not exists if tc_index not in current_tool_calls: current_tool_calls[tc_index] = { @@ -138,33 +136,37 @@ async def convert_openai_streaming_to_claude( "args_buffer": "", "json_sent": False, "claude_index": None, - "started": False + "started": False, } - + tool_call = current_tool_calls[tc_index] - + # Update tool call ID if provided if tc_delta.get("id"): tool_call["id"] = tc_delta["id"] - + # Update function name and start content block if we have both id and name function_data = tc_delta.get(Constants.TOOL_FUNCTION, {}) if function_data.get("name"): tool_call["name"] = function_data["name"] - + # Start content block when we have complete initial data - if (tool_call["id"] and tool_call["name"] and not tool_call["started"]): + if tool_call["id"] and tool_call["name"] and not tool_call["started"]: tool_block_counter += 1 claude_index = text_block_index + tool_block_counter tool_call["claude_index"] = claude_index tool_call["started"] = True - + yield f"event: {Constants.EVENT_CONTENT_BLOCK_START}\ndata: {json.dumps({'type': Constants.EVENT_CONTENT_BLOCK_START, 'index': claude_index, 'content_block': {'type': Constants.CONTENT_TOOL_USE, 'id': tool_call['id'], 'name': tool_call['name'], 'input': {}}}, ensure_ascii=False)}\n\n" - + # Handle function arguments - if "arguments" in function_data and tool_call["started"] and function_data["arguments"] is not None: + if ( + "arguments" in function_data + and tool_call["started"] + and function_data["arguments"] is not None + ): tool_call["args_buffer"] += function_data["arguments"] - + # Try to parse complete JSON and send delta when we have valid JSON try: json.loads(tool_call["args_buffer"]) @@ -259,21 +261,21 @@ async def convert_openai_streaming_to_claude_with_cancellation( usage = chunk.get("usage", None) if usage: cache_read_input_tokens = 0 - prompt_tokens_details = usage.get('prompt_tokens_details', {}) + prompt_tokens_details = usage.get("prompt_tokens_details", {}) if prompt_tokens_details: - cache_read_input_tokens = prompt_tokens_details.get('cached_tokens', 0) + cache_read_input_tokens = prompt_tokens_details.get( + "cached_tokens", 0 + ) usage_data = { - 'input_tokens': usage.get('prompt_tokens', 0), - 'output_tokens': usage.get('completion_tokens', 0), - 'cache_read_input_tokens': cache_read_input_tokens + "input_tokens": usage.get("prompt_tokens", 0), + "output_tokens": usage.get("completion_tokens", 0), + "cache_read_input_tokens": cache_read_input_tokens, } choices = chunk.get("choices", []) if not choices: continue except json.JSONDecodeError as e: - logger.warning( - f"Failed to parse chunk: {chunk_data}, error: {e}" - ) + logger.warning(f"Failed to parse chunk: {chunk_data}, error: {e}") continue choice = choices[0] @@ -288,7 +290,7 @@ async def convert_openai_streaming_to_claude_with_cancellation( if "tool_calls" in delta and delta["tool_calls"]: for tc_delta in delta["tool_calls"]: tc_index = tc_delta.get("index", 0) - + # Initialize tool call tracking by index if not exists if tc_index not in current_tool_calls: current_tool_calls[tc_index] = { @@ -297,33 +299,37 @@ async def convert_openai_streaming_to_claude_with_cancellation( "args_buffer": "", "json_sent": False, "claude_index": None, - "started": False + "started": False, } - + tool_call = current_tool_calls[tc_index] - + # Update tool call ID if provided if tc_delta.get("id"): tool_call["id"] = tc_delta["id"] - + # Update function name and start content block if we have both id and name function_data = tc_delta.get(Constants.TOOL_FUNCTION, {}) if function_data.get("name"): tool_call["name"] = function_data["name"] - + # Start content block when we have complete initial data - if (tool_call["id"] and tool_call["name"] and not tool_call["started"]): + if tool_call["id"] and tool_call["name"] and not tool_call["started"]: tool_block_counter += 1 claude_index = text_block_index + tool_block_counter tool_call["claude_index"] = claude_index tool_call["started"] = True - + yield f"event: {Constants.EVENT_CONTENT_BLOCK_START}\ndata: {json.dumps({'type': Constants.EVENT_CONTENT_BLOCK_START, 'index': claude_index, 'content_block': {'type': Constants.CONTENT_TOOL_USE, 'id': tool_call['id'], 'name': tool_call['name'], 'input': {}}}, ensure_ascii=False)}\n\n" - + # Handle function arguments - if "arguments" in function_data and tool_call["started"] and function_data["arguments"] is not None: + if ( + "arguments" in function_data + and tool_call["started"] + and function_data["arguments"] is not None + ): tool_call["args_buffer"] += function_data["arguments"] - + # Try to parse complete JSON and send delta when we have valid JSON try: json.loads(tool_call["args_buffer"]) diff --git a/src/core/client.py b/src/core/client.py index dabd977b..7fcbb7a6 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -1,56 +1,51 @@ import asyncio import json +from typing import Any, AsyncGenerator, Dict, Optional + from fastapi import HTTPException -from typing import Optional, AsyncGenerator, Dict, Any -from openai import AsyncOpenAI, AsyncAzureOpenAI -from openai.types.chat import ChatCompletion, ChatCompletionChunk -from openai._exceptions import APIError, RateLimitError, AuthenticationError, BadRequestError +from openai import AsyncAzureOpenAI, AsyncOpenAI +from openai._exceptions import APIError, AuthenticationError, BadRequestError, RateLimitError + class OpenAIClient: """Async OpenAI client with cancellation support.""" - - def __init__(self, api_key: str, base_url: str, timeout: int = 90, api_version: Optional[str] = None): + + def __init__( + self, api_key: str, base_url: str, timeout: int = 90, api_version: Optional[str] = None + ): self.api_key = api_key self.base_url = base_url - + # Detect if using Azure and instantiate the appropriate client if api_version: self.client = AsyncAzureOpenAI( - api_key=api_key, - azure_endpoint=base_url, - api_version=api_version, - timeout=timeout + api_key=api_key, azure_endpoint=base_url, api_version=api_version, timeout=timeout ) else: - self.client = AsyncOpenAI( - api_key=api_key, - base_url=base_url, - timeout=timeout - ) + self.client = AsyncOpenAI(api_key=api_key, base_url=base_url, timeout=timeout) self.active_requests: Dict[str, asyncio.Event] = {} - - async def create_chat_completion(self, request: Dict[str, Any], request_id: Optional[str] = None) -> Dict[str, Any]: + + async def create_chat_completion( + self, request: Dict[str, Any], request_id: Optional[str] = None + ) -> Dict[str, Any]: """Send chat completion to OpenAI API with cancellation support.""" - + # Create cancellation token if request_id provided if request_id: cancel_event = asyncio.Event() self.active_requests[request_id] = cancel_event - + try: # Create task that can be cancelled - completion_task = asyncio.create_task( - self.client.chat.completions.create(**request) - ) - + completion_task = asyncio.create_task(self.client.chat.completions.create(**request)) + if request_id: # Wait for either completion or cancellation cancel_task = asyncio.create_task(cancel_event.wait()) done, pending = await asyncio.wait( - [completion_task, cancel_task], - return_when=asyncio.FIRST_COMPLETED + [completion_task, cancel_task], return_when=asyncio.FIRST_COMPLETED ) - + # Cancel pending tasks for task in pending: task.cancel() @@ -58,19 +53,19 @@ async def create_chat_completion(self, request: Dict[str, Any], request_id: Opti await task except asyncio.CancelledError: pass - + # Check if request was cancelled if cancel_task in done: completion_task.cancel() raise HTTPException(status_code=499, detail="Request cancelled by client") - + completion = await completion_task else: completion = await completion_task - + # Convert to dict format that matches the original interface return completion.model_dump() - + except AuthenticationError as e: raise HTTPException(status_code=401, detail=self.classify_openai_error(str(e))) except RateLimitError as e: @@ -78,48 +73,50 @@ async def create_chat_completion(self, request: Dict[str, Any], request_id: Opti except BadRequestError as e: raise HTTPException(status_code=400, detail=self.classify_openai_error(str(e))) except APIError as e: - status_code = getattr(e, 'status_code', 500) + status_code = getattr(e, "status_code", 500) raise HTTPException(status_code=status_code, detail=self.classify_openai_error(str(e))) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") - + finally: # Clean up active request tracking if request_id and request_id in self.active_requests: del self.active_requests[request_id] - - async def create_chat_completion_stream(self, request: Dict[str, Any], request_id: Optional[str] = None) -> AsyncGenerator[str, None]: + + async def create_chat_completion_stream( + self, request: Dict[str, Any], request_id: Optional[str] = None + ) -> AsyncGenerator[str, None]: """Send streaming chat completion to OpenAI API with cancellation support.""" - + # Create cancellation token if request_id provided if request_id: cancel_event = asyncio.Event() self.active_requests[request_id] = cancel_event - + try: # Ensure stream is enabled request["stream"] = True if "stream_options" not in request: request["stream_options"] = {} request["stream_options"]["include_usage"] = True - + # Create the streaming completion streaming_completion = await self.client.chat.completions.create(**request) - + async for chunk in streaming_completion: # Check for cancellation before yielding each chunk if request_id and request_id in self.active_requests: if self.active_requests[request_id].is_set(): raise HTTPException(status_code=499, detail="Request cancelled by client") - + # Convert chunk to SSE format matching original HTTP client format chunk_dict = chunk.model_dump() chunk_json = json.dumps(chunk_dict, ensure_ascii=False) yield f"data: {chunk_json}" - + # Signal end of stream yield "data: [DONE]" - + except AuthenticationError as e: raise HTTPException(status_code=401, detail=self.classify_openai_error(str(e))) except RateLimitError as e: @@ -127,11 +124,11 @@ async def create_chat_completion_stream(self, request: Dict[str, Any], request_i except BadRequestError as e: raise HTTPException(status_code=400, detail=self.classify_openai_error(str(e))) except APIError as e: - status_code = getattr(e, 'status_code', 500) + status_code = getattr(e, "status_code", 500) raise HTTPException(status_code=status_code, detail=self.classify_openai_error(str(e))) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") - + finally: # Clean up active request tracking if request_id and request_id in self.active_requests: @@ -140,33 +137,36 @@ async def create_chat_completion_stream(self, request: Dict[str, Any], request_i def classify_openai_error(self, error_detail: Any) -> str: """Provide specific error guidance for common OpenAI API issues.""" error_str = str(error_detail).lower() - + # Region/country restrictions - if "unsupported_country_region_territory" in error_str or "country, region, or territory not supported" in error_str: + if ( + "unsupported_country_region_territory" in error_str + or "country, region, or territory not supported" in error_str + ): return "OpenAI API is not available in your region. Consider using a VPN or Azure OpenAI service." - + # API key issues if "invalid_api_key" in error_str or "unauthorized" in error_str: return "Invalid API key. Please check your OPENAI_API_KEY configuration." - + # Rate limiting if "rate_limit" in error_str or "quota" in error_str: return "Rate limit exceeded. Please wait and try again, or upgrade your API plan." - + # Model not found if "model" in error_str and ("not found" in error_str or "does not exist" in error_str): return "Model not found. Please check your BIG_MODEL and SMALL_MODEL configuration." - + # Billing issues if "billing" in error_str or "payment" in error_str: return "Billing issue. Please check your OpenAI account billing status." - + # Default: return original message return str(error_detail) - + def cancel_request(self, request_id: str) -> bool: """Cancel an active request by request_id.""" if request_id in self.active_requests: self.active_requests[request_id].set() return True - return False \ No newline at end of file + return False diff --git a/src/core/config.py b/src/core/config.py index 7254d3c6..f2c99a96 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -1,18 +1,19 @@ import os import sys + # Configuration class Config: def __init__(self): self.openai_api_key = os.environ.get("OPENAI_API_KEY") if not self.openai_api_key: raise ValueError("OPENAI_API_KEY not found in environment variables") - + # Add Anthropic API key for client validation self.anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY") if not self.anthropic_api_key: print("Warning: ANTHROPIC_API_KEY not set. Client API key validation will be disabled.") - + self.openai_base_url = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") self.azure_api_version = os.environ.get("AZURE_API_VERSION") # For Azure OpenAI self.host = os.environ.get("HOST", "0.0.0.0") @@ -20,34 +21,35 @@ def __init__(self): self.log_level = os.environ.get("LOG_LEVEL", "INFO") self.max_tokens_limit = int(os.environ.get("MAX_TOKENS_LIMIT", "4096")) self.min_tokens_limit = int(os.environ.get("MIN_TOKENS_LIMIT", "100")) - + # Connection settings self.request_timeout = int(os.environ.get("REQUEST_TIMEOUT", "90")) self.max_retries = int(os.environ.get("MAX_RETRIES", "2")) - + # Model settings - BIG and SMALL models self.big_model = os.environ.get("BIG_MODEL", "gpt-4o") self.middle_model = os.environ.get("MIDDLE_MODEL", self.big_model) self.small_model = os.environ.get("SMALL_MODEL", "gpt-4o-mini") - + def validate_api_key(self): """Basic API key validation""" if not self.openai_api_key: return False # Basic format check for OpenAI API keys - if not self.openai_api_key.startswith('sk-'): + if not self.openai_api_key.startswith("sk-"): return False return True - + def validate_client_api_key(self, client_api_key): """Validate client's Anthropic API key""" # If no ANTHROPIC_API_KEY is set in the environment, skip validation if not self.anthropic_api_key: return True - + # Check if the client's API key matches the expected value return client_api_key == self.anthropic_api_key + try: config = Config() print(f" Configuration loaded: API_KEY={'*' * 20}..., BASE_URL='{config.openai_base_url}'") diff --git a/src/core/constants.py b/src/core/constants.py index 737f557c..589895ff 100644 --- a/src/core/constants.py +++ b/src/core/constants.py @@ -1,22 +1,22 @@ -# Constants for better maintainability +# Constants for better maintainability class Constants: ROLE_USER = "user" ROLE_ASSISTANT = "assistant" ROLE_SYSTEM = "system" ROLE_TOOL = "tool" - + CONTENT_TEXT = "text" CONTENT_IMAGE = "image" CONTENT_TOOL_USE = "tool_use" CONTENT_TOOL_RESULT = "tool_result" - + TOOL_FUNCTION = "function" - + STOP_END_TURN = "end_turn" STOP_MAX_TOKENS = "max_tokens" STOP_TOOL_USE = "tool_use" STOP_ERROR = "error" - + EVENT_MESSAGE_START = "message_start" EVENT_MESSAGE_STOP = "message_stop" EVENT_MESSAGE_DELTA = "message_delta" @@ -24,6 +24,6 @@ class Constants: EVENT_CONTENT_BLOCK_STOP = "content_block_stop" EVENT_CONTENT_BLOCK_DELTA = "content_block_delta" EVENT_PING = "ping" - + DELTA_TEXT = "text_delta" - DELTA_INPUT_JSON = "input_json_delta" \ No newline at end of file + DELTA_INPUT_JSON = "input_json_delta" diff --git a/src/core/logging.py b/src/core/logging.py index 87376bb9..0e6690e1 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -1,21 +1,22 @@ import logging + from src.core.config import config # Parse log level - extract just the first word to handle comments log_level = config.log_level.split()[0].upper() # Validate and set default if invalid -valid_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] +valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] if log_level not in valid_levels: - log_level = 'INFO' + log_level = "INFO" # Logging Configuration logging.basicConfig( level=getattr(logging, log_level), - format='%(asctime)s - %(levelname)s - %(message)s', + format="%(asctime)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Configure uvicorn to be quieter for uvicorn_logger in ["uvicorn", "uvicorn.access", "uvicorn.error"]: - logging.getLogger(uvicorn_logger).setLevel(logging.WARNING) \ No newline at end of file + logging.getLogger(uvicorn_logger).setLevel(logging.WARNING) diff --git a/src/core/model_manager.py b/src/core/model_manager.py index 5495f317..ed9e5b00 100644 --- a/src/core/model_manager.py +++ b/src/core/model_manager.py @@ -1,30 +1,43 @@ from src.core.config import config + class ModelManager: def __init__(self, config): self.config = config - + def map_claude_model_to_openai(self, claude_model: str) -> str: """Map Claude model names to OpenAI model names based on BIG/SMALL pattern""" # If it's already an OpenAI model, return as-is - if claude_model.startswith("gpt-") or claude_model.startswith("o1-"): + if ( + claude_model.startswith("gpt-") + or claude_model.startswith("o1-") + or self.is_o3_model(claude_model) + ): return claude_model # If it's other supported models (ARK/Doubao/DeepSeek), return as-is - if (claude_model.startswith("ep-") or claude_model.startswith("doubao-") or - claude_model.startswith("deepseek-")): + if ( + claude_model.startswith("ep-") + or claude_model.startswith("doubao-") + or claude_model.startswith("deepseek-") + ): return claude_model - + # Map based on model naming patterns model_lower = claude_model.lower() - if 'haiku' in model_lower: + if "haiku" in model_lower: return self.config.small_model - elif 'sonnet' in model_lower: + elif "sonnet" in model_lower: return self.config.middle_model - elif 'opus' in model_lower: + elif "opus" in model_lower: return self.config.big_model else: # Default to big model for unknown models return self.config.big_model -model_manager = ModelManager(config) \ No newline at end of file + def is_o3_model(self, model: str) -> bool: + """Check if the model is an o3 model that requires special handling""" + return model.startswith("o3") + + +model_manager = ModelManager(config) diff --git a/src/main.py b/src/main.py index 8f1e0f34..69ee59da 100644 --- a/src/main.py +++ b/src/main.py @@ -1,7 +1,9 @@ +import sys + +import uvicorn from fastapi import FastAPI + from src.api.endpoints import router as api_router -import uvicorn -import sys from src.core.config import config app = FastAPI(title="Claude-to-OpenAI API Proxy", version="1.0.0") @@ -21,18 +23,16 @@ def main(): print("Optional environment variables:") print(" ANTHROPIC_API_KEY - Expected Anthropic API key for client validation") print(" If set, clients must provide this exact API key") - print( - f" OPENAI_BASE_URL - OpenAI API base URL (default: https://api.openai.com/v1)" - ) - print(f" BIG_MODEL - Model for opus requests (default: gpt-4o)") - print(f" MIDDLE_MODEL - Model for sonnet requests (default: gpt-4o)") - print(f" SMALL_MODEL - Model for haiku requests (default: gpt-4o-mini)") - print(f" HOST - Server host (default: 0.0.0.0)") - print(f" PORT - Server port (default: 8082)") - print(f" LOG_LEVEL - Logging level (default: WARNING)") - print(f" MAX_TOKENS_LIMIT - Token limit (default: 4096)") - print(f" MIN_TOKENS_LIMIT - Minimum token limit (default: 100)") - print(f" REQUEST_TIMEOUT - Request timeout in seconds (default: 90)") + print(" OPENAI_BASE_URL - OpenAI API base URL (default: https://api.openai.com/v1)") + print(" BIG_MODEL - Model for opus requests (default: gpt-4o)") + print(" MIDDLE_MODEL - Model for sonnet requests (default: gpt-4o)") + print(" SMALL_MODEL - Model for haiku requests (default: gpt-4o-mini)") + print(" HOST - Server host (default: 0.0.0.0)") + print(" PORT - Server port (default: 8082)") + print(" LOG_LEVEL - Logging level (default: WARNING)") + print(" MAX_TOKENS_LIMIT - Token limit (default: 4096)") + print(" MIN_TOKENS_LIMIT - Minimum token limit (default: 100)") + print(" REQUEST_TIMEOUT - Request timeout in seconds (default: 90)") print("") print("Model mapping:") print(f" Claude haiku models -> {config.small_model}") @@ -41,7 +41,7 @@ def main(): # Configuration summary print("๐Ÿš€ Claude-to-OpenAI API Proxy v1.0.0") - print(f"โœ… Configuration loaded successfully") + print("โœ… Configuration loaded successfully") print(f" OpenAI Base URL: {config.openai_base_url}") print(f" Big Model (opus): {config.big_model}") print(f" Middle Model (sonnet): {config.middle_model}") @@ -50,15 +50,15 @@ def main(): print(f" Request Timeout: {config.request_timeout}s") print(f" Server: {config.host}:{config.port}") print(f" Client API Key Validation: {'Enabled' if config.anthropic_api_key else 'Disabled'}") - print("") + print("aaa") # Parse log level - extract just the first word to handle comments log_level = config.log_level.split()[0].lower() - + # Validate and set default if invalid - valid_levels = ['debug', 'info', 'warning', 'error', 'critical'] + valid_levels = ["debug", "info", "warning", "error", "critical"] if log_level not in valid_levels: - log_level = 'info' + log_level = "info" # Start server uvicorn.run( diff --git a/start_proxy.py b/start_proxy.py index 713708cd..4fb79508 100644 --- a/start_proxy.py +++ b/start_proxy.py @@ -1,13 +1,13 @@ #!/usr/bin/env python3 """Start Claude Code Proxy server.""" -import sys import os +import sys # Add src to Python path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) from src.main import main if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/test_cancellation.py b/test_cancellation.py index f4d7da7c..63b46597 100644 --- a/test_cancellation.py +++ b/test_cancellation.py @@ -5,14 +5,14 @@ """ import asyncio + import httpx -import json -import time + async def test_non_streaming_cancellation(): """Test cancellation for non-streaming requests.""" print("๐Ÿงช Testing non-streaming request cancellation...") - + async with httpx.AsyncClient(timeout=30) as client: try: # Start a long-running request @@ -23,29 +23,33 @@ async def test_non_streaming_cancellation(): "model": "claude-3-5-sonnet-20241022", "max_tokens": 1000, "messages": [ - {"role": "user", "content": "Write a very long story about a journey through space that takes at least 500 words."} - ] - } + { + "role": "user", + "content": "Write a very long story about a journey through space that takes at least 500 words.", + } + ], + }, ) ) - + # Cancel after 2 seconds await asyncio.sleep(2) task.cancel() - + try: await task print("โŒ Request should have been cancelled") except asyncio.CancelledError: print("โœ… Non-streaming request cancelled successfully") - + except Exception as e: print(f"โŒ Non-streaming test error: {e}") + async def test_streaming_cancellation(): """Test cancellation for streaming requests.""" print("\n๐Ÿงช Testing streaming request cancellation...") - + async with httpx.AsyncClient(timeout=30) as client: try: # Start streaming request @@ -56,37 +60,41 @@ async def test_streaming_cancellation(): "model": "claude-3-5-sonnet-20241022", "max_tokens": 1000, "messages": [ - {"role": "user", "content": "Write a very long story about a journey through space that takes at least 500 words."} + { + "role": "user", + "content": "Write a very long story about a journey through space that takes at least 500 words.", + } ], - "stream": True - } + "stream": True, + }, ) as response: if response.status_code == 200: print("โœ… Streaming request started successfully") - + # Read a few chunks then simulate client disconnect chunk_count = 0 async for line in response.aiter_lines(): if line.strip(): chunk_count += 1 print(f"๐Ÿ“ฆ Received chunk {chunk_count}: {line[:100]}...") - + # Simulate client disconnect after 3 chunks if chunk_count >= 3: print("๐Ÿ”Œ Simulating client disconnect...") break - + print("โœ… Streaming request cancelled successfully") else: print(f"โŒ Streaming request failed: {response.status_code}") - + except Exception as e: print(f"โŒ Streaming test error: {e}") + async def test_server_running(): """Test if the server is running.""" print("๐Ÿ” Checking if server is running...") - + try: async with httpx.AsyncClient(timeout=5) as client: response = await client.get("http://localhost:8082/health") @@ -101,23 +109,24 @@ async def test_server_running(): print("๐Ÿ’ก Make sure to start the server with: python start_proxy.py") return False + async def main(): """Main test function.""" print("๐Ÿš€ Starting HTTP request cancellation tests") print("=" * 50) - + # Check if server is running if not await test_server_running(): return - + print("\n" + "=" * 50) - + # Test non-streaming cancellation await test_non_streaming_cancellation() - - # Test streaming cancellation + + # Test streaming cancellation await test_streaming_cancellation() - + print("\n" + "=" * 50) print("โœ… All cancellation tests completed!") print("\n๐Ÿ’ก Note: The actual cancellation behavior depends on:") @@ -126,5 +135,6 @@ async def main(): print(" - Server response to client disconnection") print(" - Whether the underlying OpenAI API supports cancellation") + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/tests/test_main.py b/tests/test_main.py index 3f8212db..0deb687e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2,6 +2,7 @@ import asyncio import json + import httpx from dotenv import load_dotenv @@ -16,12 +17,10 @@ async def test_basic_chat(): json={ "model": "claude-3-5-sonnet-20241022", "max_tokens": 100, - "messages": [ - {"role": "user", "content": "Hello, how are you?"} - ] - } + "messages": [{"role": "user", "content": "Hello, how are you?"}], + }, ) - + print("Basic chat response:") print(json.dumps(response.json(), indent=2)) @@ -35,11 +34,9 @@ async def test_streaming_chat(): json={ "model": "claude-3-5-haiku-20241022", "max_tokens": 150, - "messages": [ - {"role": "user", "content": "Tell me a short joke"} - ], - "stream": True - } + "messages": [{"role": "user", "content": "Tell me a short joke"}], + "stream": True, + }, ) as response: print("\nStreaming response:") async for line in response.aiter_lines(): @@ -56,7 +53,10 @@ async def test_function_calling(): "model": "claude-3-5-sonnet-20241022", "max_tokens": 200, "messages": [ - {"role": "user", "content": "What's the weather like in New York? Please use the weather function."} + { + "role": "user", + "content": "What's the weather like in New York? Please use the weather function.", + } ], "tools": [ { @@ -67,22 +67,22 @@ async def test_function_calling(): "properties": { "location": { "type": "string", - "description": "The location to get weather for" + "description": "The location to get weather for", }, "unit": { "type": "string", "enum": ["celsius", "fahrenheit"], - "description": "Temperature unit" - } + "description": "Temperature unit", + }, }, - "required": ["location"] - } + "required": ["location"], + }, } ], - "tool_choice": {"type": "auto"} - } + "tool_choice": {"type": "auto"}, + }, ) - + print("\nFunction calling response:") print(json.dumps(response.json(), indent=2)) @@ -94,14 +94,11 @@ async def test_with_system_message(): "http://localhost:8082/v1/messages", json={ "model": "claude-3-5-sonnet-20241022", - "max_tokens": 100, "system": "You are a helpful assistant that always responds in haiku format.", - "messages": [ - {"role": "user", "content": "Explain what AI is"} - ] - } + "messages": [{"role": "user", "content": "Explain what AI is"}], + }, ) - + print("\nSystem message response:") print(json.dumps(response.json(), indent=2)) @@ -111,7 +108,7 @@ async def test_multimodal(): async with httpx.AsyncClient() as client: # Sample base64 image (1x1 pixel transparent PNG) sample_image = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChAI9jU8PJAAAAASUVORK5CYII=" - + response = await client.post( "http://localhost:8082/v1/messages", json={ @@ -127,15 +124,15 @@ async def test_multimodal(): "source": { "type": "base64", "media_type": "image/png", - "data": sample_image - } - } - ] + "data": sample_image, + }, + }, + ], } - ] - } + ], + }, ) - + print("\nMultimodal response:") print(json.dumps(response.json(), indent=2)) @@ -150,7 +147,10 @@ async def test_conversation_with_tool_use(): "model": "claude-3-5-sonnet-20241022", "max_tokens": 200, "messages": [ - {"role": "user", "content": "Calculate 25 * 4 using the calculator tool"} + { + "role": "user", + "content": "Calculate 25 * 4 using the calculator tool", + } ], "tools": [ { @@ -161,26 +161,28 @@ async def test_conversation_with_tool_use(): "properties": { "expression": { "type": "string", - "description": "Mathematical expression to calculate" + "description": "Mathematical expression to calculate", } }, - "required": ["expression"] - } + "required": ["expression"], + }, } - ] - } + ], + }, ) - + print("\nTool call response:") result1 = response1.json() print(json.dumps(result1, indent=2)) - + # Simulate tool execution and send result if result1.get("content"): - tool_use_blocks = [block for block in result1["content"] if block.get("type") == "tool_use"] + tool_use_blocks = [ + block for block in result1["content"] if block.get("type") == "tool_use" + ] if tool_use_blocks: tool_block = tool_use_blocks[0] - + # Second message with tool result response2 = await client.post( "http://localhost:8082/v1/messages", @@ -188,7 +190,10 @@ async def test_conversation_with_tool_use(): "model": "claude-3-5-sonnet-20241022", "max_tokens": 100, "messages": [ - {"role": "user", "content": "Calculate 25 * 4 using the calculator tool"}, + { + "role": "user", + "content": "Calculate 25 * 4 using the calculator tool", + }, {"role": "assistant", "content": result1["content"]}, { "role": "user", @@ -196,14 +201,14 @@ async def test_conversation_with_tool_use(): { "type": "tool_result", "tool_use_id": tool_block["id"], - "content": "100" + "content": "100", } - ] - } - ] - } + ], + }, + ], + }, ) - + print("\nTool result response:") print(json.dumps(response2.json(), indent=2)) @@ -216,11 +221,14 @@ async def test_token_counting(): json={ "model": "claude-3-5-sonnet-20241022", "messages": [ - {"role": "user", "content": "This is a test message for token counting."} - ] - } + { + "role": "user", + "content": "This is a test message for token counting.", + } + ], + }, ) - + print("\nToken count response:") print(json.dumps(response.json(), indent=2)) @@ -232,7 +240,7 @@ async def test_health_and_connection(): health_response = await client.get("http://localhost:8082/health") print("\nHealth check:") print(json.dumps(health_response.json(), indent=2)) - + # Connection test connection_response = await client.get("http://localhost:8082/test-connection") print("\nConnection test:") @@ -243,7 +251,7 @@ async def main(): """Run all tests.""" print("๐Ÿงช Testing Claude to OpenAI Proxy") print("=" * 50) - + try: await test_health_and_connection() await test_token_counting() @@ -253,13 +261,13 @@ async def main(): await test_multimodal() await test_function_calling() await test_conversation_with_tool_use() - + print("\nโœ… All tests completed!") - + except Exception as e: print(f"\nโŒ Test failed: {e}") print("Make sure the server is running with a valid OPENAI_API_KEY") if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main())