Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/docs/snippets/cli_options.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
<!-- --8<------ [start:options] -->

```console
Usage: readmeai [OPTIONS]

Expand All @@ -9,7 +10,7 @@ Options:
-a, --align [center|left|right]
align for the README.md file header
sections.
--api [anthropic|gemini|ollama|openai|offline]
--api [anthropic|gemini|ollama|openai|azure|offline]
LLM API service provider to power the README
file generation.
-bc, --badge-color TEXT Primary color (hex code or name) to use for
Expand Down Expand Up @@ -50,4 +51,5 @@ Options:
generated for the README file.
--help Show this message and exit.
```

<!-- --8<------ [end:options] -->
84 changes: 84 additions & 0 deletions readmeai/models/azure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Azure OpenAI model handler for ReadmeAI."""
import os
from typing import Any

import aiohttp
import openai
from openai import AsyncAzureOpenAI

from readmeai.config.settings import ConfigLoader
from readmeai.extractors.models import RepositoryContext
from readmeai.models.openai_ import OpenAIHandler
from readmeai.models.tokens import token_handler


class AzureOpenAIHandler(OpenAIHandler):
"""Handler for Azure OpenAI models."""

def __init__(self, config_loader: ConfigLoader, context: RepositoryContext) -> None:
super().__init__(config_loader, context)
self._model_settings()

def _model_settings(self):
self.host_name = os.getenv("AZURE_ENDPOINT")
self.max_tokens = self.config.llm.tokens
self.model = os.getenv("AZURE_MODEL")
self.top_p = self.config.llm.top_p

self.client = AsyncAzureOpenAI(
azure_endpoint=self.host_name,
api_key=os.getenv("AZURE_API_KEY"),
api_version=os.getenv("AZURE_API_VERSION"),
)

async def _make_request(
self,
index: str | None,
prompt: str | None,
tokens: int | None,
repo_files: Any,
):
"""Process requests to OpenAI API, with retries and error handling."""

try:
if prompt is None:
raise ValueError("Prompt cannot be None")

prompt = await token_handler(
config=self.config,
index=index,
prompt=prompt,
tokens=tokens,
)
if not prompt:
raise ValueError("Token handler returned empty prompt")

if index == "file_summary":
self.max_tokens = 100

parameters = await self._build_payload(prompt)

# Just await the create call directly
response = await self.client.chat.completions.create(**parameters)
content = response.choices[0].message.content

if not content:
raise ValueError("Empty response from API")

self._logger.info(
f"Response from {self.config.llm.api.capitalize()} for '{index}': {content}",
)
return index, content

except (
aiohttp.ClientError,
aiohttp.ClientResponseError,
aiohttp.ClientConnectorError,
openai.OpenAIError,
) as e:
self._logger.error(f"Error processing request for '{index}': {e!r}")
raise # Re-raise for retry decorator

except Exception as e:
self._logger.error(f"Unexpected error for '{index}': {e!r}")
return index, self.placeholder
2 changes: 2 additions & 0 deletions readmeai/models/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class LLMAuthKeys(str, Enum):
GOOGLE_API_KEY = "GOOGLE_API_KEY"
OLLAMA_HOST = "OLLAMA_HOST"
OPENAI_API_KEY = "OPENAI_API_KEY"
AZURE_API_KEY = "AZURE_API_KEY"


class LLMProviders(str, Enum):
Expand All @@ -22,6 +23,7 @@ class LLMProviders(str, Enum):
OLLAMA = "ollama"
OPENAI = "openai"
OFFLINE = "offline"
AZURE = "azure"


class AnthropicModels(str, Enum):
Expand Down
12 changes: 5 additions & 7 deletions readmeai/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from readmeai.core.errors import UnsupportedServiceError
from readmeai.extractors.models import RepositoryContext
from readmeai.models.anthropic import AnthropicHandler
from readmeai.models.azure import AzureOpenAIHandler
from readmeai.models.base import BaseModelHandler
from readmeai.models.enums import LLMProviders
from readmeai.models.gemini import GeminiHandler
from readmeai.models.offline import OfflineHandler
from readmeai.models.openai import OpenAIHandler
from readmeai.models.openai_ import OpenAIHandler


class ModelFactory:
Expand All @@ -22,18 +23,15 @@ class ModelFactory:
LLMProviders.OLLAMA.value: OpenAIHandler,
LLMProviders.OPENAI.value: OpenAIHandler,
LLMProviders.OFFLINE.value: OfflineHandler,
LLMProviders.AZURE.value: AzureOpenAIHandler,
}

@staticmethod
def get_backend(
config: ConfigLoader, context: RepositoryContext
) -> BaseModelHandler:
def get_backend(config: ConfigLoader, context: RepositoryContext) -> BaseModelHandler:
"""Retrieves configured LLM API handler instance."""
llm_service = ModelFactory._model_map.get(config.config.llm.api)

if llm_service is None:
raise UnsupportedServiceError(
f"Unsupported LLM provider: {config.config.llm.api}"
)
raise UnsupportedServiceError(f"Unsupported LLM provider: {config.config.llm.api}")

return llm_service(config, context)
File renamed without changes.
77 changes: 39 additions & 38 deletions tests/cli/test_main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from unittest.mock import MagicMock, patch

import pytest
from _pytest._py.path import LocalPath
from click.testing import CliRunner
from unittest.mock import MagicMock, patch

from readmeai.cli.main import main
from readmeai.config.settings import ConfigLoader

Expand All @@ -24,13 +24,13 @@ def mock_readme_agent():


def test_main_command_basic(
cli_runner: CliRunner,
mock_config_loader: ConfigLoader,
output_file_path: str,
cli_runner: CliRunner,
mock_config_loader: ConfigLoader,
output_file_path: str,
):
with patch(
"readmeai.config.settings.ConfigLoader",
return_value=mock_config_loader,
"readmeai.config.settings.ConfigLoader",
return_value=mock_config_loader,
):
result = cli_runner.invoke(
main,
Expand All @@ -46,9 +46,9 @@ def test_main_command_basic(


def test_main_command_all_options(
cli_runner: CliRunner,
mock_config_loader: ConfigLoader,
output_file_path: str,
cli_runner: CliRunner,
mock_config_loader: ConfigLoader,
output_file_path: str,
):
mock_config = mock_config_loader
mock_config.config.git.repository = "https://github.com/eli64s/readme-ai-streamlit"
Expand Down Expand Up @@ -108,61 +108,62 @@ def test_version_option(cli_runner: CliRunner):
[
("--align", "invalid", "Invalid value for '-a' / '--align'"),
(
"--api",
"invalid",
"Invalid value for '--api': 'invalid' is not one of 'anthropic', 'gemini', 'ollama', 'openai', 'offline'.",
"--api",
"invalid",
"Invalid value for '--api': 'invalid' is not one of 'anthropic', 'gemini', 'ollama', 'openai', 'offline', "
"'azure'.",
),
# (
# "--badge-color",
# "invalid",
# "Invalid value for '-bc' / '--badge-color'",
# ),
(
"--badge-style",
"invalid",
"Invalid value for '-bs' / '--badge-style'",
"--badge-style",
"invalid",
"Invalid value for '-bs' / '--badge-style'",
),
(
"--context-window",
"invalid",
"Invalid value for '-cw' / '--context-window'",
"--context-window",
"invalid",
"Invalid value for '-cw' / '--context-window'",
),
(
"--header-style",
"invalid",
"Invalid value for '-hs' / '--header-style'",
"--header-style",
"invalid",
"Invalid value for '-hs' / '--header-style'",
),
("--logo", "invalid", "Invalid value for '-l' / '--logo'"),
# ("--model", "invalid", "Invalid value for '-m' / '--model'"),
(
"--rate-limit",
"invalid",
"Invalid value for '-rl' / '--rate-limit'",
"--rate-limit",
"invalid",
"Invalid value for '-rl' / '--rate-limit'",
),
(
"--temperature",
"invalid",
"Invalid value for '-t' / '--temperature'",
"--temperature",
"invalid",
"Invalid value for '-t' / '--temperature'",
),
(
"--navigation-style",
"invalid",
"Invalid value for '-ns' / '--navigation-style'",
"--navigation-style",
"invalid",
"Invalid value for '-ns' / '--navigation-style'",
),
(
"--top-p",
"invalid",
"Invalid value for '-tp' / '--top-p'",
"--top-p",
"invalid",
"Invalid value for '-tp' / '--top-p'",
),
(
"--tree-max-depth",
"invalid",
"Invalid value for '-td' / '--tree-max-depth'",
"--tree-max-depth",
"invalid",
"Invalid value for '-td' / '--tree-max-depth'",
),
],
)
def test_invalid_option_values(
temp_dir: LocalPath, cli_runner: CliRunner, option, value, expected
temp_dir: LocalPath, cli_runner: CliRunner, option, value, expected
):
result = cli_runner.invoke(main, ["--repository", str(temp_dir), option, value])
assert result.exit_code != 0
Expand Down
52 changes: 38 additions & 14 deletions tests/models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

import aiohttp
import pytest

from readmeai.config.settings import ConfigLoader
from readmeai.extractors.models import RepositoryContext
from readmeai.models.anthropic import ANTHROPIC_AVAILABLE, AnthropicHandler
from readmeai.models.azure import AzureOpenAIHandler
from readmeai.models.enums import GeminiModels
from readmeai.models.gemini import GeminiHandler
from readmeai.models.openai import OpenAIHandler
from readmeai.models.openai_ import OpenAIHandler


@pytest.fixture
Expand Down Expand Up @@ -46,9 +48,7 @@ def mock_aiohttp_session():


@pytest.fixture
def anthropic_handler(
mock_config_loader: ConfigLoader, mock_repository_context: RepositoryContext
):
def anthropic_handler(mock_config_loader: ConfigLoader, mock_repository_context: RepositoryContext):
if not ANTHROPIC_AVAILABLE:
pytest.skip("Anthropic library is not available")
context = mock_repository_context
Expand All @@ -60,20 +60,18 @@ def anthropic_handler(

@pytest.fixture
def anthropic_handler_with_mock_session(
anthropic_handler: AnthropicHandler, monkeypatch: pytest.MonkeyPatch
anthropic_handler: AnthropicHandler, monkeypatch: pytest.MonkeyPatch
):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test_api_key")
mock_create = AsyncMock(
return_value=MagicMock(content=[MagicMock(text="test_response")])
)
mock_create = AsyncMock(return_value=MagicMock(content=[MagicMock(text="test_response")]))
anthropic_handler.client.messages.create = mock_create
return anthropic_handler


@pytest.fixture
def gemini_handler(
mock_config_loader: ConfigLoader,
mock_repository_context: RepositoryContext,
mock_config_loader: ConfigLoader,
mock_repository_context: RepositoryContext,
):
"""Fixture to provide a GeminiHandler instance."""
mock_config_loader.config.llm.model = GeminiModels.GEMINI_FLASH.value
Expand All @@ -86,9 +84,9 @@ def gemini_handler(

@pytest.fixture
def openai_handler(
mock_config_loader: ConfigLoader,
mock_repository_context: RepositoryContext,
monkeypatch: pytest.MonkeyPatch,
mock_config_loader: ConfigLoader,
mock_repository_context: RepositoryContext,
monkeypatch: pytest.MonkeyPatch,
):
"""Fixture to provide an OpenAIHandler instance with a mocked API key."""
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
Expand All @@ -100,13 +98,39 @@ def openai_handler(

@pytest.fixture
def openai_handler_with_mock_session(
openai_handler: OpenAIHandler, mock_aiohttp_session: MagicMock
openai_handler: OpenAIHandler, mock_aiohttp_session: MagicMock
):
"""Fixture to provide an OpenAIHandler with a mocked session."""
openai_handler._session = mock_aiohttp_session
return openai_handler


@pytest.fixture
def azure_openai_handler(
mock_config_loader: ConfigLoader,
mock_repository_context: RepositoryContext,
):
"""Fixture to provide an AzureOpenAIHandler instance."""
with patch.dict("os.environ", {
"AZURE_ENDPOINT": "https://test.azure.com",
"AZURE_API_KEY": "test_azure_api_key",
"AZURE_API_VERSION": "2023-05-15",
}):
return AzureOpenAIHandler(
config_loader=mock_config_loader,
context=mock_repository_context,
)


@pytest.fixture
def azure_openai_handler_with_mock_session(
azure_openai_handler: AzureOpenAIHandler, mock_aiohttp_session: MagicMock
):
"""Fixture to provide an AzureOpenAIHandler with a mocked session."""
azure_openai_handler._session = mock_aiohttp_session
return azure_openai_handler


@pytest.fixture
def ollama_localhost():
"""Fixture to provide a localhost URL for Ollama."""
Expand Down
Loading