22
33import json
44import logging
5+ from concurrent .futures import ThreadPoolExecutor , as_completed
56from pathlib import Path
67from typing import Any , TypeVar
78
@@ -41,14 +42,19 @@ def _get_structured_output_cache_path(dirname: str) -> Path:
4142class StructuredOutputCache :
4243 """Cache for structured output results."""
4344
44- def __init__ (self , use_cache : bool = True ) -> None :
45+ def __init__ (self , use_cache : bool = True , max_workers : int | None = None , batch_size : int = 100 ) -> None :
4546 """Initialize the cache.
4647
4748 Args:
4849 use_cache: Whether to use caching.
50+ max_workers: Maximum number of worker threads for parallel loading.
51+ If None, uses min(32, os.cpu_count() + 4).
52+ batch_size: Number of cache files to process in each batch.
4953 """
5054 self .use_cache = use_cache
5155 self ._memory_cache : dict [str , BaseModel ] = {}
56+ self .max_workers = max_workers
57+ self .batch_size = batch_size
5258
5359 if self .use_cache :
5460 self ._load_existing_cache ()
@@ -60,16 +66,59 @@ def _load_existing_cache(self) -> None:
6066 if not cache_dir .exists ():
6167 return
6268
63- for cache_file in cache_dir .iterdir ():
64- if cache_file .is_file ():
65- try :
66- cached_data = PydanticModelDumper .load (cache_file )
67- if isinstance (cached_data , BaseModel ):
68- self ._memory_cache [cache_file .name ] = cached_data
69- logger .debug ("Loaded cached item into memory: %s" , cache_file .name )
70- except (ValidationError , ImportError ) as e :
71- logger .warning ("Failed to load cached item %s: %s" , cache_file .name , e )
72- cache_file .unlink (missing_ok = True )
69+ # Get all cache files to process
70+ cache_files = [f for f in cache_dir .iterdir () if f .is_file ()]
71+
72+ if not cache_files :
73+ return
74+
75+ logger .debug ("Loading %d cache files in batches of %d" , len (cache_files ), self .batch_size )
76+
77+ # Process cache files in batches to avoid resource exhaustion
78+ with ThreadPoolExecutor (max_workers = self .max_workers ) as executor :
79+ self ._load_cache_batch (executor , cache_files )
80+
81+ logger .debug ("Finished loading cache, %d items in memory" , len (self ._memory_cache ))
82+
83+ def _load_cache_batch (self , executor : ThreadPoolExecutor , cache_files : list [Path ]) -> None :
84+ """Load cache files in batches using the provided executor.
85+
86+ Args:
87+ executor: ThreadPoolExecutor to use for parallel processing.
88+ cache_files: List of cache files to load.
89+ """
90+ for i in range (0 , len (cache_files ), self .batch_size ):
91+ batch = cache_files [i : i + self .batch_size ]
92+
93+ # Submit batch of cache loading tasks
94+ futures = [executor .submit (self ._load_single_cache_file , cache_file ) for cache_file in batch ]
95+
96+ # Process completed tasks in this batch
97+ for future in as_completed (futures ):
98+ result = future .result ()
99+ if result is not None :
100+ filename , cached_data = result
101+ self ._memory_cache [filename ] = cached_data
102+ logger .debug ("Loaded cached item into memory: %s" , filename )
103+
104+ def _load_single_cache_file (self , cache_file : Path ) -> tuple [str , BaseModel ] | None :
105+ """Load a single cache file and return the result.
106+
107+ Args:
108+ cache_file: Path to the cache file to load.
109+
110+ Returns:
111+ Tuple of (filename, cached_data) if successful, None if failed.
112+ """
113+ try :
114+ cached_data = PydanticModelDumper .load (cache_file )
115+ except (ValidationError , ImportError ) as e :
116+ logger .warning ("Failed to load cached item %s: %s" , cache_file .name , e )
117+ cache_file .unlink (missing_ok = True )
118+ else :
119+ return cache_file .name , cached_data
120+
121+ return None
73122
74123 def _get_cache_key (self , messages : list [Message ], output_model : type [T ], generation_params : dict [str , Any ]) -> str :
75124 """Generate a cache key for the given parameters.
0 commit comments