Skip to content

Commit 05fee79

Browse files
authored
Feat/llm cache (#235)
* first version of cache * add proper dumping and loading * upd hasher and test * minor change * try to track cached structured outputs * minor change * decompose structured output tests * upd basics test * refactor caching tracking a little bit * upd caching test * upd retries test * code formatter
1 parent 169dcd9 commit 05fee79

File tree

8 files changed

+406
-216
lines changed

8 files changed

+406
-216
lines changed

autointent/_hash.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ class Hasher:
1515
hashing embeddings from :py:class:`autointent.Embedder`.
1616
"""
1717

18-
def __init__(self) -> None:
18+
def __init__(self, strict: bool = False) -> None:
1919
"""Initialize the Hasher instance and sets up the internal xxhash state.
2020
2121
This state will be used for progressively hashing values using the
2222
`update` method and obtaining the final digest using `hexdigest`.
2323
"""
2424
self._state = xxhash.xxh64()
25+
self.strict = strict
2526

26-
@classmethod
27-
def hash(cls, value: Any) -> int: # noqa: ANN401
27+
def hash(self, value: Any) -> int: # noqa: ANN401
2828
"""Generate a hash for the given value using xxhash.
2929
3030
Args:
@@ -35,6 +35,9 @@ def hash(cls, value: Any) -> int: # noqa: ANN401
3535
"""
3636
if hasattr(value, "__hash__") and value.__hash__ not in {None, object.__hash__}:
3737
return hash(value)
38+
if self.strict:
39+
msg = "Object is not hashable."
40+
raise ValueError(msg)
3841
return xxhash.xxh64(pickle.dumps(value)).intdigest()
3942

4043
def update(self, value: Any) -> None: # noqa: ANN401

autointent/generation/_cache.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""Helpers for caching structured outputs from LLM."""
2+
3+
import json
4+
import logging
5+
from pathlib import Path
6+
from typing import Any, TypeVar
7+
8+
from appdirs import user_cache_dir
9+
from dotenv import load_dotenv
10+
from pydantic import BaseModel, ValidationError
11+
12+
from autointent._dump_tools import PydanticModelDumper
13+
from autointent._hash import Hasher
14+
from autointent.generation.chat_templates import Message
15+
16+
logger = logging.getLogger(__name__)
17+
18+
load_dotenv()
19+
20+
T = TypeVar("T", bound=BaseModel)
21+
"""Type variable for Pydantic models used in structured output generation."""
22+
23+
24+
def _get_structured_output_cache_path(dirname: str) -> Path:
25+
"""Get the path to the structured output cache file.
26+
27+
This function constructs the full path to a cache directory stored
28+
in a specific directory under the user's home directory. The cache
29+
directory is named based on the provided dirname.
30+
added.
31+
32+
Args:
33+
dirname: The name of the cache file (without extension).
34+
35+
Returns:
36+
The full path to the cache file.
37+
"""
38+
return Path(user_cache_dir("autointent")) / "structured_outputs" / dirname
39+
40+
41+
class StructuredOutputCache:
42+
"""Cache for structured output results."""
43+
44+
def __init__(self, use_cache: bool = True) -> None:
45+
"""Initialize the cache.
46+
47+
Args:
48+
use_cache: Whether to use caching.
49+
"""
50+
self.use_cache = use_cache
51+
52+
def _get_cache_key(
53+
self, messages: list[Message], output_model: type[T], backend: str, generation_params: dict[str, Any]
54+
) -> str:
55+
"""Generate a cache key for the given parameters.
56+
57+
Args:
58+
messages: List of messages to send to the model.
59+
output_model: Pydantic model class to parse the response into.
60+
backend: Backend to use for structured output.
61+
generation_params: Generation parameters.
62+
63+
Returns:
64+
Cache key as a hexadecimal string.
65+
"""
66+
hasher = Hasher(strict=True)
67+
hasher.update(json.dumps(messages))
68+
hasher.update(json.dumps(output_model.model_json_schema()))
69+
hasher.update(backend)
70+
hasher.update(json.dumps(generation_params))
71+
return hasher.hexdigest()
72+
73+
def get(
74+
self, messages: list[Message], output_model: type[T], backend: str, generation_params: dict[str, Any]
75+
) -> T | None:
76+
"""Get cached result if available.
77+
78+
Args:
79+
messages: List of messages to send to the model.
80+
output_model: Pydantic model class to parse the response into.
81+
backend: Backend to use for structured output.
82+
generation_params: Generation parameters.
83+
84+
Returns:
85+
Cached result if available, None otherwise.
86+
"""
87+
if not self.use_cache:
88+
return None
89+
90+
cache_key = self._get_cache_key(messages, output_model, backend, generation_params)
91+
cache_path = _get_structured_output_cache_path(cache_key)
92+
93+
if cache_path.exists():
94+
try:
95+
cached_data = PydanticModelDumper.load(cache_path)
96+
97+
if isinstance(cached_data, output_model):
98+
logger.debug("Using cached structured output for key: %s", cache_key)
99+
return cached_data
100+
101+
logger.warning("Cached data type mismatch, removing invalid cache")
102+
cache_path.unlink()
103+
except (ValidationError, ImportError) as e:
104+
logger.warning("Failed to load cached structured output: %s", e)
105+
cache_path.unlink(missing_ok=True)
106+
107+
return None
108+
109+
def set(
110+
self, messages: list[Message], output_model: type[T], backend: str, generation_params: dict[str, Any], result: T
111+
) -> None:
112+
"""Cache the result.
113+
114+
Args:
115+
messages: List of messages to send to the model.
116+
output_model: Pydantic model class to parse the response into.
117+
backend: Backend to use for structured output.
118+
generation_params: Generation parameters.
119+
result: The result to cache.
120+
"""
121+
if not self.use_cache:
122+
return
123+
124+
cache_key = self._get_cache_key(messages, output_model, backend, generation_params)
125+
cache_path = _get_structured_output_cache_path(cache_key)
126+
127+
cache_path.parent.mkdir(parents=True, exist_ok=True)
128+
PydanticModelDumper.dump(result, cache_path, exists_ok=True)
129+
logger.debug("Cached structured output for key: %s", cache_key)

autointent/generation/_generator.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from autointent.generation.chat_templates import Message, Role
1414

15+
from ._cache import StructuredOutputCache
16+
1517
logger = logging.getLogger(__name__)
1618

1719
load_dotenv()
@@ -38,12 +40,19 @@ class Generator:
3840
}
3941
"""Default generation parameters for API requests."""
4042

41-
def __init__(self, base_url: str | None = None, model_name: str | None = None, **generation_params: Any) -> None: # noqa: ANN401
43+
def __init__(
44+
self,
45+
base_url: str | None = None,
46+
model_name: str | None = None,
47+
use_cache: bool = True,
48+
**generation_params: Any, # noqa: ANN401
49+
) -> None:
4250
"""Initialize the Generator with API configuration.
4351
4452
Args:
4553
base_url: OpenAI API compatible server URL.
4654
model_name: Name of the language model to use.
55+
use_cache: Whether to use caching for structured outputs.
4756
**generation_params: Additional generation parameters to override defaults passed to OpenAI completions API.
4857
"""
4958
base_url = base_url or os.getenv("OPENAI_BASE_URL")
@@ -58,6 +67,7 @@ def __init__(self, base_url: str | None = None, model_name: str | None = None, *
5867
**self._default_generation_params,
5968
**generation_params,
6069
} # https://stackoverflow.com/a/65539348
70+
self.cache = StructuredOutputCache(use_cache=use_cache)
6171

6272
def get_chat_completion(self, messages: list[Message]) -> str:
6373
"""Prompt LLM and return its answer.
@@ -92,15 +102,15 @@ def _create_retry_messages(self, error_message: str, raw: str | None) -> list[Me
92102
res.append({"role": Role.ASSISTANT, "content": raw})
93103
res.append(
94104
{
95-
"role": "user",
105+
"role": Role.USER,
96106
"content": dedent(
97107
f"""The previous response failed validation with the following error: {error_message}
98108
99-
Make sure to:
100-
1. Follow the exact schema structure
101-
2. Use the correct data types for each field
102-
3. Include all required fields
103-
4. Ensure the response is valid JSON"""
109+
Make sure to:
110+
1. Follow the exact schema structure
111+
2. Use the correct data types for each field
112+
3. Include all required fields
113+
4. Ensure the response is valid JSON"""
104114
),
105115
}
106116
)
@@ -184,6 +194,11 @@ async def get_structured_output_async(
184194
Returns:
185195
Parsed response as an instance of the provided Pydantic model.
186196
"""
197+
# Check cache first
198+
cached_result = self.cache.get(messages, output_model, backend, self.generation_params)
199+
if cached_result is not None:
200+
return cached_result
201+
187202
current_messages = messages.copy()
188203
res: T | None = None
189204

@@ -213,6 +228,9 @@ async def get_structured_output_async(
213228
logger.exception(msg)
214229
raise RuntimeError(msg)
215230

231+
# Cache the successful result
232+
self.cache.set(messages, output_model, backend, self.generation_params, res)
233+
216234
return res
217235

218236
def _get_structured_output_openai_sync(
@@ -293,6 +311,11 @@ def get_structured_output_sync(
293311
Returns:
294312
Parsed response as an instance of the provided Pydantic model.
295313
"""
314+
# Check cache first
315+
cached_result = self.cache.get(messages, output_model, backend, self.generation_params)
316+
if cached_result is not None:
317+
return cached_result
318+
296319
current_messages = messages.copy()
297320
res: T | None = None
298321

@@ -322,4 +345,7 @@ def get_structured_output_sync(
322345
logger.exception(msg)
323346
raise RuntimeError(msg)
324347

348+
# Cache the successful result
349+
self.cache.set(messages, output_model, backend, self.generation_params, res)
350+
325351
return res

tests/generation/structured_output/__init__.py

Whitespace-only changes.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Tests for structured output functionality."""
2+
3+
import os
4+
from typing import Literal
5+
6+
import pytest
7+
from pydantic import BaseModel, Field
8+
9+
from autointent.generation import Generator
10+
from autointent.generation.chat_templates import Role
11+
12+
13+
class Person(BaseModel):
14+
reasoning: str = Field(description="Some preliminary reasoning to plan fields' values")
15+
name: str = Field(description="The person's full name")
16+
age: int = Field(description="The person's age in years", ge=0, le=150)
17+
email: str = Field(description="The person's email address")
18+
occupation: str = Field(description="The person's job or profession")
19+
is_active: bool = Field(description="Whether the person is currently active", default=True)
20+
status: Literal["active", "inactive", "pending"] = Field(description="Current status of the person")
21+
hobbies: list[str] = Field(description="List of the person's hobbies and interests")
22+
23+
24+
@pytest.fixture
25+
def generator():
26+
"""Create a generator instance for testing."""
27+
return Generator(max_tokens=1000, use_cache=False)
28+
29+
30+
@pytest.mark.skipif(
31+
not os.getenv("OPENAI_API_KEY") or not os.getenv("OPENAI_MODEL_NAME"),
32+
reason="OPENAI_API_KEY and OPENAI_MODEL_NAME environment variables are required for this test",
33+
)
34+
class TestStructuredOutput:
35+
"""Test structured output functionality for different backends."""
36+
37+
def test_basic_chat_completion(self, generator):
38+
"""Test basic chat completion functionality."""
39+
response = generator.get_chat_completion(messages=[{"role": Role.USER, "content": "hi! tell me a joke"}])
40+
assert isinstance(response, str)
41+
assert len(response) > 0
42+
43+
@pytest.mark.asyncio
44+
async def test_async_chat_completion(self, generator):
45+
"""Test async chat completion functionality."""
46+
response = await generator.get_chat_completion_async(
47+
messages=[{"role": Role.USER, "content": "hi! tell me a joke"}]
48+
)
49+
assert isinstance(response, str)
50+
assert len(response) > 0
51+
52+
def test_structured_output(self, generator):
53+
"""Test that async structured output works without failing."""
54+
result = generator.get_structured_output_sync(
55+
messages=[{"role": Role.USER, "content": "Create a person"}],
56+
output_model=Person,
57+
max_retries=5,
58+
)
59+
60+
assert isinstance(result, Person)
61+
62+
@pytest.mark.asyncio
63+
async def test_structured_output_async(self, generator):
64+
"""Test that async structured output works without failing."""
65+
result = await generator.get_structured_output_async(
66+
messages=[{"role": Role.USER, "content": "Create a person"}],
67+
output_model=Person,
68+
max_retries=5,
69+
)
70+
71+
assert isinstance(result, Person)

0 commit comments

Comments
 (0)