diff --git a/src/clabe/cache_manager.py b/src/clabe/cache_manager.py new file mode 100644 index 00000000..5c1c9f45 --- /dev/null +++ b/src/clabe/cache_manager.py @@ -0,0 +1,328 @@ +import logging +import threading +from enum import Enum +from pathlib import Path +from typing import Any, ClassVar, Generic, TypeVar + +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings, CliApp, CliSubCommand + +from .constants import TMP_DIR + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class SyncStrategy(str, Enum): + """Strategy for syncing cache to disk.""" + + MANUAL = "manual" # Only save when explicitly called + AUTO = "auto" # Save after every modification + + +class CachedSettings(BaseModel, Generic[T]): + """ + Manages a cache of values with a configurable history limit. + + When a new value is added and the cache is full, the oldest value is removed. + + Attributes: + values: List of cached values, newest first + max_history: Maximum number of items to retain in cache + + Example: + >>> cache = CachedSettings[str](max_history=3) + >>> cache.add("first") + >>> cache.add("second") + >>> cache.get_all() + ['second', 'first'] + """ + + values: list[T] = Field(default_factory=list) + max_history: int = Field(default=5, gt=0) + + def add(self, value: T) -> None: + """ + Add a new value to the cache. + + If the value already exists, it's moved to the front. + If the cache is full, the oldest value is removed. + + Args: + value: The value to add to the cache + """ + if value in self.values: + self.values.remove(value) + self.values.insert(0, value) + + if len(self.values) > self.max_history: + self.values = self.values[: self.max_history] + + def get_all(self) -> list[T]: + """ + Get all cached values. + + Returns: + List of all cached values, newest first + """ + return self.values.copy() + + def get_latest(self) -> T | None: + """ + Get the most recently added value. + + Returns: + The latest value, or None if cache is empty + """ + return self.values[0] if self.values else None + + def clear(self) -> None: + """Clear all values from the cache.""" + self.values = [] + + +class CacheData(BaseModel): + """Pydantic model for cache serialization.""" + + caches: dict[str, CachedSettings[Any]] = Field(default_factory=dict) + + +class CacheManager: + """ + Thread-safe singleton cache manager with multiple named caches. + + Uses Pydantic for proper serialization/deserialization with automatic + disk synchronization support. All operations are thread-safe. + + Example: + >>> # Get singleton instance with manual sync (default) + >>> manager = CacheManager.get_instance() + >>> manager.add_to_cache("subjects", "mouse_001") + >>> manager.save() # Explicitly save + >>> + >>> # Get instance with auto sync - saves after every change + >>> manager = CacheManager.get_instance(sync_strategy=SyncStrategy.AUTO) + >>> manager.add_to_cache("subjects", "mouse_002") # Automatically saved + >>> + >>> # Custom path + >>> manager = CacheManager.get_instance(cache_path="custom/cache.json") + """ + + _instance: ClassVar["CacheManager | None"] = None + _lock: ClassVar[threading.RLock] = threading.RLock() + + def __init__( + self, + cache_path: Path | str | None = None, + sync_strategy: SyncStrategy = SyncStrategy.AUTO, + ) -> None: + """ + Initialize a CacheManager instance. + + Args: + cache_path: Path to cache file. If None, uses default location. + sync_strategy: Strategy for syncing to disk (MANUAL or AUTO) + """ + self.caches: dict[str, CachedSettings[Any]] = {} + self.sync_strategy: SyncStrategy = sync_strategy + self.cache_path: Path = Path(cache_path) if cache_path else Path(TMP_DIR) / ".cache_manager.json" + self._instance_lock: threading.RLock = threading.RLock() + + @classmethod + def get_instance( + cls, + cache_path: Path | str | None = None, + sync_strategy: SyncStrategy = SyncStrategy.AUTO, + reset: bool = False, + ) -> "CacheManager": + """ + Get the singleton instance of CacheManager (thread-safe). + + Args: + cache_path: Path to cache file. If None, uses default location. + sync_strategy: Strategy for syncing to disk (MANUAL or AUTO) + reset: If True, reset the singleton and create a new instance + + Returns: + The singleton CacheManager instance + """ + with cls._lock: + if reset or cls._instance is None: + if cache_path is None: + cache_path = Path(TMP_DIR) / ".cache_manager.json" + else: + cache_path = Path(cache_path) + + instance = cls(cache_path=cache_path, sync_strategy=sync_strategy) + + if cache_path.exists(): + try: + with cache_path.open("r", encoding="utf-8") as f: + cache_data = CacheData.model_validate_json(f.read()) + instance.caches = cache_data.caches + except Exception as e: + logger.warning(f"Cache file {cache_path} is corrupted: {e}. Creating new instance.") + + cls._instance = instance + + return cls._instance + + def _auto_save(self) -> None: + """Save to disk if auto-sync is enabled (caller must hold lock).""" + if self.sync_strategy == SyncStrategy.AUTO: + self._save_unlocked() + + def _save_unlocked(self) -> None: + """Internal save method without locking (caller must hold lock).""" + self.cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_data = CacheData(caches=self.caches) + with self.cache_path.open("w", encoding="utf-8") as f: + f.write(cache_data.model_dump_json(indent=2)) + + def register_cache(self, name: str, max_history: int = 5) -> None: + """ + Register a new cache with a specific history limit (thread-safe). + + Args: + name: Unique name for the cache + max_history: Maximum number of items to retain + """ + with self._instance_lock: + if name not in self.caches: + self.caches[name] = CachedSettings(max_history=max_history) + self._auto_save() + + def add_to_cache(self, name: str, value: Any) -> None: + """ + Add a value to a named cache (thread-safe). + + Args: + name: Name of the cache + value: Value to add + + Raises: + KeyError: If cache name is not registered + """ + with self._instance_lock: + if name not in self.caches: + self.caches[name] = CachedSettings(max_history=5) + + cache = self.caches[name] + + # we remove it first to avoid duplicates + if value in cache.values: + cache.values.remove(value) + # but add it to the front + cache.values.insert(0, value) + + if len(cache.values) > cache.max_history: + cache.values = cache.values[: cache.max_history] + + self._auto_save() + + def get_cache(self, name: str) -> list[Any]: + """ + Get all values from a named cache (thread-safe). + + Args: + name: Name of the cache + + Returns: + List of cached values, newest first + + Raises: + KeyError: If cache name is not registered + """ + with self._instance_lock: + if name not in self.caches: + raise KeyError(f"Cache '{name}' not registered.") + return self.caches[name].values.copy() + + def try_get_cache(self, name: str) -> Any | None: + """Attempt to get all values from a named cache, returning None if not found.""" + try: + return self.get_cache(name) + except KeyError: + return None + + def get_latest(self, name: str) -> Any | None: + """ + Get the most recent value from a named cache (thread-safe). + + Args: + name: Name of the cache + + Returns: + The latest value, or None if cache is empty + + Raises: + KeyError: If cache name is not registered + """ + with self._instance_lock: + values = self.get_cache(name) + return values[0] if values else None + + def clear_cache(self, name: str) -> None: + """ + Clear all values from a named cache (thread-safe). + + Args: + name: Name of the cache + + Raises: + KeyError: If cache name is not registered + """ + with self._instance_lock: + if name not in self.caches: + raise KeyError(f"Cache '{name}' not registered.") + self.caches[name].values = [] + self._auto_save() + + def clear_all_caches(self) -> None: + """Clear all caches (thread-safe).""" + with self._instance_lock: + self.caches = {} + self._auto_save() + + def save(self) -> None: + """ + Save all caches to disk using Pydantic serialization (thread-safe). + + This method is called automatically if sync_strategy is AUTO, + or can be called manually for MANUAL strategy. + """ + with self._instance_lock: + self._save_unlocked() + + +class _ListCacheCli(BaseSettings): + """CLI command to list all caches and their contents.""" + + def cli_cmd(self): + """Run the list cache CLI command.""" + manager = CacheManager.get_instance() + if not manager.caches: + logger.info("No caches available.") + for name, cache in manager.caches.items(): + logger.info(f"Cache '{name}': {cache.values}") + + +class _ResetCacheCli(BaseSettings): + """CLI command to reset all caches.""" + + def cli_cmd(self): + """Run the reset cache CLI command.""" + CacheManager.get_instance().clear_all_caches() + logger.info("All caches have been cleared.") + + +class _CacheManagerCli(BaseSettings): + """CLI application wrapper for the RPC server.""" + + reset: CliSubCommand[_ResetCacheCli] + list: CliSubCommand[_ListCacheCli] + + def cli_cmd(self): + """Run the cache manager CLI.""" + CliApp.run_subcommand(self) diff --git a/src/clabe/cli.py b/src/clabe/cli.py index 73484073..9b417978 100644 --- a/src/clabe/cli.py +++ b/src/clabe/cli.py @@ -1,5 +1,6 @@ from pydantic_settings import BaseSettings, CliApp, CliSubCommand +from .cache_manager import _CacheManagerCli from .xml_rpc._server import _XmlRpcServerStartCli @@ -7,6 +8,7 @@ class CliAppSettings(BaseSettings, cli_prog_name="clabe", cli_kebab_case=True): """CLI application settings.""" xml_rpc_server: CliSubCommand[_XmlRpcServerStartCli] + cache: CliSubCommand[_CacheManagerCli] def cli_cmd(self): """Run the selected subcommand.""" diff --git a/src/clabe/pickers/__init__.py b/src/clabe/pickers/__init__.py index 6d198352..b475897c 100644 --- a/src/clabe/pickers/__init__.py +++ b/src/clabe/pickers/__init__.py @@ -1,6 +1,8 @@ +from ._by_animal_modifier import ByAnimalModifier from .default_behavior import DefaultBehaviorPicker, DefaultBehaviorPickerSettings __all__ = [ "DefaultBehaviorPicker", "DefaultBehaviorPickerSettings", + "ByAnimalModifier", ] diff --git a/src/clabe/pickers/_by_animal_modifier.py b/src/clabe/pickers/_by_animal_modifier.py new file mode 100644 index 00000000..de7ca786 --- /dev/null +++ b/src/clabe/pickers/_by_animal_modifier.py @@ -0,0 +1,238 @@ +import abc +import functools +import logging +from pathlib import Path +from typing import Any, Generic, Protocol, TypeVar, runtime_checkable + +from pydantic import TypeAdapter + +from .._typing import TRig + +logger = logging.getLogger(__name__) +T = TypeVar("T") +TInjectable = TypeVar("TInjectable") + + +@runtime_checkable +class _IByAnimalModifier(Protocol, Generic[TRig]): + """ + Protocol defining the interface for by-animal modifiers. + + This protocol defines the contract that any by-animal modifier must implement + to inject and dump subject-specific configurations. + """ + + def inject(self, rig: TRig) -> TRig: + """Injects subject-specific configuration into the rig model.""" + ... + + def dump(self) -> None: + """Dumps the configuration to a JSON file.""" + ... + + +class ByAnimalModifier(abc.ABC, _IByAnimalModifier[TRig]): + """ + Abstract base class for modifying rig configurations with subject-specific data. + + This class provides a framework for loading and saving subject-specific + configuration data to/from JSON files. It uses reflection to access nested + attributes in the rig model and automatically handles serialization. + + Attributes: + _subject_db_path: Path to the directory containing subject-specific files + _model_path: Dot-separated path to the attribute in the rig model (e.g., "nested.field") + _model_name: Base name for the JSON file (without extension) + _tp: TypeAdapter for the model type, set during inject() + + Example: + ```python + from pathlib import Path + from clabe.pickers.default_behavior import ByAnimalModifier + import pydantic + + class MyModel(pydantic.BaseModel): + nested: "NestedConfig" + + class NestedConfig(pydantic.BaseModel): + value: int + + class MyModifier(ByAnimalModifier[MyModel]): + def __init__(self, subject_db_path: Path, **kwargs): + super().__init__( + subject_db_path=subject_db_path, + model_path="nested", + model_name="nested_config", + **kwargs + ) + + def _process_before_dump(self): + return NestedConfig(value=42) + + modifier = MyModifier(Path("./subject_db")) + model = MyModel(nested=NestedConfig(value=1)) + modified = modifier.inject(model) + modifier.dump() + ``` + """ + + def __init__(self, subject_db_path: Path, model_path: str, model_name: str, **kwargs) -> None: + """ + Initializes the ByAnimalModifier. + + Args: + subject_db_path: Path to the directory containing subject-specific JSON files + model_path: Dot-separated path to the target attribute in the rig model + model_name: Base name for the JSON file (without .json extension) + **kwargs: Additional keyword arguments (reserved for future use) + """ + self._subject_db_path = Path(subject_db_path) + self._model_path = model_path + self._model_name = model_name + self._tp: TypeAdapter[Any] | None = None + + def _process_before_inject(self, deserialized: T) -> T: + """ + Hook method called after deserialization but before injection. + + Override this method to modify the deserialized data before it's + injected into the rig model. + + Args: + deserialized: The deserialized object from the JSON file + + Returns: + The processed object to be injected + """ + return deserialized + + @abc.abstractmethod + def _process_before_dump(self) -> Any: + """ + Abstract method to generate the data to be dumped to JSON. + + Subclasses must implement this method to return the object that + should be serialized and saved to the JSON file. + + Returns: + The object to be serialized and dumped to JSON + """ + ... + + def inject(self, rig: TRig) -> TRig: + """ + Injects subject-specific configuration into the rig model. + + Loads configuration from a JSON file and injects it into the specified + path in the rig model. If the file doesn't exist, the rig is returned + unmodified with a warning logged. + + Args: + rig: The rig model to modify + + Returns: + The modified rig model + """ + target_file = self._subject_db_path / f"{self._model_name}.json" + if not target_file.exists(): + logger.warning(f"File not found: {target_file}. Using default.") + else: + target = rgetattr(rig, self._model_path) + self._tp = TypeAdapter(type(target)) + deserialized = self._tp.validate_json(target_file.read_text(encoding="utf-8")) + logger.info(f"Loading {self._model_name} from: {target_file}. Deserialized: {deserialized}") + self._process_before_inject(deserialized) + rsetattr(rig, self._model_path, deserialized) + return rig + + def dump(self) -> None: + """ + Dumps the configuration to a JSON file. + + Calls _process_before_dump() to get the data, then serializes it + to JSON and writes it to the target file. Creates parent directories + if they don't exist. + + Raises: + Exception: If _process_before_dump() fails or serialization fails + """ + target_folder = self._subject_db_path + target_file = target_folder / f"{self._model_name}.json" + + if (tp := self._tp) is None: + logger.warning("TypeAdapter is not set. Using TypeAdapter(Any) as fallback.") + tp = TypeAdapter(Any) + + try: + to_inject = self._process_before_dump() + logger.info(f"Saving {self._model_name} to: {target_file}. Serialized: {to_inject}") + target_folder.mkdir(parents=True, exist_ok=True) + target_file.write_text(tp.dump_json(to_inject, indent=2).decode("utf-8"), encoding="utf-8") + except Exception as e: + logger.error(f"Failed to process before dumping modifier: {e}") + raise + + +def rsetattr(obj, attr, val): + """ + Sets an attribute value using a dot-separated path. + + Args: + obj: The object to modify + attr: Dot-separated attribute path (e.g., "nested.field.value") + val: The value to set + + Returns: + The result of setattr on the final attribute + + Example: + ```python + class Inner: + value = 1 + + class Outer: + inner = Inner() + + obj = Outer() + rsetattr(obj, "inner.value", 42) + assert obj.inner.value == 42 + ``` + """ + pre, _, post = attr.rpartition(".") + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + """ + Gets an attribute value using a dot-separated path. + + Args: + obj: The object to query + attr: Dot-separated attribute path (e.g., "nested.field.value") + *args: Optional default value if attribute doesn't exist + + Returns: + The attribute value at the specified path + + Example: + ```python + class Inner: + value = 42 + + class Outer: + inner = Inner() + + obj = Outer() + result = rgetattr(obj, "inner.value") + assert result == 42 + + default = rgetattr(obj, "nonexistent.path", "default") + assert default == "default" + ``` + """ + + def _getattr(obj, attr): + """Helper function to get attribute with optional default.""" + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split(".")) diff --git a/src/clabe/pickers/default_behavior.py b/src/clabe/pickers/default_behavior.py index 6f566e9e..6dd718e5 100644 --- a/src/clabe/pickers/default_behavior.py +++ b/src/clabe/pickers/default_behavior.py @@ -2,7 +2,7 @@ import logging import os from pathlib import Path -from typing import Callable, ClassVar, List, Optional, Type, Union +from typing import Callable, ClassVar, List, Optional, Type, TypeVar, Union import pydantic from aind_behavior_curriculum import TrainerState @@ -11,12 +11,15 @@ from .. import ui from .._typing import TRig, TSession, TTaskLogic +from ..cache_manager import CacheManager from ..constants import ByAnimalFiles from ..launcher import Launcher from ..services import ServiceSettings from ..utils.aind_auth import validate_aind_username logger = logging.getLogger(__name__) +T = TypeVar("T") +TInjectable = TypeVar("TInjectable") class DefaultBehaviorPickerSettings(ServiceSettings): @@ -85,6 +88,7 @@ def __init__( self._experimenter_validator = experimenter_validator self._trainer_state: Optional[TrainerState] = None self._session: Optional[AindBehaviorSessionModel] = None + self._cache_manager = CacheManager.get_instance() @property def ui_helper(self) -> ui.UiHelper: @@ -116,6 +120,16 @@ def trainer_state(self) -> TrainerState: raise ValueError("Trainer state not set.") return self._trainer_state + @property + def session_directory(self) -> Path: + """Returns the directory path for the current session.""" + return self._launcher.session_directory + + @property + def session(self) -> AindBehaviorSessionModel: + """Returns the current session model.""" + return self._launcher.session + @property def config_library_dir(self) -> Path: """ @@ -184,27 +198,55 @@ def pick_rig(self, model: Type[TRig]) -> TRig: Raises: ValueError: If no rig configuration files are found or an invalid choice is made """ - available_rigs = glob.glob(os.path.join(self.rig_dir, "*.json")) - if len(available_rigs) == 0: - logger.error("No rig config files found.") - raise ValueError("No rig config files found.") - elif len(available_rigs) == 1: - logger.info("Found a single rig config file. Using %s.", {available_rigs[0]}) - rig = model_from_json_file(available_rigs[0], model) + rig: TRig | None = None + rig_path: str | None = None + + # Check cache for previously used rigs + cache = self._cache_manager.try_get_cache("rigs") + if cache: + rig_path = self.ui_helper.prompt_pick_from_list( + cache, + prompt="Choose a rig:", + allow_0_as_none=True, + zero_label="Select from library", + ) + if rig_path is not None: + rig = self._load_rig_from_path(Path(rig_path), model) + + # Prompt user to select a rig if not already selected + while rig is None: + available_rigs = glob.glob(os.path.join(self.rig_dir, "*.json")) + # We raise if no rigs are found to prevent an infinite loop + if len(available_rigs) == 0: + logger.error("No rig config files found.") + raise ValueError("No rig config files found.") + # Use the single available rig config file + elif len(available_rigs) == 1: + logger.info("Found a single rig config file. Using %s.", {available_rigs[0]}) + rig_path = available_rigs[0] + rig = model_from_json_file(rig_path, model) + else: + rig_path = self.ui_helper.prompt_pick_from_list(available_rigs, prompt="Choose a rig:") + if rig_path is not None: + rig = self._load_rig_from_path(Path(rig_path), model) + assert rig_path is not None + # Add the selected rig path to the cache + cache = self._cache_manager.add_to_cache("rigs", rig_path) + return rig + + @staticmethod + def _load_rig_from_path(path: Path, model: Type[TRig]) -> TRig | None: + """Load a rig configuration from a given path.""" + try: + if not isinstance(path, str): + raise ValueError("Invalid choice.") + rig = model_from_json_file(path, model) + logger.info("Using %s.", path) return rig - else: - while True: - try: - path = self.ui_helper.prompt_pick_from_list(available_rigs, prompt="Choose a rig:") - if not isinstance(path, str): - raise ValueError("Invalid choice.") - rig = model_from_json_file(path, model) - logger.info("Using %s.", path) - return rig - except pydantic.ValidationError as e: - logger.error("Failed to validate pydantic model. Try again. %s", e) - except ValueError as e: - logger.info("Invalid choice. Try again. %s", e) + except pydantic.ValidationError as e: + logger.error("Failed to validate pydantic model. Try again. %s", e) + except ValueError as e: + logger.info("Invalid choice. Try again. %s", e) def pick_session(self, model: Type[TSession] = AindBehaviorSessionModel) -> TSession: """ @@ -366,22 +408,22 @@ def choose_subject(self, directory: str | os.PathLike) -> str: subject = picker.choose_subject("Subjects") ``` """ - subject = None + subjects = self._cache_manager.try_get_cache("subjects") + if subjects: + subject = self.ui_helper.prompt_pick_from_list( + subjects, + prompt="Choose a subject:", + allow_0_as_none=True, + zero_label="Enter manually", + ) + else: + subject = None + while subject is None: subject = self.ui_helper.input("Enter subject name: ") if subject == "": - subject = self.ui_helper.prompt_pick_from_list( - [ - os.path.basename(folder) - for folder in os.listdir(directory) - if os.path.isdir(os.path.join(directory, folder)) - ], - prompt="Choose a subject:", - allow_0_as_none=True, - ) - else: - return subject - + subject = None + self._cache_manager.add_to_cache("subjects", subject) return subject def prompt_experimenter(self, strict: bool = True) -> Optional[List[str]]: @@ -404,10 +446,22 @@ def prompt_experimenter(self, strict: bool = True) -> Optional[List[str]]: print("Experimenters:", names) ``` """ + experimenters_cache = self._cache_manager.try_get_cache("experimenters") experimenter: Optional[List[str]] = None + _picked: str | None = None while experimenter is None: - _user_input = self.ui_helper.prompt_text("Experimenter name: ") - experimenter = _user_input.replace(",", " ").split() + if experimenters_cache: + _picked = self.ui_helper.prompt_pick_from_list( + experimenters_cache, + prompt="Choose an experimenter:", + allow_0_as_none=True, + zero_label="Enter manually", + ) + if _picked is None: + _input = self.ui_helper.prompt_text("Experimenter name: ") + else: + _input = _picked + experimenter = _input.replace(",", " ").split() if strict & (len(experimenter) == 0): logger.info("Experimenter name is not valid. Try again.") experimenter = None @@ -418,6 +472,7 @@ def prompt_experimenter(self, strict: bool = True) -> Optional[List[str]]: logger.warning("Experimenter name: %s, is not valid. Try again", name) experimenter = None break + self._cache_manager.add_to_cache("experimenters", ",".join(experimenter)) return experimenter def dump_model( diff --git a/src/clabe/ui/ui_helper.py b/src/clabe/ui/ui_helper.py index 252b0c2d..f4a157bf 100644 --- a/src/clabe/ui/ui_helper.py +++ b/src/clabe/ui/ui_helper.py @@ -191,7 +191,7 @@ def input(self, prompt: str) -> str: return self._input(prompt) def prompt_pick_from_list( - self, value: List[str], prompt: str, allow_0_as_none: bool = True, **kwargs + self, value: List[str], prompt: str, *, allow_0_as_none: bool = True, zero_label: str = "None", **kwargs ) -> Optional[str]: """ Prompts the user to pick an item from a list. @@ -223,7 +223,7 @@ def prompt_pick_from_list( try: self.print(prompt) if allow_0_as_none: - self.print("0: None") + self.print(f"0: {zero_label}") for i, item in enumerate(value): self.print(f"{i + 1}: {item}") choice = int(input("Choice: ")) diff --git a/tests/pickers/__init__.py b/tests/pickers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pickers/test_by_animal_modifier.py b/tests/pickers/test_by_animal_modifier.py new file mode 100644 index 00000000..17bf0d46 --- /dev/null +++ b/tests/pickers/test_by_animal_modifier.py @@ -0,0 +1,151 @@ +from pathlib import Path +from typing import Optional + +import pydantic +import pytest + +from clabe.pickers import ByAnimalModifier + + +class NestedModel(pydantic.BaseModel): + foo: str + bar: int + nested2: Optional["NestedModel"] = None + + +class Model(pydantic.BaseModel): + nested: NestedModel + something: float + + +class CustomModifier(ByAnimalModifier[Model]): + def __init__(self, subject_db_path: Path, model_path="nested", model_name="nested_model", **kwargs): + super().__init__(subject_db_path=subject_db_path, model_path=model_path, model_name=model_name, **kwargs) + + def _process_before_dump(self): + return NestedModel(foo="Modified", bar=10, nested2=NestedModel(foo="Modified Nested", bar=20, nested2=None)) + + +class TestByAnimalModifier: + @pytest.fixture + def temp_subject_db(self, tmp_path: Path): + subject_db = tmp_path / "subject_db" + subject_db.mkdir(parents=True, exist_ok=True) + return subject_db + + @pytest.fixture + def sample_model(self): + return Model( + nested=NestedModel(foo="Original", bar=5, nested2=NestedModel(foo="Nested", bar=5, nested2=None)), + something=3.14, + ) + + def test_inject_with_existing_file(self, temp_subject_db: Path, sample_model: Model): + nested_data = NestedModel(foo="Loaded", bar=99, nested2=None) + target_file = temp_subject_db / "nested_model.json" + target_file.write_text(nested_data.model_dump_json(indent=2), encoding="utf-8") + + modifier = CustomModifier(subject_db_path=temp_subject_db) + modified = modifier.inject(sample_model) + + assert modified.nested.foo == "Loaded" + assert modified.nested.bar == 99 + assert modified.nested.nested2 is None + + def test_inject_without_existing_file(self, temp_subject_db: Path, sample_model: Model): + modifier = CustomModifier(subject_db_path=temp_subject_db) + modified = modifier.inject(sample_model) + + assert modified.nested.foo == "Original" + assert modified.nested.bar == 5 + assert modified.something == 3.14 + + def test_dump_creates_file(self, temp_subject_db: Path, sample_model: Model): + modifier = CustomModifier(subject_db_path=temp_subject_db) + modifier.inject(sample_model) + modifier.dump() + + target_file = temp_subject_db / "nested_model.json" + assert target_file.exists() + + loaded = NestedModel.model_validate_json(target_file.read_text(encoding="utf-8")) + assert loaded.foo == "Modified" + assert loaded.bar == 10 + assert loaded.nested2.foo == "Modified Nested" + assert loaded.nested2.bar == 20 + + def test_dump_without_inject_uses_fallback(self, temp_subject_db: Path): + modifier = CustomModifier(subject_db_path=temp_subject_db) + modifier.dump() + + target_file = temp_subject_db / "nested_model.json" + assert target_file.exists() + + def test_inject_and_dump_roundtrip(self, temp_subject_db: Path, sample_model: Model): + modifier1 = CustomModifier(subject_db_path=temp_subject_db) + modifier1.inject(sample_model) + modifier1.dump() + + modifier2 = CustomModifier(subject_db_path=temp_subject_db) + modified = modifier2.inject(sample_model) + + assert modified.nested.foo == "Modified" + assert modified.nested.bar == 10 + assert modified.nested.nested2.foo == "Modified Nested" + + def test_nested_path_access(self, temp_subject_db: Path): + class DeepModel(pydantic.BaseModel): + level1: "Level1Model" + + class Level1Model(pydantic.BaseModel): + level2: "Level2Model" + + class Level2Model(pydantic.BaseModel): + value: int + + class DeepModifier(ByAnimalModifier[DeepModel]): + def __init__(self, subject_db_path: Path, **kwargs): + super().__init__( + subject_db_path=subject_db_path, model_path="level1.level2", model_name="deep_value", **kwargs + ) + + def _process_before_dump(self): + return Level2Model(value=999) + + model = DeepModel(level1=Level1Model(level2=Level2Model(value=1))) + + level2_data = Level2Model(value=42) + target_file = temp_subject_db / "deep_value.json" + target_file.write_text(level2_data.model_dump_json(indent=2), encoding="utf-8") + + modifier = DeepModifier(subject_db_path=temp_subject_db) + modified = modifier.inject(model) + + assert modified.level1.level2.value == 42 + + def test_process_before_inject_hook(self, temp_subject_db: Path, sample_model: Model): + class ModifierWithPreProcess(CustomModifier): + def _process_before_inject(self, deserialized): + deserialized.foo = "PreProcessed" + return deserialized + + nested_data = NestedModel(foo="Loaded", bar=99, nested2=None) + target_file = temp_subject_db / "nested_model.json" + target_file.write_text(nested_data.model_dump_json(indent=2), encoding="utf-8") + + modifier = ModifierWithPreProcess(subject_db_path=temp_subject_db) + modified = modifier.inject(sample_model) + + assert modified.nested.foo == "PreProcessed" + assert modified.nested.bar == 99 + + def test_dump_creates_parent_directories(self, tmp_path: Path, sample_model: Model): + nested_subject_db = tmp_path / "parent" / "child" / "subject_db" + + modifier = CustomModifier(subject_db_path=nested_subject_db) + modifier.inject(sample_model) + modifier.dump() + + target_file = nested_subject_db / "nested_model.json" + assert target_file.exists() + assert target_file.parent.exists() diff --git a/tests/test_cached_settings.py b/tests/test_cached_settings.py new file mode 100644 index 00000000..fd0fe4d2 --- /dev/null +++ b/tests/test_cached_settings.py @@ -0,0 +1,244 @@ +"""Tests for cached settings manager with auto-sync capabilities.""" + +import tempfile +from pathlib import Path + +from clabe.cache_manager import CachedSettings, CacheManager, SyncStrategy + + +class TestCachedSettings: + """Tests for the generic CachedSettings class.""" + + def test_add_single_value(self): + """Test adding a single value to cache.""" + cache = CachedSettings[str](max_history=3) + cache.add("first") + assert cache.get_all() == ["first"] + assert cache.get_latest() == "first" + + def test_add_multiple_values(self): + """Test adding multiple values maintains order (newest first).""" + cache = CachedSettings[str](max_history=5) + cache.add("first") + cache.add("second") + cache.add("third") + assert cache.get_all() == ["third", "second", "first"] + assert cache.get_latest() == "third" + + def test_max_history_limit(self): + """Test that oldest values are removed when limit is exceeded.""" + cache = CachedSettings[str](max_history=3) + cache.add("first") + cache.add("second") + cache.add("third") + cache.add("fourth") # Should remove "first" + + assert cache.get_all() == ["fourth", "third", "second"] + assert len(cache.get_all()) == 3 + + def test_duplicate_values_moved_to_front(self): + """Test that adding a duplicate moves it to the front.""" + cache = CachedSettings[str](max_history=5) + cache.add("first") + cache.add("second") + cache.add("third") + cache.add("first") # Should move "first" to front + + assert cache.get_all() == ["first", "third", "second"] + + def test_clear(self): + """Test clearing the cache.""" + cache = CachedSettings[str](max_history=3) + cache.add("first") + cache.add("second") + cache.clear() + + assert cache.get_all() == [] + assert cache.get_latest() is None + + def test_get_latest_empty(self): + """Test getting latest from empty cache returns None.""" + cache = CachedSettings[str](max_history=3) + assert cache.get_latest() is None + + +class TestCacheManagerManualSync: + """Tests for CacheManager with manual sync strategy.""" + + def test_register_and_add(self): + """Test registering a cache and adding values.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.register_cache("subjects", max_history=3) + manager.add_to_cache("subjects", "mouse_001") + manager.add_to_cache("subjects", "mouse_002") + + assert manager.get_cache("subjects") == ["mouse_002", "mouse_001"] + + def test_auto_register_on_add(self): + """Test that adding to unregistered cache auto-registers it.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.add_to_cache("subjects", "mouse_001") + assert manager.get_cache("subjects") == ["mouse_001"] + + def test_multiple_caches(self): + """Test managing multiple independent caches.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.register_cache("subjects", max_history=3) + manager.register_cache("experimenters", max_history=2) + + manager.add_to_cache("subjects", "mouse_001") + manager.add_to_cache("subjects", "mouse_002") + manager.add_to_cache("experimenters", "alice") + manager.add_to_cache("experimenters", "bob") + + assert manager.get_cache("subjects") == ["mouse_002", "mouse_001"] + assert manager.get_cache("experimenters") == ["bob", "alice"] + + def test_get_latest(self): + """Test getting the latest value from a cache.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.add_to_cache("test", "first") + manager.add_to_cache("test", "second") + + assert manager.get_latest("test") == "second" + + def test_get_latest_empty(self): + """Test getting latest from empty cache returns None.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.register_cache("test", max_history=3) + + assert manager.get_latest("test") is None + + def test_clear_cache(self): + """Test clearing a specific cache.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.add_to_cache("test", "value1") + manager.add_to_cache("test", "value2") + manager.clear_cache("test") + + assert manager.get_cache("test") == [] + + def test_clear_all_caches(self): + """Test clearing all caches at once.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.add_to_cache("subjects", "mouse_001") + manager.add_to_cache("subjects", "mouse_002") + manager.add_to_cache("experimenters", "alice") + manager.add_to_cache("projects", "project_a") + + assert len(manager.caches) == 3 + assert manager.get_cache("subjects") == ["mouse_002", "mouse_001"] + + manager.clear_all_caches() + + assert manager.caches == {} + assert len(manager.caches) == 0 + + def test_singleton_behavior(self): + """Test that get_instance returns the same instance.""" + manager1 = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager1.add_to_cache("test", "value1") + + manager2 = CacheManager.get_instance() + assert manager2.get_cache("test") == ["value1"] + assert manager1 is manager2 + + +class TestCacheManagerAutoSync: + """Tests for CacheManager with auto-sync to disk.""" + + def test_auto_sync_on_add(self): + """Test that AUTO sync saves after adding values.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "cache.json" + + manager = CacheManager.get_instance(reset=True, cache_path=path, sync_strategy=SyncStrategy.AUTO) + manager.add_to_cache("subjects", "mouse_001") + + assert path.exists() + + manager2 = CacheManager.get_instance(reset=True, cache_path=path) + assert manager2.get_cache("subjects") == ["mouse_001"] + + def test_auto_sync_on_clear(self): + """Test that AUTO sync saves after clearing.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "cache.json" + + manager = CacheManager.get_instance(reset=True, cache_path=path, sync_strategy=SyncStrategy.AUTO) + manager.add_to_cache("test", "value1") + manager.clear_cache("test") + + # Reload and verify clear persisted + manager2 = CacheManager.get_instance(reset=True, cache_path=path) + assert manager2.get_cache("test") == [] + + def test_auto_sync_on_clear_all(self): + """Test that AUTO sync saves after clearing all caches.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "cache.json" + + manager = CacheManager.get_instance(reset=True, cache_path=path, sync_strategy=SyncStrategy.AUTO) + manager.add_to_cache("subjects", "mouse_001") + manager.add_to_cache("projects", "project_a") + + assert path.exists() + + manager.clear_all_caches() + + manager2 = CacheManager.get_instance(reset=True, cache_path=path) + assert manager2.caches == {} + + def test_manual_sync_does_not_auto_save(self): + """Test that MANUAL sync does not save automatically.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "cache.json" + + manager = CacheManager.get_instance(reset=True, cache_path=path, sync_strategy=SyncStrategy.MANUAL) + manager.add_to_cache("test", "value1") + + # File should not exist yet + assert not path.exists() + + # Explicit save + manager.save() + assert path.exists() + + def test_load_nonexistent_file(self): + """Test loading from non-existent file returns empty manager.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "nonexistent.json" + manager = CacheManager.get_instance(reset=True, cache_path=path) + + assert manager.caches == {} + + def test_persistence_across_loads(self): + """Test data persists correctly across save/load cycles.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "cache.json" + + manager1 = CacheManager.get_instance(reset=True, cache_path=path, sync_strategy=SyncStrategy.AUTO) + manager1.add_to_cache("subjects", "mouse_001") + manager1.add_to_cache("subjects", "mouse_002") + manager1.add_to_cache("projects", "project_a") + + manager2 = CacheManager.get_instance(reset=True, cache_path=path) + assert manager2.get_cache("subjects") == ["mouse_002", "mouse_001"] + assert manager2.get_cache("projects") == ["project_a"] + + manager2.add_to_cache("subjects", "mouse_003") + manager2.save() + + manager3 = CacheManager.get_instance(reset=True, cache_path=path) + assert manager3.get_cache("subjects") == ["mouse_003", "mouse_002", "mouse_001"] + + def test_default_cache_path(self): + """Test that default cache path is used when none specified.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.add_to_cache("test", "value") + manager.save() + + manager2 = CacheManager.get_instance(reset=True) + assert manager2.get_cache("test") == ["value"] + + manager.cache_path.unlink(missing_ok=True)