Skip to content

Commit ac958b4

Browse files
authored
Feat/faster cache loading from disk (#241)
* implement threadpool loading * fix typing
1 parent 03c2fbc commit ac958b4

File tree

1 file changed

+60
-11
lines changed

1 file changed

+60
-11
lines changed

autointent/generation/_cache.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import logging
5+
from concurrent.futures import ThreadPoolExecutor, as_completed
56
from pathlib import Path
67
from typing import Any, TypeVar
78

@@ -41,14 +42,19 @@ def _get_structured_output_cache_path(dirname: str) -> Path:
4142
class StructuredOutputCache:
4243
"""Cache for structured output results."""
4344

44-
def __init__(self, use_cache: bool = True) -> None:
45+
def __init__(self, use_cache: bool = True, max_workers: int | None = None, batch_size: int = 100) -> None:
4546
"""Initialize the cache.
4647
4748
Args:
4849
use_cache: Whether to use caching.
50+
max_workers: Maximum number of worker threads for parallel loading.
51+
If None, uses min(32, os.cpu_count() + 4).
52+
batch_size: Number of cache files to process in each batch.
4953
"""
5054
self.use_cache = use_cache
5155
self._memory_cache: dict[str, BaseModel] = {}
56+
self.max_workers = max_workers
57+
self.batch_size = batch_size
5258

5359
if self.use_cache:
5460
self._load_existing_cache()
@@ -60,16 +66,59 @@ def _load_existing_cache(self) -> None:
6066
if not cache_dir.exists():
6167
return
6268

63-
for cache_file in cache_dir.iterdir():
64-
if cache_file.is_file():
65-
try:
66-
cached_data = PydanticModelDumper.load(cache_file)
67-
if isinstance(cached_data, BaseModel):
68-
self._memory_cache[cache_file.name] = cached_data
69-
logger.debug("Loaded cached item into memory: %s", cache_file.name)
70-
except (ValidationError, ImportError) as e:
71-
logger.warning("Failed to load cached item %s: %s", cache_file.name, e)
72-
cache_file.unlink(missing_ok=True)
69+
# Get all cache files to process
70+
cache_files = [f for f in cache_dir.iterdir() if f.is_file()]
71+
72+
if not cache_files:
73+
return
74+
75+
logger.debug("Loading %d cache files in batches of %d", len(cache_files), self.batch_size)
76+
77+
# Process cache files in batches to avoid resource exhaustion
78+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
79+
self._load_cache_batch(executor, cache_files)
80+
81+
logger.debug("Finished loading cache, %d items in memory", len(self._memory_cache))
82+
83+
def _load_cache_batch(self, executor: ThreadPoolExecutor, cache_files: list[Path]) -> None:
84+
"""Load cache files in batches using the provided executor.
85+
86+
Args:
87+
executor: ThreadPoolExecutor to use for parallel processing.
88+
cache_files: List of cache files to load.
89+
"""
90+
for i in range(0, len(cache_files), self.batch_size):
91+
batch = cache_files[i : i + self.batch_size]
92+
93+
# Submit batch of cache loading tasks
94+
futures = [executor.submit(self._load_single_cache_file, cache_file) for cache_file in batch]
95+
96+
# Process completed tasks in this batch
97+
for future in as_completed(futures):
98+
result = future.result()
99+
if result is not None:
100+
filename, cached_data = result
101+
self._memory_cache[filename] = cached_data
102+
logger.debug("Loaded cached item into memory: %s", filename)
103+
104+
def _load_single_cache_file(self, cache_file: Path) -> tuple[str, BaseModel] | None:
105+
"""Load a single cache file and return the result.
106+
107+
Args:
108+
cache_file: Path to the cache file to load.
109+
110+
Returns:
111+
Tuple of (filename, cached_data) if successful, None if failed.
112+
"""
113+
try:
114+
cached_data = PydanticModelDumper.load(cache_file)
115+
except (ValidationError, ImportError) as e:
116+
logger.warning("Failed to load cached item %s: %s", cache_file.name, e)
117+
cache_file.unlink(missing_ok=True)
118+
else:
119+
return cache_file.name, cached_data
120+
121+
return None
73122

74123
def _get_cache_key(self, messages: list[Message], output_model: type[T], generation_params: dict[str, Any]) -> str:
75124
"""Generate a cache key for the given parameters.

0 commit comments

Comments
 (0)