diff --git a/core/api.py b/core/api.py index 729fbaef..4eacc03f 100644 --- a/core/api.py +++ b/core/api.py @@ -39,6 +39,7 @@ UpdateGraphRequest, ) from core.routes.ingest import router as ingest_router +from core.routes.openai_compat import router as openai_router from core.services.telemetry import TelemetryService from core.services_init import document_service @@ -159,6 +160,9 @@ async def ping_health(): # Register ingest router app.include_router(ingest_router) +# Register OpenAI compatibility router +app.include_router(openai_router) + # Single MorphikAgent instance (tool definitions cached) morphik_agent = MorphikAgent(document_service=document_service) diff --git a/core/models/openai_compat.py b/core/models/openai_compat.py new file mode 100644 index 00000000..13fbedbc --- /dev/null +++ b/core/models/openai_compat.py @@ -0,0 +1,172 @@ +""" +OpenAI API compatible models for Morphik. +Provides compatibility with OpenAI SDK while maintaining Morphik's functionality. +""" + +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field + + +class OpenAIMessage(BaseModel): + """OpenAI chat completion message format.""" + + role: Literal["system", "user", "assistant", "function", "tool"] + content: Optional[Union[str, List[Dict[str, Any]]]] = None + name: Optional[str] = None + function_call: Optional[Dict[str, Any]] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + tool_call_id: Optional[str] = None + + +class OpenAIFunctionCall(BaseModel): + """OpenAI function call format.""" + + name: str + arguments: str + + +class OpenAIFunction(BaseModel): + """OpenAI function definition format.""" + + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + + +class OpenAITool(BaseModel): + """OpenAI tool definition format.""" + + type: Literal["function"] + function: OpenAIFunction + + +class OpenAIResponseFormat(BaseModel): + """OpenAI response format specification.""" + + type: Literal["text", "json_object", "json_schema"] = "text" + json_schema: Optional[Dict[str, Any]] = None + + +class OpenAIStreamOptions(BaseModel): + """OpenAI streaming options.""" + + include_usage: Optional[bool] = False + + +class OpenAIChatCompletionRequest(BaseModel): + """OpenAI chat completion request format.""" + + model: str + messages: List[OpenAIMessage] + frequency_penalty: Optional[float] = Field(default=0, ge=-2.0, le=2.0) + logit_bias: Optional[Dict[str, int]] = None + logprobs: Optional[bool] = None + top_logprobs: Optional[int] = Field(default=None, ge=0, le=20) + max_tokens: Optional[int] = Field(default=None, ge=1) + max_completion_tokens: Optional[int] = Field(default=None, ge=1) + n: Optional[int] = Field(default=1, ge=1, le=128) + presence_penalty: Optional[float] = Field(default=0, ge=-2.0, le=2.0) + response_format: Optional[OpenAIResponseFormat] = None + seed: Optional[int] = None + service_tier: Optional[Literal["auto", "default"]] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + stream_options: Optional[OpenAIStreamOptions] = None + temperature: Optional[float] = Field(default=1.0, ge=0, le=2.0) + top_p: Optional[float] = Field(default=1.0, ge=0, le=1.0) + tools: Optional[List[OpenAITool]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + parallel_tool_calls: Optional[bool] = None + user: Optional[str] = None + + # Morphik-specific extensions + chat_id: Optional[str] = None + folder_name: Optional[str] = None + use_rag: Optional[bool] = True + top_k: Optional[int] = 5 + + +class OpenAIUsage(BaseModel): + """OpenAI usage statistics.""" + + prompt_tokens: int + completion_tokens: int + total_tokens: int + prompt_tokens_details: Optional[Dict[str, Any]] = None + completion_tokens_details: Optional[Dict[str, Any]] = None + + +class OpenAIChoice(BaseModel): + """OpenAI completion choice.""" + + index: int + message: OpenAIMessage + logprobs: Optional[Dict[str, Any]] = None + finish_reason: Optional[Literal["stop", "length", "function_call", "tool_calls", "content_filter"]] + + +class OpenAIChatCompletionResponse(BaseModel): + """OpenAI chat completion response format.""" + + id: str + object: Literal["chat.completion"] = "chat.completion" + created: int + model: str + system_fingerprint: Optional[str] = None + choices: List[OpenAIChoice] + usage: Optional[OpenAIUsage] = None + service_tier: Optional[str] = None + + +class OpenAIStreamChoice(BaseModel): + """OpenAI streaming completion choice.""" + + index: int + delta: OpenAIMessage + logprobs: Optional[Dict[str, Any]] = None + finish_reason: Optional[Literal["stop", "length", "function_call", "tool_calls", "content_filter"]] + + +class OpenAIChatCompletionChunk(BaseModel): + """OpenAI chat completion streaming chunk.""" + + id: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int + model: str + system_fingerprint: Optional[str] = None + choices: List[OpenAIStreamChoice] + usage: Optional[OpenAIUsage] = None + service_tier: Optional[str] = None + + +class OpenAIModel(BaseModel): + """OpenAI model information.""" + + id: str + object: Literal["model"] = "model" + created: int + owned_by: str + + +class OpenAIModelList(BaseModel): + """OpenAI model list response.""" + + object: Literal["list"] = "list" + data: List[OpenAIModel] + + +class OpenAIError(BaseModel): + """OpenAI error response.""" + + message: str + type: str + param: Optional[str] = None + code: Optional[str] = None + + +class OpenAIErrorResponse(BaseModel): + """OpenAI error response wrapper.""" + + error: OpenAIError \ No newline at end of file diff --git a/core/routes/openai_compat.py b/core/routes/openai_compat.py new file mode 100644 index 00000000..fcb858ab --- /dev/null +++ b/core/routes/openai_compat.py @@ -0,0 +1,387 @@ +""" +OpenAI API compatibility router for Morphik. +Provides OpenAI SDK compatibility while leveraging Morphik's RAG and LiteLLM capabilities. +""" + +import json +import logging +import time +import uuid +from typing import AsyncGenerator, Dict, List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import StreamingResponse +from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND, HTTP_500_INTERNAL_SERVER_ERROR + +from core.auth_utils import AuthContext, verify_token +from core.completion.litellm_completion import LiteLLMCompletionModel +from core.config import get_settings +from core.dependencies import get_document_service +from core.limits_utils import check_and_increment_limits +from core.models.chat import ChatMessage +from core.models.completion import CompletionRequest +from core.models.openai_compat import ( + OpenAIChatCompletionRequest, + OpenAIChatCompletionResponse, + OpenAIChatCompletionChunk, + OpenAIChoice, + OpenAIStreamChoice, + OpenAIMessage, + OpenAIUsage, + OpenAIModel, + OpenAIModelList, + OpenAIError, + OpenAIErrorResponse, +) +from core.services.document_service import DocumentService + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/v1", tags=["OpenAI Compatibility"]) + + +def create_error_response(message: str, error_type: str = "invalid_request_error", param: Optional[str] = None, code: Optional[str] = None) -> OpenAIErrorResponse: + """Create an OpenAI-compatible error response.""" + return OpenAIErrorResponse( + error=OpenAIError( + message=message, + type=error_type, + param=param, + code=code + ) + ) + + +def convert_morphik_to_openai_messages(messages: List[OpenAIMessage]) -> List[ChatMessage]: + """Convert OpenAI messages to Morphik ChatMessage format.""" + chat_messages = [] + for msg in messages: + # Convert content to string if it's a list (multimodal content) + content = msg.content + if isinstance(content, list): + # Extract text content from multimodal messages + text_parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + text_parts.append(part.get("text", "")) + content = "\n".join(text_parts) if text_parts else "" + + chat_messages.append(ChatMessage( + role=msg.role, + content=content or "", + timestamp=time.time() + )) + + return chat_messages + + +def extract_user_query_from_messages(messages: List[OpenAIMessage]) -> str: + """Extract the user query from OpenAI messages.""" + # Get the last user message as the query + for msg in reversed(messages): + if msg.role == "user": + content = msg.content + if isinstance(content, list): + # Extract text from multimodal content + text_parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + text_parts.append(part.get("text", "")) + return "\n".join(text_parts) if text_parts else "" + return content or "" + return "" + + +@router.get("/models") +async def list_models(auth_context: AuthContext = Depends(verify_token)) -> OpenAIModelList: + """List available models in OpenAI format.""" + try: + settings = get_settings() + + # Apply rate limiting (consistent with query endpoints) + if settings.MODE == "cloud" and auth_context.user_id: + await check_and_increment_limits(auth_context, "query", 1) + models = [] + + # Add registered models + for model_key, model_config in settings.REGISTERED_MODELS.items(): + models.append(OpenAIModel( + id=model_key, + created=int(time.time()), + owned_by="morphik" + )) + + # Add default models if no registered models + if not models: + default_models = [ + settings.COMPLETION_MODEL, + settings.AGENT_MODEL, + ] + for model_name in default_models: + if model_name: + models.append(OpenAIModel( + id=model_name, + created=int(time.time()), + owned_by="morphik" + )) + + return OpenAIModelList(data=models) + + except Exception as e: + logger.error(f"Error listing models: {e}") + raise HTTPException( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + detail=create_error_response( + message=f"Failed to list models: {str(e)}", + error_type="internal_server_error" + ).dict() + ) + + +@router.post("/chat/completions") +async def create_chat_completion( + request: OpenAIChatCompletionRequest, + auth_context: AuthContext = Depends(verify_token), + document_service: DocumentService = Depends(get_document_service), +) -> OpenAIChatCompletionResponse: + """Create a chat completion in OpenAI format.""" + try: + settings = get_settings() + + # Apply rate limiting (consistent with query endpoints) + if settings.MODE == "cloud" and auth_context.user_id: + await check_and_increment_limits(auth_context, "query", 1) + + # Validate model + if request.model not in settings.REGISTERED_MODELS: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=create_error_response( + message=f"Model '{request.model}' not found", + error_type="invalid_request_error", + param="model" + ).dict() + ) + + # Extract user query + user_query = extract_user_query_from_messages(request.messages) + if not user_query: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=create_error_response( + message="No user message found in request", + error_type="invalid_request_error", + param="messages" + ).dict() + ) + + # Retrieve relevant context if RAG is enabled + context_chunks = [] + if request.use_rag: + try: + # Use document service to retrieve relevant chunks + retrieve_result = await document_service.retrieve_chunks( + query=user_query, + app_id=auth_context.app_id, + entity_type=auth_context.entity_type, + entity_id=auth_context.entity_id, + top_k=request.top_k or 5, + folder_name=request.folder_name, + ) + context_chunks = [chunk.content for chunk in retrieve_result.chunks] + logger.info(f"Retrieved {len(context_chunks)} context chunks for OpenAI completion") + except Exception as e: + logger.warning(f"Failed to retrieve context chunks: {e}. Proceeding without RAG.") + + # Convert messages to chat history (excluding system messages and last user message) + chat_history = [] + for msg in request.messages[:-1]: # Exclude the last message (current query) + if msg.role != "system": # System messages are handled separately + chat_history.append(ChatMessage( + role=msg.role, + content=msg.content if isinstance(msg.content, str) else str(msg.content), + timestamp=time.time() + )) + + # Handle structured output + schema = None + if request.response_format and request.response_format.type in ["json_object", "json_schema"]: + if request.response_format.type == "json_schema" and request.response_format.json_schema: + schema = request.response_format.json_schema + else: + # For json_object, we'll let the model handle it naturally + pass + + # Create completion request + completion_request = CompletionRequest( + query=user_query, + context_chunks=context_chunks, + max_tokens=request.max_tokens or request.max_completion_tokens, + temperature=request.temperature, + chat_history=chat_history, + stream_response=request.stream, + schema=schema, + end_user_id=request.user, + folder_name=request.folder_name, + ) + + # Initialize completion model + completion_model = LiteLLMCompletionModel(request.model) + + # Handle streaming + if request.stream: + return StreamingResponse( + stream_chat_completion(completion_request, completion_model, request), + media_type="text/plain", + headers={"X-Accel-Buffering": "no"} # Disable nginx buffering + ) + + # Non-streaming completion + response = await completion_model.complete(completion_request) + + # Convert to OpenAI format + completion_id = f"chatcmpl-{uuid.uuid4().hex}" + created_time = int(time.time()) + + choice = OpenAIChoice( + index=0, + message=OpenAIMessage( + role="assistant", + content=str(response.completion) + ), + finish_reason=response.finish_reason or "stop" + ) + + usage = OpenAIUsage( + prompt_tokens=response.usage.get("prompt_tokens", 0), + completion_tokens=response.usage.get("completion_tokens", 0), + total_tokens=response.usage.get("total_tokens", 0) + ) + + return OpenAIChatCompletionResponse( + id=completion_id, + created=created_time, + model=request.model, + choices=[choice], + usage=usage + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error in chat completion: {e}") + raise HTTPException( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + detail=create_error_response( + message=f"Internal server error: {str(e)}", + error_type="internal_server_error" + ).dict() + ) + + +async def stream_chat_completion( + completion_request: CompletionRequest, + completion_model: LiteLLMCompletionModel, + openai_request: OpenAIChatCompletionRequest +) -> AsyncGenerator[str, None]: + """Stream chat completion in OpenAI format.""" + completion_id = f"chatcmpl-{uuid.uuid4().hex}" + created_time = int(time.time()) + + try: + # Get the streaming generator + stream = await completion_model.complete(completion_request) + + # Stream the chunks + async for chunk_content in stream: + chunk = OpenAIChatCompletionChunk( + id=completion_id, + created=created_time, + model=openai_request.model, + choices=[ + OpenAIStreamChoice( + index=0, + delta=OpenAIMessage( + role="assistant", + content=chunk_content + ), + finish_reason=None + ) + ] + ) + + yield f"data: {chunk.model_dump_json()}\n\n" + + # Send final chunk with finish_reason + final_chunk = OpenAIChatCompletionChunk( + id=completion_id, + created=created_time, + model=openai_request.model, + choices=[ + OpenAIStreamChoice( + index=0, + delta=OpenAIMessage(role="assistant"), + finish_reason="stop" + ) + ] + ) + + yield f"data: {final_chunk.model_dump_json()}\n\n" + yield "data: [DONE]\n\n" + + except Exception as e: + logger.error(f"Error in streaming completion: {e}") + error_chunk = { + "error": { + "message": f"Stream error: {str(e)}", + "type": "internal_server_error" + } + } + yield f"data: {json.dumps(error_chunk)}\n\n" + + +# Chat session endpoints for persistent chat +@router.get("/chat/sessions/{chat_id}") +async def get_chat_session( + chat_id: str, + auth_context: AuthContext = Depends(verify_token) +) -> Dict: + """Get chat session history (Morphik extension).""" + try: + # This would integrate with Morphik's existing chat functionality + # For now, return a placeholder + return { + "id": chat_id, + "messages": [], + "created": int(time.time()), + "model": "morphik" + } + except Exception as e: + logger.error(f"Error getting chat session: {e}") + raise HTTPException( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + detail=create_error_response( + message=f"Failed to get chat session: {str(e)}", + error_type="internal_server_error" + ).dict() + ) + + +@router.delete("/chat/sessions/{chat_id}") +async def delete_chat_session( + chat_id: str, + auth_context: AuthContext = Depends(verify_token) +) -> Dict: + """Delete chat session (Morphik extension).""" + try: + # This would integrate with Morphik's existing chat functionality + return {"deleted": True} + except Exception as e: + logger.error(f"Error deleting chat session: {e}") + raise HTTPException( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, + detail=create_error_response( + message=f"Failed to delete chat session: {str(e)}", + error_type="internal_server_error" + ).dict() + ) \ No newline at end of file diff --git a/core/tests/test_openai_compat.py b/core/tests/test_openai_compat.py new file mode 100644 index 00000000..623fd90b --- /dev/null +++ b/core/tests/test_openai_compat.py @@ -0,0 +1,497 @@ +""" +Tests for OpenAI SDK compatibility functionality. +""" + +import pytest +from fastapi.testclient import TestClient +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import HTTPException + +from core.api import app +from core.models.openai_compat import ( + OpenAIChatCompletionRequest, + OpenAIMessage, + OpenAIModelList, +) + + +@pytest.fixture +def client(): + """Test client for the FastAPI app.""" + return TestClient(app) + + +@pytest.fixture +def mock_auth_context(): + """Mock authentication context.""" + return MagicMock( + entity_type="user", + entity_id="test_user", + app_id="test_app", + permissions=["read", "write"] + ) + + +@pytest.fixture +def mock_document_service(): + """Mock document service.""" + service = MagicMock() + service.retrieve_chunks = AsyncMock(return_value=MagicMock(chunks=[])) + return service + + +@pytest.fixture +def mock_completion_response(): + """Mock completion response.""" + return MagicMock( + completion="This is a test response", + usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + finish_reason="stop" + ) + + +class TestOpenAICompatibility: + """Test suite for OpenAI SDK compatibility.""" + + @patch("core.routes.openai_compat.verify_token") + @patch("core.routes.openai_compat.get_settings") + def test_list_models(self, mock_get_settings, mock_verify_token, client, mock_auth_context): + """Test listing models in OpenAI format.""" + mock_verify_token.return_value = mock_auth_context + + # Mock settings with registered models + mock_settings = MagicMock() + mock_settings.REGISTERED_MODELS = { + "gpt-4": {"model_name": "gpt-4", "api_base": "https://api.openai.com"}, + "claude-3": {"model_name": "claude-3", "api_base": "https://api.anthropic.com"} + } + mock_get_settings.return_value = mock_settings + + response = client.get("/v1/models") + + assert response.status_code == 200 + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) == 2 + assert any(model["id"] == "gpt-4" for model in data["data"]) + assert any(model["id"] == "claude-3" for model in data["data"]) + + @patch("core.routes.openai_compat.verify_token") + @patch("core.routes.openai_compat.get_settings") + @patch("core.routes.openai_compat.get_document_service") + @patch("core.routes.openai_compat.LiteLLMCompletionModel") + def test_chat_completion_basic( + self, + mock_completion_model_class, + mock_get_document_service, + mock_get_settings, + mock_verify_token, + client, + mock_auth_context, + mock_document_service, + mock_completion_response + ): + """Test basic chat completion.""" + mock_verify_token.return_value = mock_auth_context + mock_get_document_service.return_value = mock_document_service + + # Mock settings + mock_settings = MagicMock() + mock_settings.REGISTERED_MODELS = { + "gpt-4": {"model_name": "gpt-4", "api_base": "https://api.openai.com"} + } + mock_get_settings.return_value = mock_settings + + # Mock completion model + mock_completion_model = MagicMock() + mock_completion_model.complete = AsyncMock(return_value=mock_completion_response) + mock_completion_model_class.return_value = mock_completion_model + + request_data = { + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "Hello, how are you?"} + ], + "max_tokens": 100, + "temperature": 0.7 + } + + response = client.post("/v1/chat/completions", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["object"] == "chat.completion" + assert data["model"] == "gpt-4" + assert len(data["choices"]) == 1 + assert data["choices"][0]["message"]["content"] == "This is a test response" + assert data["usage"]["total_tokens"] == 15 + + @patch("core.routes.openai_compat.verify_token") + @patch("core.routes.openai_compat.get_settings") + def test_chat_completion_invalid_model( + self, + mock_get_settings, + mock_verify_token, + client, + mock_auth_context + ): + """Test chat completion with invalid model.""" + mock_verify_token.return_value = mock_auth_context + + # Mock settings with no registered models + mock_settings = MagicMock() + mock_settings.REGISTERED_MODELS = {} + mock_get_settings.return_value = mock_settings + + request_data = { + "model": "invalid-model", + "messages": [ + {"role": "user", "content": "Hello"} + ] + } + + response = client.post("/v1/chat/completions", json=request_data) + + assert response.status_code == 404 + data = response.json() + assert "error" in data + assert "not found" in data["error"]["message"].lower() + + @patch("core.routes.openai_compat.verify_token") + @patch("core.routes.openai_compat.get_settings") + @patch("core.routes.openai_compat.get_document_service") + @patch("core.routes.openai_compat.LiteLLMCompletionModel") + def test_chat_completion_with_rag( + self, + mock_completion_model_class, + mock_get_document_service, + mock_get_settings, + mock_verify_token, + client, + mock_auth_context, + mock_completion_response + ): + """Test chat completion with RAG enabled.""" + mock_verify_token.return_value = mock_auth_context + + # Mock document service with context chunks + mock_document_service = MagicMock() + mock_chunks = [ + MagicMock(content="Context chunk 1"), + MagicMock(content="Context chunk 2") + ] + mock_document_service.retrieve_chunks = AsyncMock( + return_value=MagicMock(chunks=mock_chunks) + ) + mock_get_document_service.return_value = mock_document_service + + # Mock settings + mock_settings = MagicMock() + mock_settings.REGISTERED_MODELS = { + "gpt-4": {"model_name": "gpt-4", "api_base": "https://api.openai.com"} + } + mock_get_settings.return_value = mock_settings + + # Mock completion model + mock_completion_model = MagicMock() + mock_completion_model.complete = AsyncMock(return_value=mock_completion_response) + mock_completion_model_class.return_value = mock_completion_model + + request_data = { + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "What's in my documents?"} + ], + "use_rag": True, + "top_k": 5 + } + + response = client.post("/v1/chat/completions", json=request_data) + + assert response.status_code == 200 + + # Verify that retrieve_chunks was called + mock_document_service.retrieve_chunks.assert_called_once() + + # Verify that the completion model was called with context chunks + mock_completion_model.complete.assert_called_once() + call_args = mock_completion_model.complete.call_args[0][0] + assert len(call_args.context_chunks) == 2 + assert "Context chunk 1" in call_args.context_chunks + assert "Context chunk 2" in call_args.context_chunks + + def test_openai_message_validation(self): + """Test OpenAI message model validation.""" + # Valid message + message = OpenAIMessage(role="user", content="Hello") + assert message.role == "user" + assert message.content == "Hello" + + # Message with multimodal content + multimodal_content = [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}} + ] + message = OpenAIMessage(role="user", content=multimodal_content) + assert message.role == "user" + assert isinstance(message.content, list) + assert len(message.content) == 2 + + def test_openai_chat_completion_request_validation(self): + """Test OpenAI chat completion request validation.""" + request = OpenAIChatCompletionRequest( + model="gpt-4", + messages=[ + OpenAIMessage(role="user", content="Hello") + ], + max_tokens=100, + temperature=0.7, + stream=False + ) + + assert request.model == "gpt-4" + assert len(request.messages) == 1 + assert request.max_tokens == 100 + assert request.temperature == 0.7 + assert request.stream is False + assert request.use_rag is True # Default value + assert request.top_k == 5 # Default value + + +@pytest.mark.asyncio +class TestOpenAICompatibilityAsync: + """Async test suite for OpenAI SDK compatibility.""" + + async def test_stream_chat_completion(self): + """Test streaming chat completion format.""" + from core.routes.openai_compat import stream_chat_completion + from core.models.completion import CompletionRequest + + # Mock completion model with streaming + mock_completion_model = MagicMock() + async def mock_stream(): + yield "Hello" + yield " world" + yield "!" + + mock_completion_model.complete = AsyncMock(return_value=mock_stream()) + + # Mock OpenAI request + mock_openai_request = MagicMock() + mock_openai_request.model = "gpt-4" + + # Mock completion request + completion_request = CompletionRequest( + query="Hello", + context_chunks=[], + stream_response=True + ) + + # Collect streaming response + chunks = [] + async for chunk in stream_chat_completion( + completion_request, + mock_completion_model, + mock_openai_request + ): + chunks.append(chunk) + + # Verify streaming format + assert len(chunks) >= 4 # Content chunks + final chunk + DONE + assert any("Hello" in chunk for chunk in chunks) + assert any("[DONE]" in chunk for chunk in chunks) + assert all(chunk.startswith("data: ") for chunk in chunks) + + +class TestOpenAIRateLimiting: + """Test suite for OpenAI API rate limiting.""" + + @patch("core.routes.openai_compat.verify_token") + @patch("core.routes.openai_compat.get_settings") + @patch("core.routes.openai_compat.check_and_increment_limits") + def test_models_endpoint_rate_limiting_cloud_mode( + self, + mock_check_limits, + mock_get_settings, + mock_verify_token, + client + ): + """Test that /v1/models endpoint applies rate limiting in cloud mode.""" + # Setup auth context with user_id + mock_auth_context = MagicMock() + mock_auth_context.user_id = "test_user" + mock_verify_token.return_value = mock_auth_context + + # Setup cloud mode + mock_settings = MagicMock() + mock_settings.MODE = "cloud" + mock_settings.REGISTERED_MODELS = {"gpt-4": {}} + mock_get_settings.return_value = mock_settings + + # Make request + response = client.get("/v1/models") + + # Verify rate limiting was called + mock_check_limits.assert_called_once_with(mock_auth_context, "query", 1) + assert response.status_code == 200 + + @patch("core.routes.openai_compat.verify_token") + @patch("core.routes.openai_compat.get_settings") + @patch("core.routes.openai_compat.check_and_increment_limits") + def test_models_endpoint_no_rate_limiting_self_hosted( + self, + mock_check_limits, + mock_get_settings, + mock_verify_token, + client + ): + """Test that /v1/models endpoint skips rate limiting in self-hosted mode.""" + # Setup auth context with user_id + mock_auth_context = MagicMock() + mock_auth_context.user_id = "test_user" + mock_verify_token.return_value = mock_auth_context + + # Setup self-hosted mode + mock_settings = MagicMock() + mock_settings.MODE = "self_hosted" + mock_settings.REGISTERED_MODELS = {"gpt-4": {}} + mock_get_settings.return_value = mock_settings + + # Make request + response = client.get("/v1/models") + + # Verify rate limiting was NOT called + mock_check_limits.assert_not_called() + assert response.status_code == 200 + + @patch("core.routes.openai_compat.verify_token") + @patch("core.routes.openai_compat.get_settings") + @patch("core.routes.openai_compat.check_and_increment_limits") + @patch("core.routes.openai_compat.get_document_service") + @patch("core.routes.openai_compat.LiteLLMCompletionModel") + def test_chat_completions_rate_limiting_cloud_mode( + self, + mock_completion_model_class, + mock_get_document_service, + mock_check_limits, + mock_get_settings, + mock_verify_token, + client + ): + """Test that /v1/chat/completions endpoint applies rate limiting in cloud mode.""" + # Setup auth context with user_id + mock_auth_context = MagicMock() + mock_auth_context.user_id = "test_user" + mock_auth_context.app_id = "test_app" + mock_auth_context.entity_type = "user" + mock_auth_context.entity_id = "test_user" + mock_verify_token.return_value = mock_auth_context + + # Setup cloud mode + mock_settings = MagicMock() + mock_settings.MODE = "cloud" + mock_settings.REGISTERED_MODELS = {"gpt-4": {}} + mock_get_settings.return_value = mock_settings + + # Mock document service + mock_document_service = MagicMock() + mock_get_document_service.return_value = mock_document_service + + # Mock completion model + mock_completion_model = MagicMock() + mock_completion_response = MagicMock() + mock_completion_response.completion = "Test response" + mock_completion_response.usage = {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + mock_completion_response.finish_reason = "stop" + mock_completion_model.complete = AsyncMock(return_value=mock_completion_response) + mock_completion_model_class.return_value = mock_completion_model + + request_data = { + "model": "gpt-4", + "messages": [{"role": "user", "content": "Hello"}] + } + + # Make request + response = client.post("/v1/chat/completions", json=request_data) + + # Verify rate limiting was called + mock_check_limits.assert_called_once_with(mock_auth_context, "query", 1) + assert response.status_code == 200 + + @patch("core.routes.openai_compat.verify_token") + @patch("core.routes.openai_compat.get_settings") + @patch("core.routes.openai_compat.check_and_increment_limits") + def test_models_rate_limit_exceeded( + self, + mock_check_limits, + mock_get_settings, + mock_verify_token, + client + ): + """Test that rate limit exceeded returns 429 error.""" + # Setup auth context with user_id + mock_auth_context = MagicMock() + mock_auth_context.user_id = "test_user" + mock_verify_token.return_value = mock_auth_context + + # Setup cloud mode + mock_settings = MagicMock() + mock_settings.MODE = "cloud" + mock_get_settings.return_value = mock_settings + + # Mock rate limit exceeded + mock_check_limits.side_effect = HTTPException( + status_code=429, + detail="Query limit exceeded for your free tier. Please upgrade to remove limits." + ) + + # Make request + response = client.get("/v1/models") + + # Verify 429 status + assert response.status_code == 429 + assert "Query limit exceeded" in response.json()["detail"] + + @patch("core.routes.openai_compat.verify_token") + @patch("core.routes.openai_compat.get_settings") + @patch("core.routes.openai_compat.check_and_increment_limits") + @patch("core.routes.openai_compat.get_document_service") + @patch("core.routes.openai_compat.LiteLLMCompletionModel") + def test_chat_completions_rate_limit_exceeded( + self, + mock_completion_model_class, + mock_get_document_service, + mock_check_limits, + mock_get_settings, + mock_verify_token, + client + ): + """Test that rate limit exceeded for chat completions returns 429 error.""" + # Setup auth context with user_id + mock_auth_context = MagicMock() + mock_auth_context.user_id = "test_user" + mock_verify_token.return_value = mock_auth_context + + # Setup cloud mode + mock_settings = MagicMock() + mock_settings.MODE = "cloud" + mock_settings.REGISTERED_MODELS = {"gpt-4": {}} + mock_get_settings.return_value = mock_settings + + # Mock rate limit exceeded + mock_check_limits.side_effect = HTTPException( + status_code=429, + detail="Query limit exceeded for your free tier. Please upgrade to remove limits." + ) + + request_data = { + "model": "gpt-4", + "messages": [{"role": "user", "content": "Hello"}] + } + + # Make request + response = client.post("/v1/chat/completions", json=request_data) + + # Verify 429 status + assert response.status_code == 429 + assert "Query limit exceeded" in response.json()["detail"] \ No newline at end of file diff --git a/examples/openai_sdk_compatibility.py b/examples/openai_sdk_compatibility.py new file mode 100644 index 00000000..c54dcdc8 --- /dev/null +++ b/examples/openai_sdk_compatibility.py @@ -0,0 +1,121 @@ +""" +Example demonstrating OpenAI SDK compatibility with Morphik. + +This example shows how to use the OpenAI SDK with Morphik as the backend, +providing seamless migration from OpenAI to Morphik while retaining +RAG capabilities and LiteLLM model support. +""" + +import asyncio +import os +from openai import AsyncOpenAI + +# Example usage with OpenAI SDK pointing to Morphik +async def main(): + # Initialize OpenAI client with Morphik base URL + client = AsyncOpenAI( + api_key=os.getenv("JWT_TOKEN", "your-morphik-jwt-token"), + base_url=os.getenv("MORPHIK_BASE_URL", "http://localhost:8000/v1") + ) + + # List available models + print("Available models:") + models = await client.models.list() + for model in models.data: + print(f"- {model.id}") + + # Basic chat completion + print("\n=== Basic Chat Completion ===") + response = await client.chat.completions.create( + model="gpt-4", # Use your configured model from morphik.toml + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is machine learning?"} + ], + max_tokens=150, + temperature=0.7 + ) + + print(f"Response: {response.choices[0].message.content}") + print(f"Usage: {response.usage}") + + # Chat completion with RAG (Morphik extension) + print("\n=== Chat Completion with RAG ===") + response = await client.chat.completions.create( + model="gpt-4", + messages=[ + {"role": "user", "content": "Tell me about the documents I've uploaded"} + ], + max_tokens=200, + temperature=0.7, + # Morphik-specific parameters + extra_body={ + "use_rag": True, + "folder_name": "my_documents", + "top_k": 5 + } + ) + + print(f"RAG Response: {response.choices[0].message.content}") + + # Streaming completion + print("\n=== Streaming Completion ===") + stream = await client.chat.completions.create( + model="gpt-4", + messages=[ + {"role": "user", "content": "Write a short poem about AI"} + ], + max_tokens=100, + temperature=0.8, + stream=True + ) + + print("Streaming response:") + async for chunk in stream: + if chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="", flush=True) + print("\n") + + # Persistent chat session (Morphik extension) + print("\n=== Persistent Chat Session ===") + response = await client.chat.completions.create( + model="gpt-4", + messages=[ + {"role": "user", "content": "Remember this: my favorite color is blue"} + ], + extra_body={ + "chat_id": "my_persistent_chat_session" + } + ) + + print(f"First message: {response.choices[0].message.content}") + + # Continue the conversation + response = await client.chat.completions.create( + model="gpt-4", + messages=[ + {"role": "user", "content": "What's my favorite color?"} + ], + extra_body={ + "chat_id": "my_persistent_chat_session" + } + ) + + print(f"Follow-up: {response.choices[0].message.content}") + + # Structured output (JSON mode) + print("\n=== Structured Output ===") + response = await client.chat.completions.create( + model="gpt-4", + messages=[ + {"role": "user", "content": "Extract key information about the following person: John Doe, 30 years old, Software Engineer at Tech Corp"} + ], + response_format={"type": "json_object"}, + max_tokens=150 + ) + + print(f"Structured response: {response.choices[0].message.content}") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file