Skip to content

Commit f18d9eb

Browse files
committed
add async disk I/O
1 parent 183fe8f commit f18d9eb

File tree

4 files changed

+167
-15
lines changed

4 files changed

+167
-15
lines changed

autointent/_dump_tools/unit_dumpers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55
from typing import Any, TypeVar
66

7+
import aiofiles
78
import joblib
89
import numpy as np
910
import numpy.typing as npt
@@ -163,6 +164,15 @@ def dump(obj: BaseModel, path: Path, exists_ok: bool) -> None:
163164
with (path / "model_dump.json").open("w", encoding="utf-8") as file:
164165
json.dump(obj.model_dump(), file, ensure_ascii=False, indent=4)
165166

167+
@staticmethod
168+
async def dump_async(obj: BaseModel, path: Path, exists_ok: bool) -> None:
169+
class_info = {"name": obj.__class__.__name__, "module": obj.__class__.__module__}
170+
path.mkdir(parents=True, exist_ok=exists_ok)
171+
async with aiofiles.open(path / "class_info.json", mode="w", encoding="utf-8") as file:
172+
await file.write(json.dumps(class_info, ensure_ascii=False, indent=4))
173+
async with aiofiles.open(path / "model_dump.json", mode="w", encoding="utf-8") as file:
174+
await file.write(json.dumps(obj.model_dump(), ensure_ascii=False, indent=4))
175+
166176
@staticmethod
167177
def load(path: Path, **kwargs: Any) -> BaseModel: # noqa: ANN401, ARG004
168178
with (path / "model_dump.json").open("r", encoding="utf-8") as file:
@@ -175,6 +185,20 @@ def load(path: Path, **kwargs: Any) -> BaseModel: # noqa: ANN401, ARG004
175185
model_type = getattr(model_type, class_info["name"])
176186
return model_type.model_validate(content) # type: ignore[no-any-return]
177187

188+
@staticmethod
189+
async def load_async(path: Path, **kwargs: Any) -> BaseModel: # noqa: ANN401, ARG004
190+
async with aiofiles.open(path / "model_dump.json", encoding="utf-8") as file:
191+
content_str = await file.read()
192+
content = json.loads(content_str)
193+
194+
async with aiofiles.open(path / "class_info.json", encoding="utf-8") as file:
195+
class_info_str = await file.read()
196+
class_info = json.loads(class_info_str)
197+
198+
model_type = importlib.import_module(class_info["module"])
199+
model_type = getattr(model_type, class_info["name"])
200+
return model_type.model_validate(content) # type: ignore[no-any-return]
201+
178202
@classmethod
179203
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
180204
return isinstance(obj, BaseModel)

autointent/generation/_cache.py

Lines changed: 140 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,23 +88,16 @@ def _get_cache_key(self, messages: list[Message], output_model: type[T], generat
8888
hasher.update(json.dumps(generation_params))
8989
return hasher.hexdigest()
9090

91-
def get(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]) -> T | None:
92-
"""Get cached result if available.
91+
def _check_memory_cache(self, cache_key: str, output_model: type[T]) -> T | None:
92+
"""Check if the result is available in memory cache.
9393
9494
Args:
95-
messages: List of messages to send to the model.
95+
cache_key: The cache key to look up.
9696
output_model: Pydantic model class to parse the response into.
97-
generation_params: Generation parameters.
9897
9998
Returns:
100-
Cached result if available, None otherwise.
99+
Cached result if available and valid, None otherwise.
101100
"""
102-
if not self.use_cache:
103-
return None
104-
105-
cache_key = self._get_cache_key(messages, output_model, generation_params)
106-
107-
# First check in-memory cache
108101
if cache_key in self._memory_cache:
109102
cached_data = self._memory_cache[cache_key]
110103
if isinstance(cached_data, output_model):
@@ -113,8 +106,18 @@ def get(self, messages: list[Message], output_model: type[T], generation_params:
113106
# Type mismatch, remove from memory cache
114107
del self._memory_cache[cache_key]
115108
logger.warning("Cached data type mismatch in memory, removing invalid cache")
109+
return None
116110

117-
# Fallback to disk cache
111+
def _load_from_disk(self, cache_key: str, output_model: type[T]) -> T | None:
112+
"""Load cached result from disk.
113+
114+
Args:
115+
cache_key: The cache key to look up.
116+
output_model: Pydantic model class to parse the response into.
117+
118+
Returns:
119+
Cached result if available and valid, None otherwise.
120+
"""
118121
cache_path = _get_structured_output_cache_path(cache_key)
119122

120123
if cache_path.exists():
@@ -135,6 +138,41 @@ def get(self, messages: list[Message], output_model: type[T], generation_params:
135138

136139
return None
137140

141+
def _save_to_disk(self, cache_key: str, result: T) -> None:
142+
"""Save result to disk cache.
143+
144+
Args:
145+
cache_key: The cache key to use.
146+
result: The result to cache.
147+
"""
148+
cache_path = _get_structured_output_cache_path(cache_key)
149+
cache_path.parent.mkdir(parents=True, exist_ok=True)
150+
PydanticModelDumper.dump(result, cache_path, exists_ok=True)
151+
152+
def get(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]) -> T | None:
153+
"""Get cached result if available.
154+
155+
Args:
156+
messages: List of messages to send to the model.
157+
output_model: Pydantic model class to parse the response into.
158+
generation_params: Generation parameters.
159+
160+
Returns:
161+
Cached result if available, None otherwise.
162+
"""
163+
if not self.use_cache:
164+
return None
165+
166+
cache_key = self._get_cache_key(messages, output_model, generation_params)
167+
168+
# First check in-memory cache
169+
memory_result = self._check_memory_cache(cache_key, output_model)
170+
if memory_result is not None:
171+
return memory_result
172+
173+
# Fallback to disk cache
174+
return self._load_from_disk(cache_key, output_model)
175+
138176
def set(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any], result: T) -> None:
139177
"""Cache the result.
140178
@@ -154,7 +192,96 @@ def set(self, messages: list[Message], output_model: type[T], generation_params:
154192
self._memory_cache[cache_key] = result
155193

156194
# Store in disk cache
195+
self._save_to_disk(cache_key, result)
196+
logger.debug("Cached structured output for key: %s (memory and disk)", cache_key)
197+
198+
async def _load_from_disk_async(self, cache_key: str, output_model: type[T]) -> T | None:
199+
"""Load cached result from disk asynchronously.
200+
201+
Args:
202+
cache_key: The cache key to look up.
203+
output_model: Pydantic model class to parse the response into.
204+
205+
Returns:
206+
Cached result if available and valid, None otherwise.
207+
"""
208+
cache_path = _get_structured_output_cache_path(cache_key)
209+
210+
if cache_path.exists():
211+
try:
212+
cached_data = await PydanticModelDumper.load_async(cache_path)
213+
214+
if isinstance(cached_data, output_model):
215+
logger.debug("Using cached structured output from disk for key: %s", cache_key)
216+
# Add to memory cache for future access
217+
self._memory_cache[cache_key] = cached_data
218+
return cached_data
219+
220+
logger.warning("Cached data type mismatch on disk, removing invalid cache")
221+
cache_path.unlink()
222+
except (ValidationError, ImportError) as e:
223+
logger.warning("Failed to load cached structured output from disk: %s", e)
224+
cache_path.unlink(missing_ok=True)
225+
226+
return None
227+
228+
async def _save_to_disk_async(self, cache_key: str, result: T) -> None:
229+
"""Save result to disk cache asynchronously.
230+
231+
Args:
232+
cache_key: The cache key to use.
233+
result: The result to cache.
234+
"""
157235
cache_path = _get_structured_output_cache_path(cache_key)
158236
cache_path.parent.mkdir(parents=True, exist_ok=True)
159-
PydanticModelDumper.dump(result, cache_path, exists_ok=True)
237+
await PydanticModelDumper.dump_async(result, cache_path, exists_ok=True)
238+
239+
async def get_async(
240+
self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]
241+
) -> T | None:
242+
"""Get cached result if available (async version).
243+
244+
Args:
245+
messages: List of messages to send to the model.
246+
output_model: Pydantic model class to parse the response into.
247+
generation_params: Generation parameters.
248+
249+
Returns:
250+
Cached result if available, None otherwise.
251+
"""
252+
if not self.use_cache:
253+
return None
254+
255+
cache_key = self._get_cache_key(messages, output_model, generation_params)
256+
257+
# First check in-memory cache
258+
memory_result = self._check_memory_cache(cache_key, output_model)
259+
if memory_result is not None:
260+
return memory_result
261+
262+
# Fallback to disk cache
263+
return await self._load_from_disk_async(cache_key, output_model)
264+
265+
async def set_async(
266+
self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any], result: T
267+
) -> None:
268+
"""Cache the result (async version).
269+
270+
Args:
271+
messages: List of messages to send to the model.
272+
output_model: Pydantic model class to parse the response into.
273+
backend: Backend to use for structured output.
274+
generation_params: Generation parameters.
275+
result: The result to cache.
276+
"""
277+
if not self.use_cache:
278+
return
279+
280+
cache_key = self._get_cache_key(messages, output_model, generation_params)
281+
282+
# Store in memory cache
283+
self._memory_cache[cache_key] = result
284+
285+
# Store in disk cache
286+
await self._save_to_disk_async(cache_key, result)
160287
logger.debug("Cached structured output for key: %s (memory and disk)", cache_key)

autointent/generation/_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ async def get_structured_output_async(
192192
Parsed response as an instance of the provided Pydantic model.
193193
"""
194194
# Check cache first
195-
cached_result = self.cache.get(messages, output_model, self.generation_params)
195+
cached_result = await self.cache.get_async(messages, output_model, self.generation_params)
196196
if cached_result is not None:
197197
return cached_result
198198

@@ -221,7 +221,7 @@ async def get_structured_output_async(
221221
raise RetriesExceededError(max_retries=max_retries, messages=current_messages)
222222

223223
# Cache the successful result
224-
self.cache.set(messages, output_model, self.generation_params, res)
224+
await self.cache.set_async(messages, output_model, self.generation_params, res)
225225

226226
return res
227227

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dependencies = [
4747
"peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)",
4848
"catboost (>=1.2.8,<2.0.0)",
4949
"aiometer (>=1.0.0,<2.0.0)",
50+
"aiofiles (>=24.1.0,<25.0.0)",
5051
]
5152

5253
[project.optional-dependencies]

0 commit comments

Comments
 (0)