Skip to content

Commit f76bc6f

Browse files
authored
Add general API for by-animal model modifiers and prompt caching (#201)
* Add initial implementation for generic model modifier * Implement generic `ByAnimalModifier` * Move modifiers to their own module * Implement cached settings * Add cli interface for cache * Add tests for clearing cache * Implement cache manager in picker * Add missing doc strings
1 parent f1adc66 commit f76bc6f

File tree

9 files changed

+1058
-38
lines changed

9 files changed

+1058
-38
lines changed

src/clabe/cache_manager.py

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
import logging
2+
import threading
3+
from enum import Enum
4+
from pathlib import Path
5+
from typing import Any, ClassVar, Generic, TypeVar
6+
7+
from pydantic import BaseModel, Field
8+
from pydantic_settings import BaseSettings, CliApp, CliSubCommand
9+
10+
from .constants import TMP_DIR
11+
12+
logger = logging.getLogger(__name__)
13+
14+
T = TypeVar("T")
15+
16+
17+
class SyncStrategy(str, Enum):
18+
"""Strategy for syncing cache to disk."""
19+
20+
MANUAL = "manual" # Only save when explicitly called
21+
AUTO = "auto" # Save after every modification
22+
23+
24+
class CachedSettings(BaseModel, Generic[T]):
25+
"""
26+
Manages a cache of values with a configurable history limit.
27+
28+
When a new value is added and the cache is full, the oldest value is removed.
29+
30+
Attributes:
31+
values: List of cached values, newest first
32+
max_history: Maximum number of items to retain in cache
33+
34+
Example:
35+
>>> cache = CachedSettings[str](max_history=3)
36+
>>> cache.add("first")
37+
>>> cache.add("second")
38+
>>> cache.get_all()
39+
['second', 'first']
40+
"""
41+
42+
values: list[T] = Field(default_factory=list)
43+
max_history: int = Field(default=5, gt=0)
44+
45+
def add(self, value: T) -> None:
46+
"""
47+
Add a new value to the cache.
48+
49+
If the value already exists, it's moved to the front.
50+
If the cache is full, the oldest value is removed.
51+
52+
Args:
53+
value: The value to add to the cache
54+
"""
55+
if value in self.values:
56+
self.values.remove(value)
57+
self.values.insert(0, value)
58+
59+
if len(self.values) > self.max_history:
60+
self.values = self.values[: self.max_history]
61+
62+
def get_all(self) -> list[T]:
63+
"""
64+
Get all cached values.
65+
66+
Returns:
67+
List of all cached values, newest first
68+
"""
69+
return self.values.copy()
70+
71+
def get_latest(self) -> T | None:
72+
"""
73+
Get the most recently added value.
74+
75+
Returns:
76+
The latest value, or None if cache is empty
77+
"""
78+
return self.values[0] if self.values else None
79+
80+
def clear(self) -> None:
81+
"""Clear all values from the cache."""
82+
self.values = []
83+
84+
85+
class CacheData(BaseModel):
86+
"""Pydantic model for cache serialization."""
87+
88+
caches: dict[str, CachedSettings[Any]] = Field(default_factory=dict)
89+
90+
91+
class CacheManager:
92+
"""
93+
Thread-safe singleton cache manager with multiple named caches.
94+
95+
Uses Pydantic for proper serialization/deserialization with automatic
96+
disk synchronization support. All operations are thread-safe.
97+
98+
Example:
99+
>>> # Get singleton instance with manual sync (default)
100+
>>> manager = CacheManager.get_instance()
101+
>>> manager.add_to_cache("subjects", "mouse_001")
102+
>>> manager.save() # Explicitly save
103+
>>>
104+
>>> # Get instance with auto sync - saves after every change
105+
>>> manager = CacheManager.get_instance(sync_strategy=SyncStrategy.AUTO)
106+
>>> manager.add_to_cache("subjects", "mouse_002") # Automatically saved
107+
>>>
108+
>>> # Custom path
109+
>>> manager = CacheManager.get_instance(cache_path="custom/cache.json")
110+
"""
111+
112+
_instance: ClassVar["CacheManager | None"] = None
113+
_lock: ClassVar[threading.RLock] = threading.RLock()
114+
115+
def __init__(
116+
self,
117+
cache_path: Path | str | None = None,
118+
sync_strategy: SyncStrategy = SyncStrategy.AUTO,
119+
) -> None:
120+
"""
121+
Initialize a CacheManager instance.
122+
123+
Args:
124+
cache_path: Path to cache file. If None, uses default location.
125+
sync_strategy: Strategy for syncing to disk (MANUAL or AUTO)
126+
"""
127+
self.caches: dict[str, CachedSettings[Any]] = {}
128+
self.sync_strategy: SyncStrategy = sync_strategy
129+
self.cache_path: Path = Path(cache_path) if cache_path else Path(TMP_DIR) / ".cache_manager.json"
130+
self._instance_lock: threading.RLock = threading.RLock()
131+
132+
@classmethod
133+
def get_instance(
134+
cls,
135+
cache_path: Path | str | None = None,
136+
sync_strategy: SyncStrategy = SyncStrategy.AUTO,
137+
reset: bool = False,
138+
) -> "CacheManager":
139+
"""
140+
Get the singleton instance of CacheManager (thread-safe).
141+
142+
Args:
143+
cache_path: Path to cache file. If None, uses default location.
144+
sync_strategy: Strategy for syncing to disk (MANUAL or AUTO)
145+
reset: If True, reset the singleton and create a new instance
146+
147+
Returns:
148+
The singleton CacheManager instance
149+
"""
150+
with cls._lock:
151+
if reset or cls._instance is None:
152+
if cache_path is None:
153+
cache_path = Path(TMP_DIR) / ".cache_manager.json"
154+
else:
155+
cache_path = Path(cache_path)
156+
157+
instance = cls(cache_path=cache_path, sync_strategy=sync_strategy)
158+
159+
if cache_path.exists():
160+
try:
161+
with cache_path.open("r", encoding="utf-8") as f:
162+
cache_data = CacheData.model_validate_json(f.read())
163+
instance.caches = cache_data.caches
164+
except Exception as e:
165+
logger.warning(f"Cache file {cache_path} is corrupted: {e}. Creating new instance.")
166+
167+
cls._instance = instance
168+
169+
return cls._instance
170+
171+
def _auto_save(self) -> None:
172+
"""Save to disk if auto-sync is enabled (caller must hold lock)."""
173+
if self.sync_strategy == SyncStrategy.AUTO:
174+
self._save_unlocked()
175+
176+
def _save_unlocked(self) -> None:
177+
"""Internal save method without locking (caller must hold lock)."""
178+
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
179+
cache_data = CacheData(caches=self.caches)
180+
with self.cache_path.open("w", encoding="utf-8") as f:
181+
f.write(cache_data.model_dump_json(indent=2))
182+
183+
def register_cache(self, name: str, max_history: int = 5) -> None:
184+
"""
185+
Register a new cache with a specific history limit (thread-safe).
186+
187+
Args:
188+
name: Unique name for the cache
189+
max_history: Maximum number of items to retain
190+
"""
191+
with self._instance_lock:
192+
if name not in self.caches:
193+
self.caches[name] = CachedSettings(max_history=max_history)
194+
self._auto_save()
195+
196+
def add_to_cache(self, name: str, value: Any) -> None:
197+
"""
198+
Add a value to a named cache (thread-safe).
199+
200+
Args:
201+
name: Name of the cache
202+
value: Value to add
203+
204+
Raises:
205+
KeyError: If cache name is not registered
206+
"""
207+
with self._instance_lock:
208+
if name not in self.caches:
209+
self.caches[name] = CachedSettings(max_history=5)
210+
211+
cache = self.caches[name]
212+
213+
# we remove it first to avoid duplicates
214+
if value in cache.values:
215+
cache.values.remove(value)
216+
# but add it to the front
217+
cache.values.insert(0, value)
218+
219+
if len(cache.values) > cache.max_history:
220+
cache.values = cache.values[: cache.max_history]
221+
222+
self._auto_save()
223+
224+
def get_cache(self, name: str) -> list[Any]:
225+
"""
226+
Get all values from a named cache (thread-safe).
227+
228+
Args:
229+
name: Name of the cache
230+
231+
Returns:
232+
List of cached values, newest first
233+
234+
Raises:
235+
KeyError: If cache name is not registered
236+
"""
237+
with self._instance_lock:
238+
if name not in self.caches:
239+
raise KeyError(f"Cache '{name}' not registered.")
240+
return self.caches[name].values.copy()
241+
242+
def try_get_cache(self, name: str) -> Any | None:
243+
"""Attempt to get all values from a named cache, returning None if not found."""
244+
try:
245+
return self.get_cache(name)
246+
except KeyError:
247+
return None
248+
249+
def get_latest(self, name: str) -> Any | None:
250+
"""
251+
Get the most recent value from a named cache (thread-safe).
252+
253+
Args:
254+
name: Name of the cache
255+
256+
Returns:
257+
The latest value, or None if cache is empty
258+
259+
Raises:
260+
KeyError: If cache name is not registered
261+
"""
262+
with self._instance_lock:
263+
values = self.get_cache(name)
264+
return values[0] if values else None
265+
266+
def clear_cache(self, name: str) -> None:
267+
"""
268+
Clear all values from a named cache (thread-safe).
269+
270+
Args:
271+
name: Name of the cache
272+
273+
Raises:
274+
KeyError: If cache name is not registered
275+
"""
276+
with self._instance_lock:
277+
if name not in self.caches:
278+
raise KeyError(f"Cache '{name}' not registered.")
279+
self.caches[name].values = []
280+
self._auto_save()
281+
282+
def clear_all_caches(self) -> None:
283+
"""Clear all caches (thread-safe)."""
284+
with self._instance_lock:
285+
self.caches = {}
286+
self._auto_save()
287+
288+
def save(self) -> None:
289+
"""
290+
Save all caches to disk using Pydantic serialization (thread-safe).
291+
292+
This method is called automatically if sync_strategy is AUTO,
293+
or can be called manually for MANUAL strategy.
294+
"""
295+
with self._instance_lock:
296+
self._save_unlocked()
297+
298+
299+
class _ListCacheCli(BaseSettings):
300+
"""CLI command to list all caches and their contents."""
301+
302+
def cli_cmd(self):
303+
"""Run the list cache CLI command."""
304+
manager = CacheManager.get_instance()
305+
if not manager.caches:
306+
logger.info("No caches available.")
307+
for name, cache in manager.caches.items():
308+
logger.info(f"Cache '{name}': {cache.values}")
309+
310+
311+
class _ResetCacheCli(BaseSettings):
312+
"""CLI command to reset all caches."""
313+
314+
def cli_cmd(self):
315+
"""Run the reset cache CLI command."""
316+
CacheManager.get_instance().clear_all_caches()
317+
logger.info("All caches have been cleared.")
318+
319+
320+
class _CacheManagerCli(BaseSettings):
321+
"""CLI application wrapper for the RPC server."""
322+
323+
reset: CliSubCommand[_ResetCacheCli]
324+
list: CliSubCommand[_ListCacheCli]
325+
326+
def cli_cmd(self):
327+
"""Run the cache manager CLI."""
328+
CliApp.run_subcommand(self)

src/clabe/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from pydantic_settings import BaseSettings, CliApp, CliSubCommand
22

3+
from .cache_manager import _CacheManagerCli
34
from .xml_rpc._server import _XmlRpcServerStartCli
45

56

67
class CliAppSettings(BaseSettings, cli_prog_name="clabe", cli_kebab_case=True):
78
"""CLI application settings."""
89

910
xml_rpc_server: CliSubCommand[_XmlRpcServerStartCli]
11+
cache: CliSubCommand[_CacheManagerCli]
1012

1113
def cli_cmd(self):
1214
"""Run the selected subcommand."""

src/clabe/pickers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from ._by_animal_modifier import ByAnimalModifier
12
from .default_behavior import DefaultBehaviorPicker, DefaultBehaviorPickerSettings
23

34
__all__ = [
45
"DefaultBehaviorPicker",
56
"DefaultBehaviorPickerSettings",
7+
"ByAnimalModifier",
68
]

0 commit comments

Comments
 (0)