Skip to content

Commit eb0a441

Browse files
authored
Feat/efficient caching (#239)
* fix retries test * remove vllm backend-related functionality * fix typing errors * ram-stored cache * add async disk I/O * try to fix typing
1 parent 43b671c commit eb0a441

File tree

7 files changed

+218
-105
lines changed

7 files changed

+218
-105
lines changed

autointent/_dump_tools/unit_dumpers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55
from typing import Any, TypeVar
66

7+
import aiofiles
78
import joblib
89
import numpy as np
910
import numpy.typing as npt
@@ -163,6 +164,15 @@ def dump(obj: BaseModel, path: Path, exists_ok: bool) -> None:
163164
with (path / "model_dump.json").open("w", encoding="utf-8") as file:
164165
json.dump(obj.model_dump(), file, ensure_ascii=False, indent=4)
165166

167+
@staticmethod
168+
async def dump_async(obj: BaseModel, path: Path, exists_ok: bool) -> None:
169+
class_info = {"name": obj.__class__.__name__, "module": obj.__class__.__module__}
170+
path.mkdir(parents=True, exist_ok=exists_ok)
171+
async with aiofiles.open(path / "class_info.json", mode="w", encoding="utf-8") as file:
172+
await file.write(json.dumps(class_info, ensure_ascii=False, indent=4))
173+
async with aiofiles.open(path / "model_dump.json", mode="w", encoding="utf-8") as file:
174+
await file.write(json.dumps(obj.model_dump(), ensure_ascii=False, indent=4))
175+
166176
@staticmethod
167177
def load(path: Path, **kwargs: Any) -> BaseModel: # noqa: ANN401, ARG004
168178
with (path / "model_dump.json").open("r", encoding="utf-8") as file:
@@ -175,6 +185,20 @@ def load(path: Path, **kwargs: Any) -> BaseModel: # noqa: ANN401, ARG004
175185
model_type = getattr(model_type, class_info["name"])
176186
return model_type.model_validate(content) # type: ignore[no-any-return]
177187

188+
@staticmethod
189+
async def load_async(path: Path, **kwargs: Any) -> BaseModel: # noqa: ANN401, ARG004
190+
async with aiofiles.open(path / "model_dump.json", encoding="utf-8") as file:
191+
content_str = await file.read()
192+
content = json.loads(content_str)
193+
194+
async with aiofiles.open(path / "class_info.json", encoding="utf-8") as file:
195+
class_info_str = await file.read()
196+
class_info = json.loads(class_info_str)
197+
198+
model_type = importlib.import_module(class_info["module"])
199+
model_type = getattr(model_type, class_info["name"])
200+
return model_type.model_validate(content) # type: ignore[no-any-return]
201+
178202
@classmethod
179203
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
180204
return isinstance(obj, BaseModel)

autointent/generation/_cache.py

Lines changed: 180 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,35 @@ def __init__(self, use_cache: bool = True) -> None:
4848
use_cache: Whether to use caching.
4949
"""
5050
self.use_cache = use_cache
51+
self._memory_cache: dict[str, BaseModel] = {}
5152

52-
def _get_cache_key(
53-
self, messages: list[Message], output_model: type[T], backend: str, generation_params: dict[str, Any]
54-
) -> str:
53+
if self.use_cache:
54+
self._load_existing_cache()
55+
56+
def _load_existing_cache(self) -> None:
57+
"""Load all existing cache items from disk into memory."""
58+
cache_dir = Path(user_cache_dir("autointent")) / "structured_outputs"
59+
60+
if not cache_dir.exists():
61+
return
62+
63+
for cache_file in cache_dir.iterdir():
64+
if cache_file.is_file():
65+
try:
66+
cached_data = PydanticModelDumper.load(cache_file)
67+
if isinstance(cached_data, BaseModel):
68+
self._memory_cache[cache_file.name] = cached_data
69+
logger.debug("Loaded cached item into memory: %s", cache_file.name)
70+
except (ValidationError, ImportError) as e:
71+
logger.warning("Failed to load cached item %s: %s", cache_file.name, e)
72+
cache_file.unlink(missing_ok=True)
73+
74+
def _get_cache_key(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]) -> str:
5575
"""Generate a cache key for the given parameters.
5676
5777
Args:
5878
messages: List of messages to send to the model.
5979
output_model: Pydantic model class to parse the response into.
60-
backend: Backend to use for structured output.
6180
generation_params: Generation parameters.
6281
6382
Returns:
@@ -66,19 +85,76 @@ def _get_cache_key(
6685
hasher = Hasher(strict=True)
6786
hasher.update(json.dumps(messages))
6887
hasher.update(json.dumps(output_model.model_json_schema()))
69-
hasher.update(backend)
7088
hasher.update(json.dumps(generation_params))
7189
return hasher.hexdigest()
7290

73-
def get(
74-
self, messages: list[Message], output_model: type[T], backend: str, generation_params: dict[str, Any]
75-
) -> T | None:
91+
def _check_memory_cache(self, cache_key: str, output_model: type[T]) -> T | None:
92+
"""Check if the result is available in memory cache.
93+
94+
Args:
95+
cache_key: The cache key to look up.
96+
output_model: Pydantic model class to parse the response into.
97+
98+
Returns:
99+
Cached result if available and valid, None otherwise.
100+
"""
101+
if cache_key in self._memory_cache:
102+
cached_data = self._memory_cache[cache_key]
103+
if isinstance(cached_data, output_model):
104+
logger.debug("Using cached structured output from memory for key: %s", cache_key)
105+
return cached_data
106+
# Type mismatch, remove from memory cache
107+
del self._memory_cache[cache_key]
108+
logger.warning("Cached data type mismatch in memory, removing invalid cache")
109+
return None
110+
111+
def _load_from_disk(self, cache_key: str, output_model: type[T]) -> T | None:
112+
"""Load cached result from disk.
113+
114+
Args:
115+
cache_key: The cache key to look up.
116+
output_model: Pydantic model class to parse the response into.
117+
118+
Returns:
119+
Cached result if available and valid, None otherwise.
120+
"""
121+
cache_path = _get_structured_output_cache_path(cache_key)
122+
123+
if cache_path.exists():
124+
try:
125+
cached_data = PydanticModelDumper.load(cache_path)
126+
127+
if isinstance(cached_data, output_model):
128+
logger.debug("Using cached structured output from disk for key: %s", cache_key)
129+
# Add to memory cache for future access
130+
self._memory_cache[cache_key] = cached_data
131+
return cached_data
132+
133+
logger.warning("Cached data type mismatch on disk, removing invalid cache")
134+
cache_path.unlink()
135+
except (ValidationError, ImportError) as e:
136+
logger.warning("Failed to load cached structured output from disk: %s", e)
137+
cache_path.unlink(missing_ok=True)
138+
139+
return None
140+
141+
def _save_to_disk(self, cache_key: str, result: T) -> None:
142+
"""Save result to disk cache.
143+
144+
Args:
145+
cache_key: The cache key to use.
146+
result: The result to cache.
147+
"""
148+
cache_path = _get_structured_output_cache_path(cache_key)
149+
cache_path.parent.mkdir(parents=True, exist_ok=True)
150+
PydanticModelDumper.dump(result, cache_path, exists_ok=True)
151+
152+
def get(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]) -> T | None:
76153
"""Get cached result if available.
77154
78155
Args:
79156
messages: List of messages to send to the model.
80157
output_model: Pydantic model class to parse the response into.
81-
backend: Backend to use for structured output.
82158
generation_params: Generation parameters.
83159
84160
Returns:
@@ -87,29 +163,109 @@ def get(
87163
if not self.use_cache:
88164
return None
89165

90-
cache_key = self._get_cache_key(messages, output_model, backend, generation_params)
166+
cache_key = self._get_cache_key(messages, output_model, generation_params)
167+
168+
# First check in-memory cache
169+
memory_result = self._check_memory_cache(cache_key, output_model)
170+
if memory_result is not None:
171+
return memory_result
172+
173+
# Fallback to disk cache
174+
return self._load_from_disk(cache_key, output_model)
175+
176+
def set(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any], result: T) -> None:
177+
"""Cache the result.
178+
179+
Args:
180+
messages: List of messages to send to the model.
181+
output_model: Pydantic model class to parse the response into.
182+
backend: Backend to use for structured output.
183+
generation_params: Generation parameters.
184+
result: The result to cache.
185+
"""
186+
if not self.use_cache:
187+
return
188+
189+
cache_key = self._get_cache_key(messages, output_model, generation_params)
190+
191+
# Store in memory cache
192+
self._memory_cache[cache_key] = result
193+
194+
# Store in disk cache
195+
self._save_to_disk(cache_key, result)
196+
logger.debug("Cached structured output for key: %s (memory and disk)", cache_key)
197+
198+
async def _load_from_disk_async(self, cache_key: str, output_model: type[T]) -> T | None:
199+
"""Load cached result from disk asynchronously.
200+
201+
Args:
202+
cache_key: The cache key to look up.
203+
output_model: Pydantic model class to parse the response into.
204+
205+
Returns:
206+
Cached result if available and valid, None otherwise.
207+
"""
91208
cache_path = _get_structured_output_cache_path(cache_key)
92209

93210
if cache_path.exists():
94211
try:
95-
cached_data = PydanticModelDumper.load(cache_path)
212+
cached_data = await PydanticModelDumper.load_async(cache_path)
96213

97214
if isinstance(cached_data, output_model):
98-
logger.debug("Using cached structured output for key: %s", cache_key)
215+
logger.debug("Using cached structured output from disk for key: %s", cache_key)
216+
# Add to memory cache for future access
217+
self._memory_cache[cache_key] = cached_data
99218
return cached_data
100219

101-
logger.warning("Cached data type mismatch, removing invalid cache")
220+
logger.warning("Cached data type mismatch on disk, removing invalid cache")
102221
cache_path.unlink()
103222
except (ValidationError, ImportError) as e:
104-
logger.warning("Failed to load cached structured output: %s", e)
223+
logger.warning("Failed to load cached structured output from disk: %s", e)
105224
cache_path.unlink(missing_ok=True)
106225

107226
return None
108227

109-
def set(
110-
self, messages: list[Message], output_model: type[T], backend: str, generation_params: dict[str, Any], result: T
228+
async def _save_to_disk_async(self, cache_key: str, result: T) -> None:
229+
"""Save result to disk cache asynchronously.
230+
231+
Args:
232+
cache_key: The cache key to use.
233+
result: The result to cache.
234+
"""
235+
cache_path = _get_structured_output_cache_path(cache_key)
236+
cache_path.parent.mkdir(parents=True, exist_ok=True)
237+
await PydanticModelDumper.dump_async(result, cache_path, exists_ok=True)
238+
239+
async def get_async(
240+
self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]
241+
) -> T | None:
242+
"""Get cached result if available (async version).
243+
244+
Args:
245+
messages: List of messages to send to the model.
246+
output_model: Pydantic model class to parse the response into.
247+
generation_params: Generation parameters.
248+
249+
Returns:
250+
Cached result if available, None otherwise.
251+
"""
252+
if not self.use_cache:
253+
return None
254+
255+
cache_key = self._get_cache_key(messages, output_model, generation_params)
256+
257+
# First check in-memory cache
258+
memory_result = self._check_memory_cache(cache_key, output_model)
259+
if memory_result is not None:
260+
return memory_result
261+
262+
# Fallback to disk cache
263+
return await self._load_from_disk_async(cache_key, output_model)
264+
265+
async def set_async(
266+
self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any], result: T
111267
) -> None:
112-
"""Cache the result.
268+
"""Cache the result (async version).
113269
114270
Args:
115271
messages: List of messages to send to the model.
@@ -121,9 +277,11 @@ def set(
121277
if not self.use_cache:
122278
return
123279

124-
cache_key = self._get_cache_key(messages, output_model, backend, generation_params)
125-
cache_path = _get_structured_output_cache_path(cache_key)
280+
cache_key = self._get_cache_key(messages, output_model, generation_params)
126281

127-
cache_path.parent.mkdir(parents=True, exist_ok=True)
128-
PydanticModelDumper.dump(result, cache_path, exists_ok=True)
129-
logger.debug("Cached structured output for key: %s", cache_key)
282+
# Store in memory cache
283+
self._memory_cache[cache_key] = result
284+
285+
# Store in disk cache
286+
await self._save_to_disk_async(cache_key, result)
287+
logger.debug("Cached structured output for key: %s (memory and disk)", cache_key)

0 commit comments

Comments
 (0)