Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions elroy/core/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ def db(self) -> DbSession:
def db_manager(self) -> DbManager:
assert self.params.database_url, "Database URL not set"
return get_db_manager(self.params.database_url)

@cached_property
def llm_client(self):
"""Get the LLM client instance. Can be overridden in subclasses for testing."""
from ..llm.llm_client import LLMClient
return LLMClient()

@allow_unused
def is_db_connected(self) -> bool:
Expand Down
210 changes: 210 additions & 0 deletions elroy/llm/cached_llm_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""
Cached LLM Client for Tests Only.

WARNING: This client is ONLY intended for use in tests to cache LLM responses
to disk for reproducible test runs and cost savings. It should NEVER be used
in production code as it would create unwanted caching behavior.

The caching mechanism writes responses to JSON files in tests/fixtures/llm_cache/
which can be checked into version control so remote test runs use cached data
instead of making real API calls.
"""
import hashlib
import json
import os
from pathlib import Path
from typing import List, Type, TypeVar, Any, Dict
from pydantic import BaseModel
from ..config.llm import ChatModel, EmbeddingModel
from .llm_client import LLMClient

T = TypeVar("T", bound=BaseModel)


class CachedLLMClient(LLMClient):
"""
TEST-ONLY LLM client that caches responses to disk.

⚠️ WARNING: This client is ONLY for tests! It caches all LLM responses
to disk files in tests/fixtures/llm_cache/. This is useful for:

1. Making tests deterministic and reproducible
2. Reducing API costs during test development
3. Enabling offline test execution
4. Consistent CI/CD test runs

NEVER use this client in production - it would cache all LLM interactions
to disk which is not desired behavior for a production application.

Cache files are organized by content hash to ensure deterministic behavior
regardless of when the test is run.
"""

def __init__(self, cache_dir: str = "tests/fixtures/llm_cache"):
"""
Initialize the cached LLM client.

Args:
cache_dir: Directory to store cache files (relative to project root)
"""
# This is TEST-ONLY - we assume we're running from project root during tests
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)

def _get_cache_key(self, data: Dict[str, Any]) -> str:
"""
Generate a deterministic cache key based on input parameters.

Uses content-based hashing to ensure the same inputs always produce
the same cache key, making tests reproducible.
"""
# Create a sorted, stable representation of the input data
cache_data = json.dumps(data, sort_keys=True, default=str)
return hashlib.sha256(cache_data.encode()).hexdigest()

def _get_cache_path(self, method: str, cache_key: str) -> Path:
"""Get the file path for a cached response."""
return self.cache_dir / f"{method}_{cache_key}.json"

def _load_from_cache(self, cache_path: Path) -> Any:
"""Load a response from cache file."""
if cache_path.exists():
try:
with open(cache_path, 'r') as f:
return json.load(f)
except (json.JSONDecodeError, FileNotFoundError):
# If cache file is corrupted, ignore it and re-generate
pass
return None

def _save_to_cache(self, cache_path: Path, data: Any) -> None:
"""Save a response to cache file."""
try:
with open(cache_path, 'w') as f:
json.dump(data, f, indent=2, default=str)
except Exception as e:
# If we can't save to cache, log but don't fail the test
print(f"Warning: Could not save to cache {cache_path}: {e}")

def query_llm(self, model: ChatModel, prompt: str, system: str) -> str:
"""
Query LLM with caching for tests.

Caches based on model name, prompt, and system message content.
Falls back to real API call if cache miss, then saves response.
"""
cache_data = {
"method": "query_llm",
"model": model.name,
"prompt": prompt,
"system": system
}

cache_key = self._get_cache_key(cache_data)
cache_path = self._get_cache_path("query_llm", cache_key)

# Try to load from cache first
cached_response = self._load_from_cache(cache_path)
if cached_response is not None:
return cached_response["response"]

# Cache miss - make real API call
response = super().query_llm(model=model, prompt=prompt, system=system)

# Save to cache for future test runs
self._save_to_cache(cache_path, {"response": response})

return response

def query_llm_with_response_format(self, model: ChatModel, prompt: str, system: str, response_format: Type[T]) -> T:
"""
Query LLM with response format, using caching for tests.

Caches based on model, prompt, system message, and response format class name.
"""
cache_data = {
"method": "query_llm_with_response_format",
"model": model.name,
"prompt": prompt,
"system": system,
"response_format": response_format.__name__
}

cache_key = self._get_cache_key(cache_data)
cache_path = self._get_cache_path("query_llm_with_response_format", cache_key)

# Try to load from cache first
cached_response = self._load_from_cache(cache_path)
if cached_response is not None:
# Reconstruct the Pydantic model from cached JSON
return response_format.model_validate(cached_response["response"])

# Cache miss - make real API call
response = super().query_llm_with_response_format(
model=model, prompt=prompt, system=system, response_format=response_format
)

# Save to cache for future test runs
self._save_to_cache(cache_path, {"response": response.model_dump()})

return response

def query_llm_with_word_limit(self, model: ChatModel, prompt: str, system: str, word_limit: int) -> str:
"""
Query LLM with word limit, using caching for tests.

Caches based on model, prompt, system message, and word limit.
"""
cache_data = {
"method": "query_llm_with_word_limit",
"model": model.name,
"prompt": prompt,
"system": system,
"word_limit": word_limit
}

cache_key = self._get_cache_key(cache_data)
cache_path = self._get_cache_path("query_llm_with_word_limit", cache_key)

# Try to load from cache first
cached_response = self._load_from_cache(cache_path)
if cached_response is not None:
return cached_response["response"]

# Cache miss - make real API call
response = super().query_llm_with_word_limit(
model=model, prompt=prompt, system=system, word_limit=word_limit
)

# Save to cache for future test runs
self._save_to_cache(cache_path, {"response": response})

return response

def get_embedding(self, model: EmbeddingModel, text: str) -> List[float]:
"""
Get embedding with caching for tests.

Caches based on model name and text content.
"""
cache_data = {
"method": "get_embedding",
"model": model.name,
"text": text
}

cache_key = self._get_cache_key(cache_data)
cache_path = self._get_cache_path("get_embedding", cache_key)

# Try to load from cache first
cached_response = self._load_from_cache(cache_path)
if cached_response is not None:
return cached_response["response"]

# Cache miss - make real API call
response = super().get_embedding(model=model, text=text)

# Save to cache for future test runs
self._save_to_cache(cache_path, {"response": response})

return response
44 changes: 44 additions & 0 deletions elroy/llm/llm_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Base LLM Client class that wraps existing LLM functions for class-based interface.
"""
from typing import List, Type, TypeVar
from pydantic import BaseModel
from ..config.llm import ChatModel, EmbeddingModel
from .client import query_llm, query_llm_with_response_format, query_llm_with_word_limit, get_embedding

T = TypeVar("T", bound=BaseModel)


class LLMClient:
"""
Base class that provides a class-based interface to LLM functions.

This wraps the existing standalone LLM functions to provide a consistent
interface that can be easily extended or mocked in tests.
"""

def query_llm(self, model: ChatModel, prompt: str, system: str) -> str:
"""Query the LLM with a prompt and system message."""
return query_llm(model=model, prompt=prompt, system=system)

def query_llm_with_response_format(self, model: ChatModel, prompt: str, system: str, response_format: Type[T]) -> T:
"""Query the LLM with a specific response format."""
return query_llm_with_response_format(
model=model,
prompt=prompt,
system=system,
response_format=response_format
)

def query_llm_with_word_limit(self, model: ChatModel, prompt: str, system: str, word_limit: int) -> str:
"""Query the LLM with a word limit constraint."""
return query_llm_with_word_limit(
model=model,
prompt=prompt,
system=system,
word_limit=word_limit
)

def get_embedding(self, model: EmbeddingModel, text: str) -> List[float]:
"""Generate an embedding for the given text."""
return get_embedding(model=model, text=text)
18 changes: 14 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def io(rich_formatter: RichFormatter) -> Generator[MockCliIO, Any, None]:


@pytest.fixture(scope="function")
def george_ctx(ctx: ElroyContext) -> Generator[ElroyContext, Any, None]:
def george_ctx(ctx: TestElroyContext) -> Generator[TestElroyContext, Any, None]:
messages = [
ContextMessage(
role=USER,
Expand Down Expand Up @@ -208,12 +208,22 @@ def user_token() -> Generator[str, None, None]:
yield str(uuid.uuid4())


class TestElroyContext(ElroyContext):
"""Test-specific ElroyContext that uses cached LLM client."""

@cached_property
def llm_client(self):
"""Override to use cached LLM client in tests."""
from elroy.llm.cached_llm_client import CachedLLMClient
return CachedLLMClient()


@pytest.fixture(scope="function")
def ctx(db_manager: DbManager, db_session: DbSession, user_token, chat_model_name: str) -> Generator[ElroyContext, None, None]:
"""Create an ElroyContext for testing, using the same defaults as the CLI"""
def ctx(db_manager: DbManager, db_session: DbSession, user_token, chat_model_name: str) -> Generator[TestElroyContext, None, None]:
"""Create a TestElroyContext for testing with cached LLM client"""

# Create new context with all parameters
ctx = ElroyContext.init(
ctx = TestElroyContext.init(
user_token=user_token,
database_url=db_manager.url,
chat_model=chat_model_name,
Expand Down