Skip to content

Commit 65e54a1

Browse files
committed
feat(ai): improve robustness of AI client
Signed-off-by: Amine <[email protected]>
1 parent b9c1921 commit 65e54a1

File tree

11 files changed

+344
-189
lines changed

11 files changed

+344
-189
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies = [
3939
"cryptography >=44.0.0,<45.0.0",
4040
"semgrep == 1.113.0",
4141
"email-validator >=2.2.0,<3.0.0",
42+
"pydantic >= 2.11.5,<2.12.0",
4243
]
4344
keywords = []
4445
# https://pypi.org/classifiers/

src/macaron/ai.py

Lines changed: 0 additions & 175 deletions
This file was deleted.

src/macaron/ai/README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Macaron AI Module
2+
3+
This module provides the foundation for interacting with Large Language Models (LLMs) in a provider-agnostic way. It includes an abstract client definition, provider-specific client implementations, a client factory, and utility functions for processing responses.
4+
5+
## Module Components
6+
7+
- **ai_client.py**
8+
Defines the abstract [`AIClient`](./ai_client.py) class. This class handles the initialization of LLM configuration from the defaults and serves as the base for all specific AI client implementations.
9+
10+
- **openai_client.py**
11+
Implements the [`OpenAiClient`](./openai_client.py) class, a concrete subclass of [`AIClient`](./ai_client.py). This client interacts with OpenAI-like APIs by sending requests using HTTP and processing the responses. It also validates and structures responses using the tools provided.
12+
13+
- **ai_factory.py**
14+
Contains the [`AIClientFactory`](./ai_factory.py) class, which is responsible for reading provider configuration from the defaults and creating the correct AI client instance.
15+
16+
- **ai_tools.py**
17+
Offers utility functions such as `structure_response` to assist with parsing and validating the JSON response returned by an LLM. These functions ensure that responses conform to a given Pydantic model for easier downstream processing.
18+
19+
## Usage
20+
21+
1. **Configuration:**
22+
The module reads the LLM configuration from the application defaults (using the `defaults` module). Make sure that the `llm` section in your configuration includes valid settings such as `enabled`, `api_key`, `api_endpoint`, `model`, and `context_window`.
23+
24+
2. **Creating a Client:**
25+
Use the [`AIClientFactory`](./ai_factory.py) to create an AI client instance. The factory checks the configured provider and returns a client (e.g., an instance of [`OpenAiClient`](./openai_client.py)) that can be used to invoke the LLM.
26+
27+
Example:
28+
```py
29+
from macaron.ai.ai_factory import AIClientFactory
30+
31+
factory = AIClientFactory()
32+
client = factory.create_client(system_prompt="You are a helpful assistant.")
33+
response = client.invoke("Hello, how can you assist me?")
34+
print(response)
35+
```
36+
37+
3. **Response Processing:**
38+
When a structured response is required, pass a Pydantic model class to the `invoke` method. The [`ai_tools.py`](./ai_tools.py) module takes care of parsing and validating the response to ensure it meets the expected structure.
39+
40+
## Logging and Error Handling
41+
42+
- The module uses Python's logging framework to report important events, such as token usage and warnings when prompts exceed the allowed context window.
43+
- Configuration errors (e.g., missing API key or endpoint) are handled by raising descriptive exceptions, such as those defined in the [`ConfigurationError`](../errors.py).
44+
45+
## Extensibility
46+
47+
The design of the AI module is provider-agnostic. To add support for additional LLM providers:
48+
- Implement a new client by subclassing [`AIClient`](./ai_client.py).
49+
- Add the new client to the [`PROVIDER_MAPPING`](./ai_factory.py).
50+
- Update the configuration defaults accordingly.

src/macaron/ai/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
2+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

src/macaron/ai/ai_client.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
2+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.
3+
4+
"""This module defines the abstract AIClient class for implementing AI clients."""
5+
6+
import logging
7+
from abc import ABC, abstractmethod
8+
from typing import Any, TypeVar
9+
10+
from pydantic import BaseModel
11+
12+
T = TypeVar("T", bound=BaseModel)
13+
14+
logger: logging.Logger = logging.getLogger(__name__)
15+
16+
17+
class AIClient(ABC):
18+
"""This abstract class is used to implement ai clients."""
19+
20+
def __init__(self, system_prompt: str, defaults: dict) -> None:
21+
"""
22+
Initialize the AI client.
23+
24+
The LLM configuration is read from defaults.
25+
"""
26+
self.system_prompt = system_prompt
27+
self.defaults = defaults
28+
29+
@abstractmethod
30+
def invoke(
31+
self,
32+
user_prompt: str,
33+
temperature: float = 0.2,
34+
structured_output: type[T] | None = None,
35+
) -> Any:
36+
"""
37+
Invoke the LLM and optionally validate its response.
38+
39+
Parameters
40+
----------
41+
user_prompt: str
42+
The user prompt to send to the LLM.
43+
temperature: float
44+
The temperature for the LLM response.
45+
structured_output: Optional[Type[T]]
46+
The Pydantic model to validate the response against. If provided, the response will be parsed and validated.
47+
48+
Returns
49+
-------
50+
Optional[T | str]
51+
The validated Pydantic model instance if `structured_output` is provided,
52+
or the raw string response if not.
53+
"""

src/macaron/ai/ai_factory.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) 2025 - 2025, Oracle and/or its affiliates. All rights reserved.
2+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.
3+
4+
"""This module defines the AIClientFactory class for creating AI clients based on provider configuration."""
5+
6+
import logging
7+
8+
from macaron.ai.ai_client import AIClient
9+
from macaron.ai.openai_client import OpenAiClient
10+
from macaron.config.defaults import defaults
11+
from macaron.errors import ConfigurationError
12+
13+
logger: logging.Logger = logging.getLogger(__name__)
14+
15+
16+
class AIClientFactory:
17+
"""Factory to create AI clients based on provider configuration."""
18+
19+
PROVIDER_MAPPING: dict[str, type[AIClient]] = {"openai": OpenAiClient}
20+
21+
def __init__(self) -> None:
22+
"""
23+
Initialize the AI client.
24+
25+
The LLM configuration is read from defaults.
26+
"""
27+
self.defaults = self._load_defaults()
28+
29+
def _load_defaults(self) -> dict:
30+
section_name = "llm"
31+
default_values = {
32+
"enabled": False,
33+
"provider": "",
34+
"api_key": "",
35+
"api_endpoint": "",
36+
"model": "",
37+
"context_window": 10000,
38+
}
39+
40+
if defaults.has_section(section_name):
41+
section = defaults[section_name]
42+
default_values["enabled"] = section.getboolean("enabled", default_values["enabled"])
43+
default_values["api_key"] = str(section.get("api_key", default_values["api_key"])).strip().lower()
44+
default_values["api_endpoint"] = (
45+
str(section.get("api_endpoint", default_values["api_endpoint"])).strip().lower()
46+
)
47+
default_values["model"] = str(section.get("model", default_values["model"])).strip().lower()
48+
default_values["provider"] = str(section.get("provider", default_values["provider"])).strip().lower()
49+
default_values["context_window"] = section.getint("context_window", 10000)
50+
51+
if default_values["enabled"]:
52+
for key, value in default_values.items():
53+
if not value:
54+
raise ConfigurationError(
55+
f"AI client configuration '{key}' is required but not set in the defaults."
56+
)
57+
58+
return default_values
59+
60+
def create_client(self, system_prompt: str) -> AIClient | None:
61+
"""Create an AI client based on the configured provider."""
62+
client_class = self.PROVIDER_MAPPING.get(self.defaults["provider"])
63+
if client_class is None:
64+
logger.error("Provider '%s' is not supported.", self.defaults["provider"])
65+
return None
66+
return client_class(system_prompt, self.defaults)
67+
68+
def list_available_providers(self) -> list[str]:
69+
"""List all registered providers."""
70+
return list(self.PROVIDER_MAPPING.keys())

0 commit comments

Comments
 (0)