diff --git a/docs/docs/snippets/cli_options.md b/docs/docs/snippets/cli_options.md index 2b837a3b..ad92328c 100644 --- a/docs/docs/snippets/cli_options.md +++ b/docs/docs/snippets/cli_options.md @@ -1,4 +1,5 @@ + ```console Usage: readmeai [OPTIONS] @@ -9,7 +10,7 @@ Options: -a, --align [center|left|right] align for the README.md file header sections. - --api [anthropic|gemini|ollama|openai|offline] + --api [anthropic|gemini|ollama|openai|azure|offline] LLM API service provider to power the README file generation. -bc, --badge-color TEXT Primary color (hex code or name) to use for @@ -50,4 +51,5 @@ Options: generated for the README file. --help Show this message and exit. ``` + diff --git a/readmeai/models/azure.py b/readmeai/models/azure.py new file mode 100644 index 00000000..4b4550e0 --- /dev/null +++ b/readmeai/models/azure.py @@ -0,0 +1,84 @@ +"""Azure OpenAI model handler for ReadmeAI.""" +import os +from typing import Any + +import aiohttp +import openai +from openai import AsyncAzureOpenAI + +from readmeai.config.settings import ConfigLoader +from readmeai.extractors.models import RepositoryContext +from readmeai.models.openai_ import OpenAIHandler +from readmeai.models.tokens import token_handler + + +class AzureOpenAIHandler(OpenAIHandler): + """Handler for Azure OpenAI models.""" + + def __init__(self, config_loader: ConfigLoader, context: RepositoryContext) -> None: + super().__init__(config_loader, context) + self._model_settings() + + def _model_settings(self): + self.host_name = os.getenv("AZURE_ENDPOINT") + self.max_tokens = self.config.llm.tokens + self.model = os.getenv("AZURE_MODEL") + self.top_p = self.config.llm.top_p + + self.client = AsyncAzureOpenAI( + azure_endpoint=self.host_name, + api_key=os.getenv("AZURE_API_KEY"), + api_version=os.getenv("AZURE_API_VERSION"), + ) + + async def _make_request( + self, + index: str | None, + prompt: str | None, + tokens: int | None, + repo_files: Any, + ): + """Process requests to OpenAI API, with retries and error handling.""" + + try: + if prompt is None: + raise ValueError("Prompt cannot be None") + + prompt = await token_handler( + config=self.config, + index=index, + prompt=prompt, + tokens=tokens, + ) + if not prompt: + raise ValueError("Token handler returned empty prompt") + + if index == "file_summary": + self.max_tokens = 100 + + parameters = await self._build_payload(prompt) + + # Just await the create call directly + response = await self.client.chat.completions.create(**parameters) + content = response.choices[0].message.content + + if not content: + raise ValueError("Empty response from API") + + self._logger.info( + f"Response from {self.config.llm.api.capitalize()} for '{index}': {content}", + ) + return index, content + + except ( + aiohttp.ClientError, + aiohttp.ClientResponseError, + aiohttp.ClientConnectorError, + openai.OpenAIError, + ) as e: + self._logger.error(f"Error processing request for '{index}': {e!r}") + raise # Re-raise for retry decorator + + except Exception as e: + self._logger.error(f"Unexpected error for '{index}': {e!r}") + return index, self.placeholder diff --git a/readmeai/models/enums.py b/readmeai/models/enums.py index 6c6948cd..d4eb111c 100644 --- a/readmeai/models/enums.py +++ b/readmeai/models/enums.py @@ -10,6 +10,7 @@ class LLMAuthKeys(str, Enum): GOOGLE_API_KEY = "GOOGLE_API_KEY" OLLAMA_HOST = "OLLAMA_HOST" OPENAI_API_KEY = "OPENAI_API_KEY" + AZURE_API_KEY = "AZURE_API_KEY" class LLMProviders(str, Enum): @@ -22,6 +23,7 @@ class LLMProviders(str, Enum): OLLAMA = "ollama" OPENAI = "openai" OFFLINE = "offline" + AZURE = "azure" class AnthropicModels(str, Enum): diff --git a/readmeai/models/factory.py b/readmeai/models/factory.py index c1b70cfc..fec77f3e 100644 --- a/readmeai/models/factory.py +++ b/readmeai/models/factory.py @@ -4,11 +4,12 @@ from readmeai.core.errors import UnsupportedServiceError from readmeai.extractors.models import RepositoryContext from readmeai.models.anthropic import AnthropicHandler +from readmeai.models.azure import AzureOpenAIHandler from readmeai.models.base import BaseModelHandler from readmeai.models.enums import LLMProviders from readmeai.models.gemini import GeminiHandler from readmeai.models.offline import OfflineHandler -from readmeai.models.openai import OpenAIHandler +from readmeai.models.openai_ import OpenAIHandler class ModelFactory: @@ -22,18 +23,15 @@ class ModelFactory: LLMProviders.OLLAMA.value: OpenAIHandler, LLMProviders.OPENAI.value: OpenAIHandler, LLMProviders.OFFLINE.value: OfflineHandler, + LLMProviders.AZURE.value: AzureOpenAIHandler, } @staticmethod - def get_backend( - config: ConfigLoader, context: RepositoryContext - ) -> BaseModelHandler: + def get_backend(config: ConfigLoader, context: RepositoryContext) -> BaseModelHandler: """Retrieves configured LLM API handler instance.""" llm_service = ModelFactory._model_map.get(config.config.llm.api) if llm_service is None: - raise UnsupportedServiceError( - f"Unsupported LLM provider: {config.config.llm.api}" - ) + raise UnsupportedServiceError(f"Unsupported LLM provider: {config.config.llm.api}") return llm_service(config, context) diff --git a/readmeai/models/openai.py b/readmeai/models/openai_.py similarity index 100% rename from readmeai/models/openai.py rename to readmeai/models/openai_.py diff --git a/tests/cli/test_main.py b/tests/cli/test_main.py index 8896d270..42707411 100644 --- a/tests/cli/test_main.py +++ b/tests/cli/test_main.py @@ -1,8 +1,8 @@ -from unittest.mock import MagicMock, patch - import pytest from _pytest._py.path import LocalPath from click.testing import CliRunner +from unittest.mock import MagicMock, patch + from readmeai.cli.main import main from readmeai.config.settings import ConfigLoader @@ -24,13 +24,13 @@ def mock_readme_agent(): def test_main_command_basic( - cli_runner: CliRunner, - mock_config_loader: ConfigLoader, - output_file_path: str, + cli_runner: CliRunner, + mock_config_loader: ConfigLoader, + output_file_path: str, ): with patch( - "readmeai.config.settings.ConfigLoader", - return_value=mock_config_loader, + "readmeai.config.settings.ConfigLoader", + return_value=mock_config_loader, ): result = cli_runner.invoke( main, @@ -46,9 +46,9 @@ def test_main_command_basic( def test_main_command_all_options( - cli_runner: CliRunner, - mock_config_loader: ConfigLoader, - output_file_path: str, + cli_runner: CliRunner, + mock_config_loader: ConfigLoader, + output_file_path: str, ): mock_config = mock_config_loader mock_config.config.git.repository = "https://github.com/eli64s/readme-ai-streamlit" @@ -108,9 +108,10 @@ def test_version_option(cli_runner: CliRunner): [ ("--align", "invalid", "Invalid value for '-a' / '--align'"), ( - "--api", - "invalid", - "Invalid value for '--api': 'invalid' is not one of 'anthropic', 'gemini', 'ollama', 'openai', 'offline'.", + "--api", + "invalid", + "Invalid value for '--api': 'invalid' is not one of 'anthropic', 'gemini', 'ollama', 'openai', 'offline', " + "'azure'.", ), # ( # "--badge-color", @@ -118,51 +119,51 @@ def test_version_option(cli_runner: CliRunner): # "Invalid value for '-bc' / '--badge-color'", # ), ( - "--badge-style", - "invalid", - "Invalid value for '-bs' / '--badge-style'", + "--badge-style", + "invalid", + "Invalid value for '-bs' / '--badge-style'", ), ( - "--context-window", - "invalid", - "Invalid value for '-cw' / '--context-window'", + "--context-window", + "invalid", + "Invalid value for '-cw' / '--context-window'", ), ( - "--header-style", - "invalid", - "Invalid value for '-hs' / '--header-style'", + "--header-style", + "invalid", + "Invalid value for '-hs' / '--header-style'", ), ("--logo", "invalid", "Invalid value for '-l' / '--logo'"), # ("--model", "invalid", "Invalid value for '-m' / '--model'"), ( - "--rate-limit", - "invalid", - "Invalid value for '-rl' / '--rate-limit'", + "--rate-limit", + "invalid", + "Invalid value for '-rl' / '--rate-limit'", ), ( - "--temperature", - "invalid", - "Invalid value for '-t' / '--temperature'", + "--temperature", + "invalid", + "Invalid value for '-t' / '--temperature'", ), ( - "--navigation-style", - "invalid", - "Invalid value for '-ns' / '--navigation-style'", + "--navigation-style", + "invalid", + "Invalid value for '-ns' / '--navigation-style'", ), ( - "--top-p", - "invalid", - "Invalid value for '-tp' / '--top-p'", + "--top-p", + "invalid", + "Invalid value for '-tp' / '--top-p'", ), ( - "--tree-max-depth", - "invalid", - "Invalid value for '-td' / '--tree-max-depth'", + "--tree-max-depth", + "invalid", + "Invalid value for '-td' / '--tree-max-depth'", ), ], ) def test_invalid_option_values( - temp_dir: LocalPath, cli_runner: CliRunner, option, value, expected + temp_dir: LocalPath, cli_runner: CliRunner, option, value, expected ): result = cli_runner.invoke(main, ["--repository", str(temp_dir), option, value]) assert result.exit_code != 0 diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 7bfea39e..c48c21ed 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -4,12 +4,14 @@ import aiohttp import pytest + from readmeai.config.settings import ConfigLoader from readmeai.extractors.models import RepositoryContext from readmeai.models.anthropic import ANTHROPIC_AVAILABLE, AnthropicHandler +from readmeai.models.azure import AzureOpenAIHandler from readmeai.models.enums import GeminiModels from readmeai.models.gemini import GeminiHandler -from readmeai.models.openai import OpenAIHandler +from readmeai.models.openai_ import OpenAIHandler @pytest.fixture @@ -46,9 +48,7 @@ def mock_aiohttp_session(): @pytest.fixture -def anthropic_handler( - mock_config_loader: ConfigLoader, mock_repository_context: RepositoryContext -): +def anthropic_handler(mock_config_loader: ConfigLoader, mock_repository_context: RepositoryContext): if not ANTHROPIC_AVAILABLE: pytest.skip("Anthropic library is not available") context = mock_repository_context @@ -60,20 +60,18 @@ def anthropic_handler( @pytest.fixture def anthropic_handler_with_mock_session( - anthropic_handler: AnthropicHandler, monkeypatch: pytest.MonkeyPatch + anthropic_handler: AnthropicHandler, monkeypatch: pytest.MonkeyPatch ): monkeypatch.setenv("ANTHROPIC_API_KEY", "test_api_key") - mock_create = AsyncMock( - return_value=MagicMock(content=[MagicMock(text="test_response")]) - ) + mock_create = AsyncMock(return_value=MagicMock(content=[MagicMock(text="test_response")])) anthropic_handler.client.messages.create = mock_create return anthropic_handler @pytest.fixture def gemini_handler( - mock_config_loader: ConfigLoader, - mock_repository_context: RepositoryContext, + mock_config_loader: ConfigLoader, + mock_repository_context: RepositoryContext, ): """Fixture to provide a GeminiHandler instance.""" mock_config_loader.config.llm.model = GeminiModels.GEMINI_FLASH.value @@ -86,9 +84,9 @@ def gemini_handler( @pytest.fixture def openai_handler( - mock_config_loader: ConfigLoader, - mock_repository_context: RepositoryContext, - monkeypatch: pytest.MonkeyPatch, + mock_config_loader: ConfigLoader, + mock_repository_context: RepositoryContext, + monkeypatch: pytest.MonkeyPatch, ): """Fixture to provide an OpenAIHandler instance with a mocked API key.""" monkeypatch.setenv("OPENAI_API_KEY", "test_api_key") @@ -100,13 +98,39 @@ def openai_handler( @pytest.fixture def openai_handler_with_mock_session( - openai_handler: OpenAIHandler, mock_aiohttp_session: MagicMock + openai_handler: OpenAIHandler, mock_aiohttp_session: MagicMock ): """Fixture to provide an OpenAIHandler with a mocked session.""" openai_handler._session = mock_aiohttp_session return openai_handler +@pytest.fixture +def azure_openai_handler( + mock_config_loader: ConfigLoader, + mock_repository_context: RepositoryContext, +): + """Fixture to provide an AzureOpenAIHandler instance.""" + with patch.dict("os.environ", { + "AZURE_ENDPOINT": "https://test.azure.com", + "AZURE_API_KEY": "test_azure_api_key", + "AZURE_API_VERSION": "2023-05-15", + }): + return AzureOpenAIHandler( + config_loader=mock_config_loader, + context=mock_repository_context, + ) + + +@pytest.fixture +def azure_openai_handler_with_mock_session( + azure_openai_handler: AzureOpenAIHandler, mock_aiohttp_session: MagicMock +): + """Fixture to provide an AzureOpenAIHandler with a mocked session.""" + azure_openai_handler._session = mock_aiohttp_session + return azure_openai_handler + + @pytest.fixture def ollama_localhost(): """Fixture to provide a localhost URL for Ollama.""" diff --git a/tests/models/test_azure.py b/tests/models/test_azure.py new file mode 100644 index 00000000..c6c79217 --- /dev/null +++ b/tests/models/test_azure.py @@ -0,0 +1,10 @@ +from readmeai.models.azure import AzureOpenAIHandler + + +def test_azure_openai_handler_sets_attributes(azure_openai_handler: AzureOpenAIHandler): + """Test that the Azure Openai handler sets the correct attributes.""" + assert hasattr(azure_openai_handler, "host_name") + assert hasattr(azure_openai_handler, "model") + assert hasattr(azure_openai_handler, "max_tokens") + assert hasattr(azure_openai_handler, "top_p") + assert hasattr(azure_openai_handler, "client") diff --git a/tests/models/test_factory.py b/tests/models/test_factory.py index 2ac8437e..a37327af 100644 --- a/tests/models/test_factory.py +++ b/tests/models/test_factory.py @@ -1,6 +1,7 @@ from unittest.mock import patch import pytest + from readmeai.config.settings import ConfigLoader from readmeai.core.errors import UnsupportedServiceError from readmeai.extractors.models import RepositoryContext @@ -9,9 +10,9 @@ def test_get_backend_openai( - mock_config_loader: ConfigLoader, - mock_repository_context: RepositoryContext, - monkeypatch: pytest.MonkeyPatch, + mock_config_loader: ConfigLoader, + mock_repository_context: RepositoryContext, + monkeypatch: pytest.MonkeyPatch, ): """Test getting OpenAI backend with proper environment setup.""" mock_config_loader.config.llm.api = LLMProviders.OPENAI.value @@ -22,9 +23,9 @@ def test_get_backend_openai( def test_get_backend_anthropic( - mock_config_loader: ConfigLoader, - mock_repository_context: RepositoryContext, - monkeypatch: pytest.MonkeyPatch, + mock_config_loader: ConfigLoader, + mock_repository_context: RepositoryContext, + monkeypatch: pytest.MonkeyPatch, ): """Test getting Anthropic backend.""" mock_config_loader.config.llm.api = LLMProviders.ANTHROPIC.value @@ -35,8 +36,8 @@ def test_get_backend_anthropic( def test_get_backend_gemini( - mock_config_loader: ConfigLoader, - mock_repository_context: RepositoryContext, + mock_config_loader: ConfigLoader, + mock_repository_context: RepositoryContext, ): """Test getting Gemini backend.""" mock_config_loader.config.llm.api = LLMProviders.GEMINI.value @@ -47,8 +48,8 @@ def test_get_backend_gemini( def test_get_backend_offline( - mock_config_loader: ConfigLoader, - mock_repository_context: RepositoryContext, + mock_config_loader: ConfigLoader, + mock_repository_context: RepositoryContext, ): """Test getting Offline backend.""" mock_config_loader.config.llm.api = LLMProviders.OFFLINE.value @@ -57,8 +58,21 @@ def test_get_backend_offline( assert handler.__class__.__name__ == "OfflineHandler" +def test_get_backend_azure( + mock_config_loader: ConfigLoader, + mock_repository_context: RepositoryContext, + monkeypatch: pytest.MonkeyPatch, +): + """Test getting OpenAI backend with proper environment setup.""" + mock_config_loader.config.llm.api = LLMProviders.AZURE.value + monkeypatch.setenv("AZURE_API_KEY", "test_key") + handler = ModelFactory.get_backend(mock_config_loader, mock_repository_context) + assert handler is not None + assert handler.__class__.__name__ == "AzureOpenAIHandler" + + def test_get_backend_unsupported_service( - mock_config_loader: ConfigLoader, mock_repository_context: RepositoryContext + mock_config_loader: ConfigLoader, mock_repository_context: RepositoryContext ): """Test getting a backend with an unsupported service.""" diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index bdf6f09d..22edeade 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -3,9 +3,10 @@ import aiohttp import pytest import tenacity + from readmeai.config.settings import ConfigLoader, Settings from readmeai.models.enums import LLMProviders -from readmeai.models.openai import OpenAIHandler +from readmeai.models.openai_ import OpenAIHandler def test_openai_handler_sets_attributes(openai_handler: OpenAIHandler): @@ -16,8 +17,8 @@ def test_openai_handler_sets_attributes(openai_handler: OpenAIHandler): def test_openai_endpoint_configuration_for_openai( - mock_config_loader: ConfigLoader, - openai_handler: OpenAIHandler, + mock_config_loader: ConfigLoader, + openai_handler: OpenAIHandler, ): """Test that the correct endpoint is set for OpenAI API.""" mock_config_loader.config.llm.api = LLMProviders.OPENAI.value @@ -25,16 +26,13 @@ def test_openai_endpoint_configuration_for_openai( def test_openai_endpoint_configuration_for_ollama( - mock_config_loader: ConfigLoader, - ollama_localhost: str, + mock_config_loader: ConfigLoader, + ollama_localhost: str, ): """Test that the correct endpoint is set for OLLAMA.""" mock_config_loader.config.llm.api = LLMProviders.OLLAMA.value mock_config_loader.config.llm.localhost = ollama_localhost - assert ( - "v1/chat/completions" - in f"{mock_config_loader.config.llm.localhost}v1/chat/completions" - ) + assert "v1/chat/completions" in f"{mock_config_loader.config.llm.localhost}v1/chat/completions" @pytest.mark.asyncio @@ -96,9 +94,7 @@ async def test_openai_make_request_without_context(openai_handler: OpenAIHandler @pytest.mark.asyncio -async def test_make_request_error_handling( - mock_config: Settings, openai_handler: OpenAIHandler -): +async def test_make_request_error_handling(mock_config: Settings, openai_handler: OpenAIHandler): """Test error handling in _make_request.""" async def run_test(error):