Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,11 @@ packages = ["src"]
python_version = "3.9"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
disallow_untyped_defs = true


[tool.ruff]
line-length = 100

[tool.ruff.format]
quote-style = "double"
70 changes: 38 additions & 32 deletions src/api/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from fastapi import APIRouter, HTTPException, Request, Header, Depends
from fastapi.responses import JSONResponse, StreamingResponse
from datetime import datetime
import uuid
from datetime import datetime
from typing import Optional

from src.core.config import config
from src.core.logging import logger
from src.core.client import OpenAIClient
from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse

from src.conversion.request_converter import convert_claude_to_openai
from src.conversion.response_converter import (
convert_openai_to_claude_response,
convert_openai_streaming_to_claude_with_cancellation,
convert_openai_to_claude_response,
)
from src.core.client import OpenAIClient
from src.core.config import config
from src.core.logging import logger
from src.core.model_manager import model_manager
from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest

router = APIRouter()

Expand All @@ -24,34 +25,40 @@
api_version=config.azure_api_version,
)

async def validate_api_key(x_api_key: Optional[str] = Header(None), authorization: Optional[str] = Header(None)):

async def validate_api_key(
x_api_key: Optional[str] = Header(None), authorization: Optional[str] = Header(None)
):
"""Validate the client's API key from either x-api-key header or Authorization header."""
client_api_key = None

# Extract API key from headers
if x_api_key:
client_api_key = x_api_key
elif authorization and authorization.startswith("Bearer "):
client_api_key = authorization.replace("Bearer ", "")

# Skip validation if ANTHROPIC_API_KEY is not set in the environment
if not config.anthropic_api_key:
return

# Validate the client API key
if not client_api_key or not config.validate_client_api_key(client_api_key):
logger.warning(f"Invalid API key provided by client")
logger.warning("Invalid API key provided by client")
raise HTTPException(
status_code=401,
detail="Invalid API key. Please provide a valid Anthropic API key."
detail="Invalid API key. Please provide a valid Anthropic API key.",
)


@router.post("/v1/messages")
async def create_message(request: ClaudeMessagesRequest, http_request: Request, _: None = Depends(validate_api_key)):
async def create_message(
request: ClaudeMessagesRequest,
http_request: Request,
_: None = Depends(validate_api_key),
):
try:
logger.debug(
f"Processing Claude request: model={request.model}, stream={request.stream}"
)
logger.info(f"Processing Claude request: model={request.model}, stream={request.stream}")

# Generate unique request ID for cancellation tracking
request_id = str(uuid.uuid4())
Expand Down Expand Up @@ -100,12 +107,8 @@ async def create_message(request: ClaudeMessagesRequest, http_request: Request,
return JSONResponse(status_code=e.status_code, content=error_response)
else:
# Non-streaming response
openai_response = await openai_client.create_chat_completion(
openai_request, request_id
)
claude_response = convert_openai_to_claude_response(
openai_response, request
)
openai_response = await openai_client.create_chat_completion(openai_request, request_id)
claude_response = convert_openai_to_claude_response(openai_response, request)
return claude_response
except HTTPException:
raise
Expand All @@ -123,7 +126,6 @@ async def count_tokens(request: ClaudeTokenCountRequest, _: None = Depends(valid
try:
# For token counting, we'll use a simple estimation
# In a real implementation, you might want to use tiktoken or similar

total_chars = 0

# Count system message characters
Expand Down Expand Up @@ -173,13 +175,17 @@ async def test_connection():
"""Test API connectivity to OpenAI"""
try:
# Simple test request to verify API connectivity
test_response = await openai_client.create_chat_completion(
{
"model": config.small_model,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 5,
}
)
openai_request = {
"model": config.small_model,
"messages": [{"role": "user", "content": "Hello"}],
}
if model_manager.is_o3_model(config.small_model):
openai_request["max_completion_tokens"] = 200
openai_request["temperature"] = 1
else:
openai_request["max_tokens"] = 5
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why max_tokens is 5?


test_response = await openai_client.create_chat_completion(openai_request)

return {
"status": "success",
Expand Down
40 changes: 22 additions & 18 deletions src/conversion/request_converter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
from typing import Dict, Any, List
import logging
from typing import Any, Dict, List
from venv import logger
from src.core.constants import Constants
from src.models.claude import ClaudeMessagesRequest, ClaudeMessage

from src.core.config import config
import logging
from src.core.constants import Constants
from src.models.claude import ClaudeMessage, ClaudeMessagesRequest

logger = logging.getLogger(__name__)

Expand All @@ -30,17 +31,12 @@ def convert_claude_to_openai(
for block in claude_request.system:
if hasattr(block, "type") and block.type == Constants.CONTENT_TEXT:
text_parts.append(block.text)
elif (
isinstance(block, dict)
and block.get("type") == Constants.CONTENT_TEXT
):
elif isinstance(block, dict) and block.get("type") == Constants.CONTENT_TEXT:
text_parts.append(block.get("text", ""))
system_text = "\n\n".join(text_parts)

if system_text.strip():
openai_messages.append(
{"role": Constants.ROLE_SYSTEM, "content": system_text.strip()}
)
openai_messages.append({"role": Constants.ROLE_SYSTEM, "content": system_text.strip()})

# Process Claude messages
i = 0
Expand Down Expand Up @@ -77,13 +73,21 @@ def convert_claude_to_openai(
openai_request = {
"model": openai_model,
"messages": openai_messages,
"max_tokens": min(
max(claude_request.max_tokens, config.min_tokens_limit),
config.max_tokens_limit,
),
"temperature": claude_request.temperature,
"stream": claude_request.stream,
}

# Handle max tokens based on model type
max_tokens_value = min(
max(claude_request.max_tokens, config.min_tokens_limit),
config.max_tokens_limit,
)
if model_manager.is_o3_model(openai_model):
openai_request["max_completion_tokens"] = max_tokens_value
openai_request["temperature"] = 1
else:
openai_request["max_tokens"] = max_tokens_value
openai_request["temperature"] = claude_request.temperature

logger.debug(
f"Converted Claude request to OpenAI format: {json.dumps(openai_request, indent=2, ensure_ascii=False)}"
)
Expand Down Expand Up @@ -133,7 +137,7 @@ def convert_claude_user_message(msg: ClaudeMessage) -> Dict[str, Any]:
"""Convert Claude user message to OpenAI format."""
if msg.content is None:
return {"role": Constants.ROLE_USER, "content": ""}

if isinstance(msg.content, str):
return {"role": Constants.ROLE_USER, "content": msg.content}

Expand Down Expand Up @@ -172,7 +176,7 @@ def convert_claude_assistant_message(msg: ClaudeMessage) -> Dict[str, Any]:

if msg.content is None:
return {"role": Constants.ROLE_ASSISTANT, "content": None}

if isinstance(msg.content, str):
return {"role": Constants.ROLE_ASSISTANT, "content": msg.content}

Expand Down
78 changes: 42 additions & 36 deletions src/conversion/response_converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import uuid

from fastapi import HTTPException, Request

from src.core.constants import Constants
from src.models.claude import ClaudeMessagesRequest

Expand Down Expand Up @@ -69,9 +71,7 @@ def convert_openai_to_claude_response(
"stop_sequence": None,
"usage": {
"input_tokens": openai_response.get("usage", {}).get("prompt_tokens", 0),
"output_tokens": openai_response.get("usage", {}).get(
"completion_tokens", 0
),
"output_tokens": openai_response.get("usage", {}).get("completion_tokens", 0),
},
}

Expand Down Expand Up @@ -112,9 +112,7 @@ async def convert_openai_streaming_to_claude(
if not choices:
continue
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse chunk: {chunk_data}, error: {e}"
)
logger.warning(f"Failed to parse chunk: {chunk_data}, error: {e}")
continue

choice = choices[0]
Expand All @@ -129,7 +127,7 @@ async def convert_openai_streaming_to_claude(
if "tool_calls" in delta:
for tc_delta in delta["tool_calls"]:
tc_index = tc_delta.get("index", 0)

# Initialize tool call tracking by index if not exists
if tc_index not in current_tool_calls:
current_tool_calls[tc_index] = {
Expand All @@ -138,33 +136,37 @@ async def convert_openai_streaming_to_claude(
"args_buffer": "",
"json_sent": False,
"claude_index": None,
"started": False
"started": False,
}

tool_call = current_tool_calls[tc_index]

# Update tool call ID if provided
if tc_delta.get("id"):
tool_call["id"] = tc_delta["id"]

# Update function name and start content block if we have both id and name
function_data = tc_delta.get(Constants.TOOL_FUNCTION, {})
if function_data.get("name"):
tool_call["name"] = function_data["name"]

# Start content block when we have complete initial data
if (tool_call["id"] and tool_call["name"] and not tool_call["started"]):
if tool_call["id"] and tool_call["name"] and not tool_call["started"]:
tool_block_counter += 1
claude_index = text_block_index + tool_block_counter
tool_call["claude_index"] = claude_index
tool_call["started"] = True

yield f"event: {Constants.EVENT_CONTENT_BLOCK_START}\ndata: {json.dumps({'type': Constants.EVENT_CONTENT_BLOCK_START, 'index': claude_index, 'content_block': {'type': Constants.CONTENT_TOOL_USE, 'id': tool_call['id'], 'name': tool_call['name'], 'input': {}}}, ensure_ascii=False)}\n\n"

# Handle function arguments
if "arguments" in function_data and tool_call["started"] and function_data["arguments"] is not None:
if (
"arguments" in function_data
and tool_call["started"]
and function_data["arguments"] is not None
):
tool_call["args_buffer"] += function_data["arguments"]

# Try to parse complete JSON and send delta when we have valid JSON
try:
json.loads(tool_call["args_buffer"])
Expand Down Expand Up @@ -259,21 +261,21 @@ async def convert_openai_streaming_to_claude_with_cancellation(
usage = chunk.get("usage", None)
if usage:
cache_read_input_tokens = 0
prompt_tokens_details = usage.get('prompt_tokens_details', {})
prompt_tokens_details = usage.get("prompt_tokens_details", {})
if prompt_tokens_details:
cache_read_input_tokens = prompt_tokens_details.get('cached_tokens', 0)
cache_read_input_tokens = prompt_tokens_details.get(
"cached_tokens", 0
)
usage_data = {
'input_tokens': usage.get('prompt_tokens', 0),
'output_tokens': usage.get('completion_tokens', 0),
'cache_read_input_tokens': cache_read_input_tokens
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"cache_read_input_tokens": cache_read_input_tokens,
}
choices = chunk.get("choices", [])
if not choices:
continue
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse chunk: {chunk_data}, error: {e}"
)
logger.warning(f"Failed to parse chunk: {chunk_data}, error: {e}")
continue

choice = choices[0]
Expand All @@ -288,7 +290,7 @@ async def convert_openai_streaming_to_claude_with_cancellation(
if "tool_calls" in delta and delta["tool_calls"]:
for tc_delta in delta["tool_calls"]:
tc_index = tc_delta.get("index", 0)

# Initialize tool call tracking by index if not exists
if tc_index not in current_tool_calls:
current_tool_calls[tc_index] = {
Expand All @@ -297,33 +299,37 @@ async def convert_openai_streaming_to_claude_with_cancellation(
"args_buffer": "",
"json_sent": False,
"claude_index": None,
"started": False
"started": False,
}

tool_call = current_tool_calls[tc_index]

# Update tool call ID if provided
if tc_delta.get("id"):
tool_call["id"] = tc_delta["id"]

# Update function name and start content block if we have both id and name
function_data = tc_delta.get(Constants.TOOL_FUNCTION, {})
if function_data.get("name"):
tool_call["name"] = function_data["name"]

# Start content block when we have complete initial data
if (tool_call["id"] and tool_call["name"] and not tool_call["started"]):
if tool_call["id"] and tool_call["name"] and not tool_call["started"]:
tool_block_counter += 1
claude_index = text_block_index + tool_block_counter
tool_call["claude_index"] = claude_index
tool_call["started"] = True

yield f"event: {Constants.EVENT_CONTENT_BLOCK_START}\ndata: {json.dumps({'type': Constants.EVENT_CONTENT_BLOCK_START, 'index': claude_index, 'content_block': {'type': Constants.CONTENT_TOOL_USE, 'id': tool_call['id'], 'name': tool_call['name'], 'input': {}}}, ensure_ascii=False)}\n\n"

# Handle function arguments
if "arguments" in function_data and tool_call["started"] and function_data["arguments"] is not None:
if (
"arguments" in function_data
and tool_call["started"]
and function_data["arguments"] is not None
):
tool_call["args_buffer"] += function_data["arguments"]

# Try to parse complete JSON and send delta when we have valid JSON
try:
json.loads(tool_call["args_buffer"])
Expand Down
Loading