Skip to content

Commit eb2cee3

Browse files
feat: add LRU cache memory management to ServiceContainer
- Implement per-service-type LRU cache with configurable limits (default: 5) - Add automatic eviction of oldest cached services when capacity is reached - Support service cleanup during eviction for proper resource management - Add cache statistics monitoring via get_cache_stats() - Support selective cache clearing by service type - Maintain backward compatibility with existing service container API Enhanced test coverage: - Add comprehensive tests for LRU eviction behavior - Test cache access order updates and proper LRU semantics - Verify service cleanup during eviction - Test cache isolation between service types - Add cache statistics and monitoring tests - Validate selective vs. full cache clearing
1 parent 9de20f8 commit eb2cee3

File tree

2 files changed

+343
-57
lines changed

2 files changed

+343
-57
lines changed

core/quivr_core/rag/langgraph_framework/services/service_container.py

Lines changed: 103 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Dict, Type, TypeVar, Any, Optional
22
from abc import ABC, abstractmethod
33
import logging
4+
from collections import OrderedDict as OrderedDictImpl
45
from quivr_core.rag.entities.config import LLMEndpointConfig, WorkflowConfig
56
from quivr_core.rag.langgraph_framework.entities.retrieval_service_config import (
67
RetrievalServiceConfig,
@@ -77,15 +78,17 @@ def get_config_type(self) -> Optional[Type]:
7778

7879

7980
class ServiceContainer:
80-
"""Dependency injection container for services."""
81+
"""Dependency injection container for services with LRU cache per service type."""
8182

82-
def __init__(self, vector_store=None):
83-
self._services: Dict[tuple, Any] = {} # Changed to support tuple keys
83+
def __init__(self, vector_store=None, max_cache_per_service: int = 5):
84+
# Use OrderedDict for LRU cache behavior per service type
85+
self._services: Dict[Type, OrderedDictImpl[str, Any]] = {}
8486
self._factories: Dict[Type, ServiceFactory] = {
8587
LLMService: LLMServiceFactory(),
8688
ToolService: ToolServiceFactory(),
8789
RAGPromptService: PromptServiceFactory(),
8890
}
91+
self._max_cache_per_service = max_cache_per_service
8992

9093
# Register RetrieverService factory if vector_store is provided
9194
if vector_store:
@@ -101,49 +104,60 @@ def register_vector_store(self, vector_store):
101104
"""Register a vector store and enable RetrievalService."""
102105
self._factories[RetrievalService] = RetrievalServiceFactory(vector_store)
103106

107+
def _get_service_cache(self, service_type: Type) -> OrderedDictImpl[str, Any]:
108+
"""Get or create the cache for a specific service type."""
109+
if service_type not in self._services:
110+
self._services[service_type] = OrderedDictImpl()
111+
return self._services[service_type]
112+
113+
def _evict_oldest_if_needed(self, service_cache: OrderedDictImpl[str, Any]) -> None:
114+
"""Remove the oldest cached service if cache is at capacity."""
115+
if len(service_cache) >= self._max_cache_per_service:
116+
oldest_key = next(iter(service_cache))
117+
removed_service = service_cache.pop(oldest_key)
118+
logger.debug(f"Evicted oldest cached service: {oldest_key}")
119+
# Clean up the service if it has cleanup methods
120+
if hasattr(removed_service, "cleanup"):
121+
try:
122+
removed_service.cleanup()
123+
except Exception as e:
124+
logger.warning(f"Error cleaning up evicted service: {e}")
125+
104126
def get_service(self, service_type: Type[T], config: Optional[Any] = None) -> T:
105-
"""Get or create a service instance with config change detection."""
127+
"""Get or create a service instance with LRU cache per service type."""
106128
import hashlib
107129
import json
108130

109-
# If no config is provided, use singleton pattern
110-
if config is None:
111-
cache_key = (service_type, "singleton")
112-
if cache_key not in self._services:
113-
if service_type not in self._factories:
114-
raise ValueError(
115-
f"No factory registered for service type: {service_type}"
116-
)
117-
118-
factory = self._factories[service_type]
119-
logger.debug(f"Creating singleton instance of {service_type.__name__}")
120-
service = factory.create(None)
121-
self._services[cache_key] = service
122-
123-
return self._services[cache_key]
124-
125-
# Create config hash for change detection when config is provided
126-
config_dict = (
127-
config.model_dump() if hasattr(config, "model_dump") else str(config)
128-
)
129-
config_hash = hashlib.md5(
130-
json.dumps(config_dict, sort_keys=True).encode()
131-
).hexdigest()
132-
133-
# Check if we need to recreate the service
134-
cache_key = (service_type, config_hash)
135-
if (
136-
cache_key not in self._services
137-
or self._config_hashes.get(service_type) != config_hash
138-
):
139-
if service_type not in self._factories:
140-
raise ValueError(
141-
f"No factory registered for service type: {service_type}"
142-
)
131+
if service_type not in self._factories:
132+
raise ValueError(f"No factory registered for service type: {service_type}")
143133

144-
factory = self._factories[service_type]
134+
# Get the cache for this service type
135+
service_cache = self._get_service_cache(service_type)
145136

146-
# Validate config type (skip validation if factory doesn't specify a config type)
137+
# Determine cache key
138+
if config is None:
139+
cache_key = "singleton"
140+
else:
141+
config_dict = (
142+
config.model_dump() if hasattr(config, "model_dump") else str(config)
143+
)
144+
cache_key = hashlib.md5(
145+
json.dumps(config_dict, sort_keys=True).encode()
146+
).hexdigest()
147+
148+
# Check if service exists in cache
149+
if cache_key in service_cache:
150+
# Move to end (most recently used)
151+
service = service_cache.pop(cache_key)
152+
service_cache[cache_key] = service
153+
logger.debug(f"Retrieved cached {service_type.__name__} instance")
154+
return service
155+
156+
# Service not in cache, create new instance
157+
factory = self._factories[service_type]
158+
159+
# Validate config type (skip validation if factory doesn't specify a config type)
160+
if config is not None:
147161
expected_config_type = factory.get_config_type()
148162
if expected_config_type is not None and not isinstance(
149163
config, expected_config_type
@@ -152,14 +166,52 @@ def get_service(self, service_type: Type[T], config: Optional[Any] = None) -> T:
152166
f"Expected config of type {expected_config_type}, got {type(config)}"
153167
)
154168

155-
logger.debug(f"Creating new instance of {service_type.__name__}")
156-
service = factory.create(config)
157-
self._services[cache_key] = service
158-
self._config_hashes[service_type] = config_hash
159-
160-
return self._services[cache_key]
161-
162-
def clear_cache(self):
163-
"""Clear all cached services."""
164-
self._services.clear()
165-
self._config_hashes.clear()
169+
# Evict oldest if at capacity
170+
self._evict_oldest_if_needed(service_cache)
171+
172+
# Create new service
173+
logger.debug(f"Creating new {service_type.__name__} instance")
174+
service = factory.create(config)
175+
service_cache[cache_key] = service
176+
177+
return service
178+
179+
def clear_cache(self, service_type: Optional[Type] = None):
180+
"""Clear cached services. If service_type is None, clear all caches."""
181+
if service_type is None:
182+
# Clean up all services before clearing
183+
for service_cache in self._services.values():
184+
for service in service_cache.values():
185+
if hasattr(service, "cleanup"):
186+
try:
187+
service.cleanup()
188+
except Exception as e:
189+
logger.warning(
190+
f"Error cleaning up service during cache clear: {e}"
191+
)
192+
self._services.clear()
193+
self._config_hashes.clear()
194+
else:
195+
# Clear cache for specific service type
196+
if service_type in self._services:
197+
service_cache = self._services[service_type]
198+
for service in service_cache.values():
199+
if hasattr(service, "cleanup"):
200+
try:
201+
service.cleanup()
202+
except Exception as e:
203+
logger.warning(
204+
f"Error cleaning up {service_type.__name__} service: {e}"
205+
)
206+
service_cache.clear()
207+
208+
def get_cache_stats(self) -> Dict[str, Dict[str, Any]]:
209+
"""Get cache statistics for monitoring."""
210+
stats = {}
211+
for service_type, service_cache in self._services.items():
212+
stats[service_type.__name__] = {
213+
"cached_instances": len(service_cache),
214+
"max_capacity": self._max_cache_per_service,
215+
"cache_keys": list(service_cache.keys()),
216+
}
217+
return stats

0 commit comments

Comments
 (0)