From d1fca4947005a56c717f5efb9bacfea2dc86af15 Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 13 Nov 2025 13:28:42 -0800 Subject: [PATCH 1/8] Add initial implementation for generic model modifier --- src/clabe/pickers/default_behavior.py | 78 ++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/src/clabe/pickers/default_behavior.py b/src/clabe/pickers/default_behavior.py index 6f566e9..35277ff 100644 --- a/src/clabe/pickers/default_behavior.py +++ b/src/clabe/pickers/default_behavior.py @@ -1,13 +1,16 @@ +import abc +import functools import glob import logging import os from pathlib import Path -from typing import Callable, ClassVar, List, Optional, Type, Union +from typing import Any, Callable, ClassVar, Generic, List, Optional, Protocol, Type, TypeVar, Union, runtime_checkable import pydantic from aind_behavior_curriculum import TrainerState from aind_behavior_services import AindBehaviorRigModel, AindBehaviorSessionModel, AindBehaviorTaskLogicModel from aind_behavior_services.utils import model_from_json_file +from pydantic import TypeAdapter from .. import ui from .._typing import TRig, TSession, TTaskLogic @@ -17,6 +20,8 @@ from ..utils.aind_auth import validate_aind_username logger = logging.getLogger(__name__) +T = TypeVar("T") +TInjectable = TypeVar("TInjectable") class DefaultBehaviorPickerSettings(ServiceSettings): @@ -116,6 +121,14 @@ def trainer_state(self) -> TrainerState: raise ValueError("Trainer state not set.") return self._trainer_state + @property + def session_directory(self) -> Path: + return self._launcher.session_directory + + @property + def session(self) -> AindBehaviorSessionModel: + return self._launcher.session + @property def config_library_dir(self) -> Path: """ @@ -458,3 +471,66 @@ def dump_model( f.write(model.model_dump_json(indent=2)) logger.info("Saved model to %s", path) return path + + +@runtime_checkable +class _IByAnimalModifier(Protocol, Generic[TRig]): + def inject(self, rig: TRig) -> TRig: ... + + def dump(self, rig: TRig) -> None: ... + + +class ByAnimalModifier(abc.ABC, _IByAnimalModifier[TRig]): + def __init__(self, picker: DefaultBehaviorPicker, model_path: str, model_name: str, **kwargs) -> None: + self._picker = picker + self._model_path = model_path + self._model_name = model_name + + def _process_before_inject(self, deserialized: T) -> T: + return deserialized + + @abc.abstractmethod + def _process_before_dump(self) -> Any: ... + + def inject(self, rig: TRig) -> TRig: + subject = self._picker.session.subject + target_folder = self._picker.subject_dir / subject + target_file = target_folder / self._model_name + if not target_file.exists(): + logger.warning(f"File not found: {target_file}. Using default.") + else: + target = rgetattr(rig, self._model_path) + deserialized = TypeAdapter(type(target)).validate_json(target_file.read_text(encoding="utf-8")) + logger.info(f"Loading {self._model_name} from: {target_file}. Deserialized: {deserialized}") + self._process_before_inject(deserialized) + rsetattr(rig, self._model_path, deserialized) + return rig + + def dump(self, rig: TRig) -> None: + subject = self._picker.session.subject + target_folder = self._picker.subject_dir / subject + target_file = target_folder / self._model_name + target = rgetattr(rig, self._model_path) + tp = TypeAdapter(type(target)) + + try: + to_inject = self._process_before_dump() + logger.info(f"Saving {self._model_name} to: {target_file}. Serialized: {to_inject}") + target_folder.mkdir(parents=True, exist_ok=True) + target_file.write_text(tp.dump_json(to_inject, indent=2).decode("utf-8"), encoding="utf-8") + except Exception as e: + logger.error(f"Failed to process before dumping modifier: {e}") + raise + + +# from https://stackoverflow.com/a/31174427 +def rsetattr(obj, attr, val): + pre, _, post = attr.rpartition(".") + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + def _getattr(obj, attr): + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split(".")) From 9474c98e350f97745b8523a81b9f08e0797ad056 Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 13 Nov 2025 14:15:29 -0800 Subject: [PATCH 2/8] Implement generic `ByAnimalModifier` --- src/clabe/pickers/default_behavior.py | 196 +++++++++++++++++++++-- tests/pickers/__init__.py | 0 tests/pickers/test_by_animal_modifier.py | 151 +++++++++++++++++ 3 files changed, 331 insertions(+), 16 deletions(-) create mode 100644 tests/pickers/__init__.py create mode 100644 tests/pickers/test_by_animal_modifier.py diff --git a/src/clabe/pickers/default_behavior.py b/src/clabe/pickers/default_behavior.py index 35277ff..b825a61 100644 --- a/src/clabe/pickers/default_behavior.py +++ b/src/clabe/pickers/default_behavior.py @@ -123,10 +123,12 @@ def trainer_state(self) -> TrainerState: @property def session_directory(self) -> Path: + """Returns the directory path for the current session.""" return self._launcher.session_directory @property def session(self) -> AindBehaviorSessionModel: + """Returns the current session model.""" return self._launcher.session @property @@ -475,43 +477,153 @@ def dump_model( @runtime_checkable class _IByAnimalModifier(Protocol, Generic[TRig]): - def inject(self, rig: TRig) -> TRig: ... + """ + Protocol defining the interface for by-animal modifiers. + + This protocol defines the contract that any by-animal modifier must implement + to inject and dump subject-specific configurations. + """ + + def inject(self, rig: TRig) -> TRig: + """Injects subject-specific configuration into the rig model.""" + ... - def dump(self, rig: TRig) -> None: ... + def dump(self) -> None: + """Dumps the configuration to a JSON file.""" + ... class ByAnimalModifier(abc.ABC, _IByAnimalModifier[TRig]): - def __init__(self, picker: DefaultBehaviorPicker, model_path: str, model_name: str, **kwargs) -> None: - self._picker = picker + """ + Abstract base class for modifying rig configurations with subject-specific data. + + This class provides a framework for loading and saving subject-specific + configuration data to/from JSON files. It uses reflection to access nested + attributes in the rig model and automatically handles serialization. + + Attributes: + _subject_db_path: Path to the directory containing subject-specific files + _model_path: Dot-separated path to the attribute in the rig model (e.g., "nested.field") + _model_name: Base name for the JSON file (without extension) + _tp: TypeAdapter for the model type, set during inject() + + Example: + ```python + from pathlib import Path + from clabe.pickers.default_behavior import ByAnimalModifier + import pydantic + + class MyModel(pydantic.BaseModel): + nested: "NestedConfig" + + class NestedConfig(pydantic.BaseModel): + value: int + + class MyModifier(ByAnimalModifier[MyModel]): + def __init__(self, subject_db_path: Path, **kwargs): + super().__init__( + subject_db_path=subject_db_path, + model_path="nested", + model_name="nested_config", + **kwargs + ) + + def _process_before_dump(self): + return NestedConfig(value=42) + + modifier = MyModifier(Path("./subject_db")) + model = MyModel(nested=NestedConfig(value=1)) + modified = modifier.inject(model) + modifier.dump() + ``` + """ + + def __init__(self, subject_db_path: Path, model_path: str, model_name: str, **kwargs) -> None: + """ + Initializes the ByAnimalModifier. + + Args: + subject_db_path: Path to the directory containing subject-specific JSON files + model_path: Dot-separated path to the target attribute in the rig model + model_name: Base name for the JSON file (without .json extension) + **kwargs: Additional keyword arguments (reserved for future use) + """ + self._subject_db_path = Path(subject_db_path) self._model_path = model_path self._model_name = model_name + self._tp: TypeAdapter[Any] | None = None def _process_before_inject(self, deserialized: T) -> T: + """ + Hook method called after deserialization but before injection. + + Override this method to modify the deserialized data before it's + injected into the rig model. + + Args: + deserialized: The deserialized object from the JSON file + + Returns: + The processed object to be injected + """ return deserialized @abc.abstractmethod - def _process_before_dump(self) -> Any: ... + def _process_before_dump(self) -> Any: + """ + Abstract method to generate the data to be dumped to JSON. + + Subclasses must implement this method to return the object that + should be serialized and saved to the JSON file. + + Returns: + The object to be serialized and dumped to JSON + """ + ... def inject(self, rig: TRig) -> TRig: - subject = self._picker.session.subject - target_folder = self._picker.subject_dir / subject - target_file = target_folder / self._model_name + """ + Injects subject-specific configuration into the rig model. + + Loads configuration from a JSON file and injects it into the specified + path in the rig model. If the file doesn't exist, the rig is returned + unmodified with a warning logged. + + Args: + rig: The rig model to modify + + Returns: + The modified rig model + """ + target_file = self._subject_db_path / f"{self._model_name}.json" if not target_file.exists(): logger.warning(f"File not found: {target_file}. Using default.") else: target = rgetattr(rig, self._model_path) - deserialized = TypeAdapter(type(target)).validate_json(target_file.read_text(encoding="utf-8")) + self._tp = TypeAdapter(type(target)) + deserialized = self._tp.validate_json(target_file.read_text(encoding="utf-8")) logger.info(f"Loading {self._model_name} from: {target_file}. Deserialized: {deserialized}") self._process_before_inject(deserialized) rsetattr(rig, self._model_path, deserialized) return rig - def dump(self, rig: TRig) -> None: - subject = self._picker.session.subject - target_folder = self._picker.subject_dir / subject - target_file = target_folder / self._model_name - target = rgetattr(rig, self._model_path) - tp = TypeAdapter(type(target)) + def dump(self) -> None: + """ + Dumps the configuration to a JSON file. + + Calls _process_before_dump() to get the data, then serializes it + to JSON and writes it to the target file. Creates parent directories + if they don't exist. + + Raises: + Exception: If _process_before_dump() fails or serialization fails + """ + target_folder = self._subject_db_path + target_file = target_folder / f"{self._model_name}.json" + + if (tp := self._tp) is None: + logger.warning("TypeAdapter is not set. Using TypeAdapter(Any) as fallback.") + tp = TypeAdapter(Any) try: to_inject = self._process_before_dump() @@ -523,14 +635,66 @@ def dump(self, rig: TRig) -> None: raise -# from https://stackoverflow.com/a/31174427 def rsetattr(obj, attr, val): + """ + Sets an attribute value using a dot-separated path. + + Args: + obj: The object to modify + attr: Dot-separated attribute path (e.g., "nested.field.value") + val: The value to set + + Returns: + The result of setattr on the final attribute + + Example: + ```python + class Inner: + value = 1 + + class Outer: + inner = Inner() + + obj = Outer() + rsetattr(obj, "inner.value", 42) + assert obj.inner.value == 42 + ``` + """ pre, _, post = attr.rpartition(".") return setattr(rgetattr(obj, pre) if pre else obj, post, val) def rgetattr(obj, attr, *args): + """ + Gets an attribute value using a dot-separated path. + + Args: + obj: The object to query + attr: Dot-separated attribute path (e.g., "nested.field.value") + *args: Optional default value if attribute doesn't exist + + Returns: + The attribute value at the specified path + + Example: + ```python + class Inner: + value = 42 + + class Outer: + inner = Inner() + + obj = Outer() + result = rgetattr(obj, "inner.value") + assert result == 42 + + default = rgetattr(obj, "nonexistent.path", "default") + assert default == "default" + ``` + """ + def _getattr(obj, attr): + """Helper function to get attribute with optional default.""" return getattr(obj, attr, *args) return functools.reduce(_getattr, [obj] + attr.split(".")) diff --git a/tests/pickers/__init__.py b/tests/pickers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/pickers/test_by_animal_modifier.py b/tests/pickers/test_by_animal_modifier.py new file mode 100644 index 0000000..8090bff --- /dev/null +++ b/tests/pickers/test_by_animal_modifier.py @@ -0,0 +1,151 @@ +from pathlib import Path +from typing import Optional + +import pydantic +import pytest + +from clabe.pickers.default_behavior import ByAnimalModifier + + +class NestedModel(pydantic.BaseModel): + foo: str + bar: int + nested2: Optional["NestedModel"] = None + + +class Model(pydantic.BaseModel): + nested: NestedModel + something: float + + +class CustomModifier(ByAnimalModifier[Model]): + def __init__(self, subject_db_path: Path, model_path="nested", model_name="nested_model", **kwargs): + super().__init__(subject_db_path=subject_db_path, model_path=model_path, model_name=model_name, **kwargs) + + def _process_before_dump(self): + return NestedModel(foo="Modified", bar=10, nested2=NestedModel(foo="Modified Nested", bar=20, nested2=None)) + + +class TestByAnimalModifier: + @pytest.fixture + def temp_subject_db(self, tmp_path: Path): + subject_db = tmp_path / "subject_db" + subject_db.mkdir(parents=True, exist_ok=True) + return subject_db + + @pytest.fixture + def sample_model(self): + return Model( + nested=NestedModel(foo="Original", bar=5, nested2=NestedModel(foo="Nested", bar=5, nested2=None)), + something=3.14, + ) + + def test_inject_with_existing_file(self, temp_subject_db: Path, sample_model: Model): + nested_data = NestedModel(foo="Loaded", bar=99, nested2=None) + target_file = temp_subject_db / "nested_model.json" + target_file.write_text(nested_data.model_dump_json(indent=2), encoding="utf-8") + + modifier = CustomModifier(subject_db_path=temp_subject_db) + modified = modifier.inject(sample_model) + + assert modified.nested.foo == "Loaded" + assert modified.nested.bar == 99 + assert modified.nested.nested2 is None + + def test_inject_without_existing_file(self, temp_subject_db: Path, sample_model: Model): + modifier = CustomModifier(subject_db_path=temp_subject_db) + modified = modifier.inject(sample_model) + + assert modified.nested.foo == "Original" + assert modified.nested.bar == 5 + assert modified.something == 3.14 + + def test_dump_creates_file(self, temp_subject_db: Path, sample_model: Model): + modifier = CustomModifier(subject_db_path=temp_subject_db) + modifier.inject(sample_model) + modifier.dump() + + target_file = temp_subject_db / "nested_model.json" + assert target_file.exists() + + loaded = NestedModel.model_validate_json(target_file.read_text(encoding="utf-8")) + assert loaded.foo == "Modified" + assert loaded.bar == 10 + assert loaded.nested2.foo == "Modified Nested" + assert loaded.nested2.bar == 20 + + def test_dump_without_inject_uses_fallback(self, temp_subject_db: Path): + modifier = CustomModifier(subject_db_path=temp_subject_db) + modifier.dump() + + target_file = temp_subject_db / "nested_model.json" + assert target_file.exists() + + def test_inject_and_dump_roundtrip(self, temp_subject_db: Path, sample_model: Model): + modifier1 = CustomModifier(subject_db_path=temp_subject_db) + modifier1.inject(sample_model) + modifier1.dump() + + modifier2 = CustomModifier(subject_db_path=temp_subject_db) + modified = modifier2.inject(sample_model) + + assert modified.nested.foo == "Modified" + assert modified.nested.bar == 10 + assert modified.nested.nested2.foo == "Modified Nested" + + def test_nested_path_access(self, temp_subject_db: Path): + class DeepModel(pydantic.BaseModel): + level1: "Level1Model" + + class Level1Model(pydantic.BaseModel): + level2: "Level2Model" + + class Level2Model(pydantic.BaseModel): + value: int + + class DeepModifier(ByAnimalModifier[DeepModel]): + def __init__(self, subject_db_path: Path, **kwargs): + super().__init__( + subject_db_path=subject_db_path, model_path="level1.level2", model_name="deep_value", **kwargs + ) + + def _process_before_dump(self): + return Level2Model(value=999) + + model = DeepModel(level1=Level1Model(level2=Level2Model(value=1))) + + level2_data = Level2Model(value=42) + target_file = temp_subject_db / "deep_value.json" + target_file.write_text(level2_data.model_dump_json(indent=2), encoding="utf-8") + + modifier = DeepModifier(subject_db_path=temp_subject_db) + modified = modifier.inject(model) + + assert modified.level1.level2.value == 42 + + def test_process_before_inject_hook(self, temp_subject_db: Path, sample_model: Model): + class ModifierWithPreProcess(CustomModifier): + def _process_before_inject(self, deserialized): + deserialized.foo = "PreProcessed" + return deserialized + + nested_data = NestedModel(foo="Loaded", bar=99, nested2=None) + target_file = temp_subject_db / "nested_model.json" + target_file.write_text(nested_data.model_dump_json(indent=2), encoding="utf-8") + + modifier = ModifierWithPreProcess(subject_db_path=temp_subject_db) + modified = modifier.inject(sample_model) + + assert modified.nested.foo == "PreProcessed" + assert modified.nested.bar == 99 + + def test_dump_creates_parent_directories(self, tmp_path: Path, sample_model: Model): + nested_subject_db = tmp_path / "parent" / "child" / "subject_db" + + modifier = CustomModifier(subject_db_path=nested_subject_db) + modifier.inject(sample_model) + modifier.dump() + + target_file = nested_subject_db / "nested_model.json" + assert target_file.exists() + assert target_file.parent.exists() From 7482b5a057ebbc8693b9a3733e1d37867609b980 Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 13 Nov 2025 14:19:55 -0800 Subject: [PATCH 3/8] Move modifiers to their own module --- src/clabe/pickers/__init__.py | 2 + src/clabe/pickers/_by_animal_modifier.py | 238 +++++++++++++++++++++++ src/clabe/pickers/default_behavior.py | 230 +--------------------- tests/pickers/test_by_animal_modifier.py | 2 +- 4 files changed, 242 insertions(+), 230 deletions(-) create mode 100644 src/clabe/pickers/_by_animal_modifier.py diff --git a/src/clabe/pickers/__init__.py b/src/clabe/pickers/__init__.py index 6d19835..b475897 100644 --- a/src/clabe/pickers/__init__.py +++ b/src/clabe/pickers/__init__.py @@ -1,6 +1,8 @@ +from ._by_animal_modifier import ByAnimalModifier from .default_behavior import DefaultBehaviorPicker, DefaultBehaviorPickerSettings __all__ = [ "DefaultBehaviorPicker", "DefaultBehaviorPickerSettings", + "ByAnimalModifier", ] diff --git a/src/clabe/pickers/_by_animal_modifier.py b/src/clabe/pickers/_by_animal_modifier.py new file mode 100644 index 0000000..de7ca78 --- /dev/null +++ b/src/clabe/pickers/_by_animal_modifier.py @@ -0,0 +1,238 @@ +import abc +import functools +import logging +from pathlib import Path +from typing import Any, Generic, Protocol, TypeVar, runtime_checkable + +from pydantic import TypeAdapter + +from .._typing import TRig + +logger = logging.getLogger(__name__) +T = TypeVar("T") +TInjectable = TypeVar("TInjectable") + + +@runtime_checkable +class _IByAnimalModifier(Protocol, Generic[TRig]): + """ + Protocol defining the interface for by-animal modifiers. + + This protocol defines the contract that any by-animal modifier must implement + to inject and dump subject-specific configurations. + """ + + def inject(self, rig: TRig) -> TRig: + """Injects subject-specific configuration into the rig model.""" + ... + + def dump(self) -> None: + """Dumps the configuration to a JSON file.""" + ... + + +class ByAnimalModifier(abc.ABC, _IByAnimalModifier[TRig]): + """ + Abstract base class for modifying rig configurations with subject-specific data. + + This class provides a framework for loading and saving subject-specific + configuration data to/from JSON files. It uses reflection to access nested + attributes in the rig model and automatically handles serialization. + + Attributes: + _subject_db_path: Path to the directory containing subject-specific files + _model_path: Dot-separated path to the attribute in the rig model (e.g., "nested.field") + _model_name: Base name for the JSON file (without extension) + _tp: TypeAdapter for the model type, set during inject() + + Example: + ```python + from pathlib import Path + from clabe.pickers.default_behavior import ByAnimalModifier + import pydantic + + class MyModel(pydantic.BaseModel): + nested: "NestedConfig" + + class NestedConfig(pydantic.BaseModel): + value: int + + class MyModifier(ByAnimalModifier[MyModel]): + def __init__(self, subject_db_path: Path, **kwargs): + super().__init__( + subject_db_path=subject_db_path, + model_path="nested", + model_name="nested_config", + **kwargs + ) + + def _process_before_dump(self): + return NestedConfig(value=42) + + modifier = MyModifier(Path("./subject_db")) + model = MyModel(nested=NestedConfig(value=1)) + modified = modifier.inject(model) + modifier.dump() + ``` + """ + + def __init__(self, subject_db_path: Path, model_path: str, model_name: str, **kwargs) -> None: + """ + Initializes the ByAnimalModifier. + + Args: + subject_db_path: Path to the directory containing subject-specific JSON files + model_path: Dot-separated path to the target attribute in the rig model + model_name: Base name for the JSON file (without .json extension) + **kwargs: Additional keyword arguments (reserved for future use) + """ + self._subject_db_path = Path(subject_db_path) + self._model_path = model_path + self._model_name = model_name + self._tp: TypeAdapter[Any] | None = None + + def _process_before_inject(self, deserialized: T) -> T: + """ + Hook method called after deserialization but before injection. + + Override this method to modify the deserialized data before it's + injected into the rig model. + + Args: + deserialized: The deserialized object from the JSON file + + Returns: + The processed object to be injected + """ + return deserialized + + @abc.abstractmethod + def _process_before_dump(self) -> Any: + """ + Abstract method to generate the data to be dumped to JSON. + + Subclasses must implement this method to return the object that + should be serialized and saved to the JSON file. + + Returns: + The object to be serialized and dumped to JSON + """ + ... + + def inject(self, rig: TRig) -> TRig: + """ + Injects subject-specific configuration into the rig model. + + Loads configuration from a JSON file and injects it into the specified + path in the rig model. If the file doesn't exist, the rig is returned + unmodified with a warning logged. + + Args: + rig: The rig model to modify + + Returns: + The modified rig model + """ + target_file = self._subject_db_path / f"{self._model_name}.json" + if not target_file.exists(): + logger.warning(f"File not found: {target_file}. Using default.") + else: + target = rgetattr(rig, self._model_path) + self._tp = TypeAdapter(type(target)) + deserialized = self._tp.validate_json(target_file.read_text(encoding="utf-8")) + logger.info(f"Loading {self._model_name} from: {target_file}. Deserialized: {deserialized}") + self._process_before_inject(deserialized) + rsetattr(rig, self._model_path, deserialized) + return rig + + def dump(self) -> None: + """ + Dumps the configuration to a JSON file. + + Calls _process_before_dump() to get the data, then serializes it + to JSON and writes it to the target file. Creates parent directories + if they don't exist. + + Raises: + Exception: If _process_before_dump() fails or serialization fails + """ + target_folder = self._subject_db_path + target_file = target_folder / f"{self._model_name}.json" + + if (tp := self._tp) is None: + logger.warning("TypeAdapter is not set. Using TypeAdapter(Any) as fallback.") + tp = TypeAdapter(Any) + + try: + to_inject = self._process_before_dump() + logger.info(f"Saving {self._model_name} to: {target_file}. Serialized: {to_inject}") + target_folder.mkdir(parents=True, exist_ok=True) + target_file.write_text(tp.dump_json(to_inject, indent=2).decode("utf-8"), encoding="utf-8") + except Exception as e: + logger.error(f"Failed to process before dumping modifier: {e}") + raise + + +def rsetattr(obj, attr, val): + """ + Sets an attribute value using a dot-separated path. + + Args: + obj: The object to modify + attr: Dot-separated attribute path (e.g., "nested.field.value") + val: The value to set + + Returns: + The result of setattr on the final attribute + + Example: + ```python + class Inner: + value = 1 + + class Outer: + inner = Inner() + + obj = Outer() + rsetattr(obj, "inner.value", 42) + assert obj.inner.value == 42 + ``` + """ + pre, _, post = attr.rpartition(".") + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + """ + Gets an attribute value using a dot-separated path. + + Args: + obj: The object to query + attr: Dot-separated attribute path (e.g., "nested.field.value") + *args: Optional default value if attribute doesn't exist + + Returns: + The attribute value at the specified path + + Example: + ```python + class Inner: + value = 42 + + class Outer: + inner = Inner() + + obj = Outer() + result = rgetattr(obj, "inner.value") + assert result == 42 + + default = rgetattr(obj, "nonexistent.path", "default") + assert default == "default" + ``` + """ + + def _getattr(obj, attr): + """Helper function to get attribute with optional default.""" + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split(".")) diff --git a/src/clabe/pickers/default_behavior.py b/src/clabe/pickers/default_behavior.py index b825a61..c256032 100644 --- a/src/clabe/pickers/default_behavior.py +++ b/src/clabe/pickers/default_behavior.py @@ -1,16 +1,13 @@ -import abc -import functools import glob import logging import os from pathlib import Path -from typing import Any, Callable, ClassVar, Generic, List, Optional, Protocol, Type, TypeVar, Union, runtime_checkable +from typing import Callable, ClassVar, List, Optional, Type, TypeVar, Union import pydantic from aind_behavior_curriculum import TrainerState from aind_behavior_services import AindBehaviorRigModel, AindBehaviorSessionModel, AindBehaviorTaskLogicModel from aind_behavior_services.utils import model_from_json_file -from pydantic import TypeAdapter from .. import ui from .._typing import TRig, TSession, TTaskLogic @@ -473,228 +470,3 @@ def dump_model( f.write(model.model_dump_json(indent=2)) logger.info("Saved model to %s", path) return path - - -@runtime_checkable -class _IByAnimalModifier(Protocol, Generic[TRig]): - """ - Protocol defining the interface for by-animal modifiers. - - This protocol defines the contract that any by-animal modifier must implement - to inject and dump subject-specific configurations. - """ - - def inject(self, rig: TRig) -> TRig: - """Injects subject-specific configuration into the rig model.""" - ... - - def dump(self) -> None: - """Dumps the configuration to a JSON file.""" - ... - - -class ByAnimalModifier(abc.ABC, _IByAnimalModifier[TRig]): - """ - Abstract base class for modifying rig configurations with subject-specific data. - - This class provides a framework for loading and saving subject-specific - configuration data to/from JSON files. It uses reflection to access nested - attributes in the rig model and automatically handles serialization. - - Attributes: - _subject_db_path: Path to the directory containing subject-specific files - _model_path: Dot-separated path to the attribute in the rig model (e.g., "nested.field") - _model_name: Base name for the JSON file (without extension) - _tp: TypeAdapter for the model type, set during inject() - - Example: - ```python - from pathlib import Path - from clabe.pickers.default_behavior import ByAnimalModifier - import pydantic - - class MyModel(pydantic.BaseModel): - nested: "NestedConfig" - - class NestedConfig(pydantic.BaseModel): - value: int - - class MyModifier(ByAnimalModifier[MyModel]): - def __init__(self, subject_db_path: Path, **kwargs): - super().__init__( - subject_db_path=subject_db_path, - model_path="nested", - model_name="nested_config", - **kwargs - ) - - def _process_before_dump(self): - return NestedConfig(value=42) - - modifier = MyModifier(Path("./subject_db")) - model = MyModel(nested=NestedConfig(value=1)) - modified = modifier.inject(model) - modifier.dump() - ``` - """ - - def __init__(self, subject_db_path: Path, model_path: str, model_name: str, **kwargs) -> None: - """ - Initializes the ByAnimalModifier. - - Args: - subject_db_path: Path to the directory containing subject-specific JSON files - model_path: Dot-separated path to the target attribute in the rig model - model_name: Base name for the JSON file (without .json extension) - **kwargs: Additional keyword arguments (reserved for future use) - """ - self._subject_db_path = Path(subject_db_path) - self._model_path = model_path - self._model_name = model_name - self._tp: TypeAdapter[Any] | None = None - - def _process_before_inject(self, deserialized: T) -> T: - """ - Hook method called after deserialization but before injection. - - Override this method to modify the deserialized data before it's - injected into the rig model. - - Args: - deserialized: The deserialized object from the JSON file - - Returns: - The processed object to be injected - """ - return deserialized - - @abc.abstractmethod - def _process_before_dump(self) -> Any: - """ - Abstract method to generate the data to be dumped to JSON. - - Subclasses must implement this method to return the object that - should be serialized and saved to the JSON file. - - Returns: - The object to be serialized and dumped to JSON - """ - ... - - def inject(self, rig: TRig) -> TRig: - """ - Injects subject-specific configuration into the rig model. - - Loads configuration from a JSON file and injects it into the specified - path in the rig model. If the file doesn't exist, the rig is returned - unmodified with a warning logged. - - Args: - rig: The rig model to modify - - Returns: - The modified rig model - """ - target_file = self._subject_db_path / f"{self._model_name}.json" - if not target_file.exists(): - logger.warning(f"File not found: {target_file}. Using default.") - else: - target = rgetattr(rig, self._model_path) - self._tp = TypeAdapter(type(target)) - deserialized = self._tp.validate_json(target_file.read_text(encoding="utf-8")) - logger.info(f"Loading {self._model_name} from: {target_file}. Deserialized: {deserialized}") - self._process_before_inject(deserialized) - rsetattr(rig, self._model_path, deserialized) - return rig - - def dump(self) -> None: - """ - Dumps the configuration to a JSON file. - - Calls _process_before_dump() to get the data, then serializes it - to JSON and writes it to the target file. Creates parent directories - if they don't exist. - - Raises: - Exception: If _process_before_dump() fails or serialization fails - """ - target_folder = self._subject_db_path - target_file = target_folder / f"{self._model_name}.json" - - if (tp := self._tp) is None: - logger.warning("TypeAdapter is not set. Using TypeAdapter(Any) as fallback.") - tp = TypeAdapter(Any) - - try: - to_inject = self._process_before_dump() - logger.info(f"Saving {self._model_name} to: {target_file}. Serialized: {to_inject}") - target_folder.mkdir(parents=True, exist_ok=True) - target_file.write_text(tp.dump_json(to_inject, indent=2).decode("utf-8"), encoding="utf-8") - except Exception as e: - logger.error(f"Failed to process before dumping modifier: {e}") - raise - - -def rsetattr(obj, attr, val): - """ - Sets an attribute value using a dot-separated path. - - Args: - obj: The object to modify - attr: Dot-separated attribute path (e.g., "nested.field.value") - val: The value to set - - Returns: - The result of setattr on the final attribute - - Example: - ```python - class Inner: - value = 1 - - class Outer: - inner = Inner() - - obj = Outer() - rsetattr(obj, "inner.value", 42) - assert obj.inner.value == 42 - ``` - """ - pre, _, post = attr.rpartition(".") - return setattr(rgetattr(obj, pre) if pre else obj, post, val) - - -def rgetattr(obj, attr, *args): - """ - Gets an attribute value using a dot-separated path. - - Args: - obj: The object to query - attr: Dot-separated attribute path (e.g., "nested.field.value") - *args: Optional default value if attribute doesn't exist - - Returns: - The attribute value at the specified path - - Example: - ```python - class Inner: - value = 42 - - class Outer: - inner = Inner() - - obj = Outer() - result = rgetattr(obj, "inner.value") - assert result == 42 - - default = rgetattr(obj, "nonexistent.path", "default") - assert default == "default" - ``` - """ - - def _getattr(obj, attr): - """Helper function to get attribute with optional default.""" - return getattr(obj, attr, *args) - - return functools.reduce(_getattr, [obj] + attr.split(".")) diff --git a/tests/pickers/test_by_animal_modifier.py b/tests/pickers/test_by_animal_modifier.py index 8090bff..17bf0d4 100644 --- a/tests/pickers/test_by_animal_modifier.py +++ b/tests/pickers/test_by_animal_modifier.py @@ -4,7 +4,7 @@ import pydantic import pytest -from clabe.pickers.default_behavior import ByAnimalModifier +from clabe.pickers import ByAnimalModifier class NestedModel(pydantic.BaseModel): From d28758809c6ddd7a17d0bdd593b981e6833d7c88 Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 13 Nov 2025 16:12:41 -0800 Subject: [PATCH 4/8] Implement cached settings --- src/clabe/cached_settings.py | 284 ++++++++++++++++++++++++++++++++++ tests/test_cached_settings.py | 215 +++++++++++++++++++++++++ 2 files changed, 499 insertions(+) create mode 100644 src/clabe/cached_settings.py create mode 100644 tests/test_cached_settings.py diff --git a/src/clabe/cached_settings.py b/src/clabe/cached_settings.py new file mode 100644 index 0000000..85d6bae --- /dev/null +++ b/src/clabe/cached_settings.py @@ -0,0 +1,284 @@ +"""Local cache manager for maintaining settings history with configurable limits.""" + +import logging +import threading +from enum import Enum +from pathlib import Path +from typing import Any, ClassVar, Generic, TypeVar + +from pydantic import BaseModel, Field + +from .constants import TMP_DIR + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class SyncStrategy(str, Enum): + """Strategy for syncing cache to disk.""" + + MANUAL = "manual" # Only save when explicitly called + AUTO = "auto" # Save after every modification + + +class CachedSettings(BaseModel, Generic[T]): + """ + Manages a cache of values with a configurable history limit. + + When a new value is added and the cache is full, the oldest value is removed. + + Attributes: + values: List of cached values, newest first + max_history: Maximum number of items to retain in cache + + Example: + >>> cache = CachedSettings[str](max_history=3) + >>> cache.add("first") + >>> cache.add("second") + >>> cache.get_all() + ['second', 'first'] + """ + + values: list[T] = Field(default_factory=list) + max_history: int = Field(default=5, gt=0) + + def add(self, value: T) -> None: + """ + Add a new value to the cache. + + If the value already exists, it's moved to the front. + If the cache is full, the oldest value is removed. + + Args: + value: The value to add to the cache + """ + if value in self.values: + self.values.remove(value) + self.values.insert(0, value) + + if len(self.values) > self.max_history: + self.values = self.values[: self.max_history] + + def get_all(self) -> list[T]: + """ + Get all cached values. + + Returns: + List of all cached values, newest first + """ + return self.values.copy() + + def get_latest(self) -> T | None: + """ + Get the most recently added value. + + Returns: + The latest value, or None if cache is empty + """ + return self.values[0] if self.values else None + + def clear(self) -> None: + """Clear all values from the cache.""" + self.values = [] + + +class CacheData(BaseModel): + """Pydantic model for cache serialization.""" + + caches: dict[str, CachedSettings[Any]] = Field(default_factory=dict) + + +class CacheManager: + """ + Thread-safe singleton cache manager with multiple named caches. + + Uses Pydantic for proper serialization/deserialization with automatic + disk synchronization support. All operations are thread-safe. + + Example: + >>> # Get singleton instance with manual sync (default) + >>> manager = CacheManager.get_instance() + >>> manager.add_to_cache("subjects", "mouse_001") + >>> manager.save() # Explicitly save + >>> + >>> # Get instance with auto sync - saves after every change + >>> manager = CacheManager.get_instance(sync_strategy=SyncStrategy.AUTO) + >>> manager.add_to_cache("subjects", "mouse_002") # Automatically saved + >>> + >>> # Custom path + >>> manager = CacheManager.get_instance(cache_path="custom/cache.json") + """ + + _instance: ClassVar["CacheManager | None"] = None + _lock: ClassVar[threading.RLock] = threading.RLock() + + def __init__( + self, + cache_path: Path | str | None = None, + sync_strategy: SyncStrategy = SyncStrategy.MANUAL, + ) -> None: + """ + Initialize a CacheManager instance. + + Args: + cache_path: Path to cache file. If None, uses default location. + sync_strategy: Strategy for syncing to disk (MANUAL or AUTO) + """ + self.caches: dict[str, CachedSettings[Any]] = {} + self.sync_strategy: SyncStrategy = sync_strategy + self.cache_path: Path = Path(cache_path) if cache_path else Path(TMP_DIR) / ".cache_manager.json" + self._instance_lock: threading.RLock = threading.RLock() + + @classmethod + def get_instance( + cls, + cache_path: Path | str | None = None, + sync_strategy: SyncStrategy = SyncStrategy.MANUAL, + reset: bool = False, + ) -> "CacheManager": + """ + Get the singleton instance of CacheManager (thread-safe). + + Args: + cache_path: Path to cache file. If None, uses default location. + sync_strategy: Strategy for syncing to disk (MANUAL or AUTO) + reset: If True, reset the singleton and create a new instance + + Returns: + The singleton CacheManager instance + """ + with cls._lock: + if reset or cls._instance is None: + if cache_path is None: + cache_path = Path(TMP_DIR) / ".cache_manager.json" + else: + cache_path = Path(cache_path) + + instance = cls(cache_path=cache_path, sync_strategy=sync_strategy) + + if cache_path.exists(): + try: + with cache_path.open("r", encoding="utf-8") as f: + cache_data = CacheData.model_validate_json(f.read()) + instance.caches = cache_data.caches + except Exception as e: + logger.warning(f"Cache file {cache_path} is corrupted: {e}. Creating new instance.") + + cls._instance = instance + + return cls._instance + + def _auto_save(self) -> None: + """Save to disk if auto-sync is enabled (caller must hold lock).""" + if self.sync_strategy == SyncStrategy.AUTO: + self._save_unlocked() + + def _save_unlocked(self) -> None: + """Internal save method without locking (caller must hold lock).""" + self.cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_data = CacheData(caches=self.caches) + with self.cache_path.open("w", encoding="utf-8") as f: + f.write(cache_data.model_dump_json(indent=2)) + + def register_cache(self, name: str, max_history: int = 5) -> None: + """ + Register a new cache with a specific history limit (thread-safe). + + Args: + name: Unique name for the cache + max_history: Maximum number of items to retain + """ + with self._instance_lock: + if name not in self.caches: + self.caches[name] = CachedSettings(max_history=max_history) + self._auto_save() + + def add_to_cache(self, name: str, value: Any) -> None: + """ + Add a value to a named cache (thread-safe). + + Args: + name: Name of the cache + value: Value to add + + Raises: + KeyError: If cache name is not registered + """ + with self._instance_lock: + if name not in self.caches: + self.caches[name] = CachedSettings(max_history=5) + + cache = self.caches[name] + + # we remove it first to avoid duplicates + if value in cache.values: + cache.values.remove(value) + # but add it to the front + cache.values.insert(0, value) + + if len(cache.values) > cache.max_history: + cache.values = cache.values[: cache.max_history] + + self._auto_save() + + def get_cache(self, name: str) -> list[Any]: + """ + Get all values from a named cache (thread-safe). + + Args: + name: Name of the cache + + Returns: + List of cached values, newest first + + Raises: + KeyError: If cache name is not registered + """ + with self._instance_lock: + if name not in self.caches: + raise KeyError(f"Cache '{name}' not registered.") + return self.caches[name].values.copy() + + def get_latest(self, name: str) -> Any | None: + """ + Get the most recent value from a named cache (thread-safe). + + Args: + name: Name of the cache + + Returns: + The latest value, or None if cache is empty + + Raises: + KeyError: If cache name is not registered + """ + with self._instance_lock: + values = self.get_cache(name) + return values[0] if values else None + + def clear_cache(self, name: str) -> None: + """ + Clear all values from a named cache (thread-safe). + + Args: + name: Name of the cache + + Raises: + KeyError: If cache name is not registered + """ + with self._instance_lock: + if name not in self.caches: + raise KeyError(f"Cache '{name}' not registered.") + self.caches[name].values = [] + self._auto_save() + + def save(self) -> None: + """ + Save all caches to disk using Pydantic serialization (thread-safe). + + This method is called automatically if sync_strategy is AUTO, + or can be called manually for MANUAL strategy. + """ + with self._instance_lock: + self._save_unlocked() diff --git a/tests/test_cached_settings.py b/tests/test_cached_settings.py new file mode 100644 index 0000000..091fd16 --- /dev/null +++ b/tests/test_cached_settings.py @@ -0,0 +1,215 @@ +"""Tests for cached settings manager with auto-sync capabilities.""" + +import tempfile +from pathlib import Path + +from clabe.cached_settings import CachedSettings, CacheManager, SyncStrategy + + +class TestCachedSettings: + """Tests for the generic CachedSettings class.""" + + def test_add_single_value(self): + """Test adding a single value to cache.""" + cache = CachedSettings[str](max_history=3) + cache.add("first") + assert cache.get_all() == ["first"] + assert cache.get_latest() == "first" + + def test_add_multiple_values(self): + """Test adding multiple values maintains order (newest first).""" + cache = CachedSettings[str](max_history=5) + cache.add("first") + cache.add("second") + cache.add("third") + assert cache.get_all() == ["third", "second", "first"] + assert cache.get_latest() == "third" + + def test_max_history_limit(self): + """Test that oldest values are removed when limit is exceeded.""" + cache = CachedSettings[str](max_history=3) + cache.add("first") + cache.add("second") + cache.add("third") + cache.add("fourth") # Should remove "first" + + assert cache.get_all() == ["fourth", "third", "second"] + assert len(cache.get_all()) == 3 + + def test_duplicate_values_moved_to_front(self): + """Test that adding a duplicate moves it to the front.""" + cache = CachedSettings[str](max_history=5) + cache.add("first") + cache.add("second") + cache.add("third") + cache.add("first") # Should move "first" to front + + assert cache.get_all() == ["first", "third", "second"] + + def test_clear(self): + """Test clearing the cache.""" + cache = CachedSettings[str](max_history=3) + cache.add("first") + cache.add("second") + cache.clear() + + assert cache.get_all() == [] + assert cache.get_latest() is None + + def test_get_latest_empty(self): + """Test getting latest from empty cache returns None.""" + cache = CachedSettings[str](max_history=3) + assert cache.get_latest() is None + + +class TestCacheManagerManualSync: + """Tests for CacheManager with manual sync strategy.""" + + def test_register_and_add(self): + """Test registering a cache and adding values.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.register_cache("subjects", max_history=3) + manager.add_to_cache("subjects", "mouse_001") + manager.add_to_cache("subjects", "mouse_002") + + assert manager.get_cache("subjects") == ["mouse_002", "mouse_001"] + + def test_auto_register_on_add(self): + """Test that adding to unregistered cache auto-registers it.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.add_to_cache("subjects", "mouse_001") + assert manager.get_cache("subjects") == ["mouse_001"] + + def test_multiple_caches(self): + """Test managing multiple independent caches.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.register_cache("subjects", max_history=3) + manager.register_cache("experimenters", max_history=2) + + manager.add_to_cache("subjects", "mouse_001") + manager.add_to_cache("subjects", "mouse_002") + manager.add_to_cache("experimenters", "alice") + manager.add_to_cache("experimenters", "bob") + + assert manager.get_cache("subjects") == ["mouse_002", "mouse_001"] + assert manager.get_cache("experimenters") == ["bob", "alice"] + + def test_get_latest(self): + """Test getting the latest value from a cache.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.add_to_cache("test", "first") + manager.add_to_cache("test", "second") + + assert manager.get_latest("test") == "second" + + def test_get_latest_empty(self): + """Test getting latest from empty cache returns None.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.register_cache("test", max_history=3) + + assert manager.get_latest("test") is None + + def test_clear_cache(self): + """Test clearing a specific cache.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.add_to_cache("test", "value1") + manager.add_to_cache("test", "value2") + manager.clear_cache("test") + + assert manager.get_cache("test") == [] + + def test_singleton_behavior(self): + """Test that get_instance returns the same instance.""" + manager1 = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager1.add_to_cache("test", "value1") + + manager2 = CacheManager.get_instance() + assert manager2.get_cache("test") == ["value1"] + assert manager1 is manager2 + + +class TestCacheManagerAutoSync: + """Tests for CacheManager with auto-sync to disk.""" + + def test_auto_sync_on_add(self): + """Test that AUTO sync saves after adding values.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "cache.json" + + # Create with auto-sync + manager = CacheManager.get_instance(reset=True, cache_path=path, sync_strategy=SyncStrategy.AUTO) + manager.add_to_cache("subjects", "mouse_001") + + # Verify file was created automatically + assert path.exists() + + # Load in a new instance and verify data persisted + manager2 = CacheManager.get_instance(reset=True, cache_path=path) + assert manager2.get_cache("subjects") == ["mouse_001"] + + def test_auto_sync_on_clear(self): + """Test that AUTO sync saves after clearing.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "cache.json" + + manager = CacheManager.get_instance(reset=True, cache_path=path, sync_strategy=SyncStrategy.AUTO) + manager.add_to_cache("test", "value1") + manager.clear_cache("test") + + # Reload and verify clear persisted + manager2 = CacheManager.get_instance(reset=True, cache_path=path) + assert manager2.get_cache("test") == [] + + def test_manual_sync_does_not_auto_save(self): + """Test that MANUAL sync does not save automatically.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "cache.json" + + manager = CacheManager.get_instance(reset=True, cache_path=path, sync_strategy=SyncStrategy.MANUAL) + manager.add_to_cache("test", "value1") + + # File should not exist yet + assert not path.exists() + + # Explicit save + manager.save() + assert path.exists() + + def test_load_nonexistent_file(self): + """Test loading from non-existent file returns empty manager.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "nonexistent.json" + manager = CacheManager.get_instance(reset=True, cache_path=path) + + assert manager.caches == {} + + def test_persistence_across_loads(self): + """Test data persists correctly across save/load cycles.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "cache.json" + + manager1 = CacheManager.get_instance(reset=True, cache_path=path, sync_strategy=SyncStrategy.AUTO) + manager1.add_to_cache("subjects", "mouse_001") + manager1.add_to_cache("subjects", "mouse_002") + manager1.add_to_cache("projects", "project_a") + + manager2 = CacheManager.get_instance(reset=True, cache_path=path) + assert manager2.get_cache("subjects") == ["mouse_002", "mouse_001"] + assert manager2.get_cache("projects") == ["project_a"] + + manager2.add_to_cache("subjects", "mouse_003") + manager2.save() + + manager3 = CacheManager.get_instance(reset=True, cache_path=path) + assert manager3.get_cache("subjects") == ["mouse_003", "mouse_002", "mouse_001"] + + def test_default_cache_path(self): + """Test that default cache path is used when none specified.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.add_to_cache("test", "value") + manager.save() + + manager2 = CacheManager.get_instance(reset=True) + assert manager2.get_cache("test") == ["value"] + + manager.cache_path.unlink(missing_ok=True) From b9880f75d3a7820287c4986be1fee567b41986e3 Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 13 Nov 2025 20:06:18 -0800 Subject: [PATCH 5/8] Add cli interface for cache --- .../{cached_settings.py => cache_manager.py} | 38 ++++++++++++++++++- src/clabe/cli.py | 2 + 2 files changed, 38 insertions(+), 2 deletions(-) rename src/clabe/{cached_settings.py => cache_manager.py} (89%) diff --git a/src/clabe/cached_settings.py b/src/clabe/cache_manager.py similarity index 89% rename from src/clabe/cached_settings.py rename to src/clabe/cache_manager.py index 85d6bae..ea838a7 100644 --- a/src/clabe/cached_settings.py +++ b/src/clabe/cache_manager.py @@ -1,5 +1,3 @@ -"""Local cache manager for maintaining settings history with configurable limits.""" - import logging import threading from enum import Enum @@ -7,6 +5,7 @@ from typing import Any, ClassVar, Generic, TypeVar from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings, CliApp, CliSubCommand from .constants import TMP_DIR @@ -273,6 +272,12 @@ def clear_cache(self, name: str) -> None: self.caches[name].values = [] self._auto_save() + def clear_all_caches(self) -> None: + """Clear all caches (thread-safe).""" + with self._instance_lock: + self.caches = {} + self._auto_save() + def save(self) -> None: """ Save all caches to disk using Pydantic serialization (thread-safe). @@ -282,3 +287,32 @@ def save(self) -> None: """ with self._instance_lock: self._save_unlocked() + + +class _ListCacheCli(BaseSettings): + """CLI command to list all caches and their contents.""" + + def cli_cmd(self): + manager = CacheManager.get_instance() + if not manager.caches: + logger.info("No caches available.") + for name, cache in manager.caches.items(): + logger.info(f"Cache '{name}': {cache.values}") + + +class _ResetCacheCli(BaseSettings): + """CLI command to reset all caches.""" + + def cli_cmd(self): + CacheManager.get_instance().clear_all_caches() + logger.info("All caches have been cleared.") + + +class _CacheManagerCli(BaseSettings): + """CLI application wrapper for the RPC server.""" + + reset: CliSubCommand[_ResetCacheCli] + list: CliSubCommand[_ListCacheCli] + + def cli_cmd(self): + CliApp.run_subcommand(self) diff --git a/src/clabe/cli.py b/src/clabe/cli.py index 7348407..9b41797 100644 --- a/src/clabe/cli.py +++ b/src/clabe/cli.py @@ -1,5 +1,6 @@ from pydantic_settings import BaseSettings, CliApp, CliSubCommand +from .cache_manager import _CacheManagerCli from .xml_rpc._server import _XmlRpcServerStartCli @@ -7,6 +8,7 @@ class CliAppSettings(BaseSettings, cli_prog_name="clabe", cli_kebab_case=True): """CLI application settings.""" xml_rpc_server: CliSubCommand[_XmlRpcServerStartCli] + cache: CliSubCommand[_CacheManagerCli] def cli_cmd(self): """Run the selected subcommand.""" From 870809cb20e7723a65fca2bc07e0f4d333968a11 Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 13 Nov 2025 20:07:28 -0800 Subject: [PATCH 6/8] Add tests for clearing cache --- tests/test_cached_settings.py | 37 +++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/tests/test_cached_settings.py b/tests/test_cached_settings.py index 091fd16..fd0fe4d 100644 --- a/tests/test_cached_settings.py +++ b/tests/test_cached_settings.py @@ -3,7 +3,7 @@ import tempfile from pathlib import Path -from clabe.cached_settings import CachedSettings, CacheManager, SyncStrategy +from clabe.cache_manager import CachedSettings, CacheManager, SyncStrategy class TestCachedSettings: @@ -118,6 +118,22 @@ def test_clear_cache(self): assert manager.get_cache("test") == [] + def test_clear_all_caches(self): + """Test clearing all caches at once.""" + manager = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) + manager.add_to_cache("subjects", "mouse_001") + manager.add_to_cache("subjects", "mouse_002") + manager.add_to_cache("experimenters", "alice") + manager.add_to_cache("projects", "project_a") + + assert len(manager.caches) == 3 + assert manager.get_cache("subjects") == ["mouse_002", "mouse_001"] + + manager.clear_all_caches() + + assert manager.caches == {} + assert len(manager.caches) == 0 + def test_singleton_behavior(self): """Test that get_instance returns the same instance.""" manager1 = CacheManager.get_instance(reset=True, sync_strategy=SyncStrategy.MANUAL) @@ -136,14 +152,11 @@ def test_auto_sync_on_add(self): with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) / "cache.json" - # Create with auto-sync manager = CacheManager.get_instance(reset=True, cache_path=path, sync_strategy=SyncStrategy.AUTO) manager.add_to_cache("subjects", "mouse_001") - # Verify file was created automatically assert path.exists() - # Load in a new instance and verify data persisted manager2 = CacheManager.get_instance(reset=True, cache_path=path) assert manager2.get_cache("subjects") == ["mouse_001"] @@ -160,6 +173,22 @@ def test_auto_sync_on_clear(self): manager2 = CacheManager.get_instance(reset=True, cache_path=path) assert manager2.get_cache("test") == [] + def test_auto_sync_on_clear_all(self): + """Test that AUTO sync saves after clearing all caches.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "cache.json" + + manager = CacheManager.get_instance(reset=True, cache_path=path, sync_strategy=SyncStrategy.AUTO) + manager.add_to_cache("subjects", "mouse_001") + manager.add_to_cache("projects", "project_a") + + assert path.exists() + + manager.clear_all_caches() + + manager2 = CacheManager.get_instance(reset=True, cache_path=path) + assert manager2.caches == {} + def test_manual_sync_does_not_auto_save(self): """Test that MANUAL sync does not save automatically.""" with tempfile.TemporaryDirectory() as tmpdir: From d76cb4ceb3ef5502d798671e8a4f354cdd9c0a9e Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 13 Nov 2025 22:38:40 -0800 Subject: [PATCH 7/8] Implement cache manager in picker --- src/clabe/cache_manager.py | 10 ++- src/clabe/pickers/default_behavior.py | 112 ++++++++++++++++++-------- src/clabe/ui/ui_helper.py | 4 +- 3 files changed, 87 insertions(+), 39 deletions(-) diff --git a/src/clabe/cache_manager.py b/src/clabe/cache_manager.py index ea838a7..3538f38 100644 --- a/src/clabe/cache_manager.py +++ b/src/clabe/cache_manager.py @@ -115,7 +115,7 @@ class CacheManager: def __init__( self, cache_path: Path | str | None = None, - sync_strategy: SyncStrategy = SyncStrategy.MANUAL, + sync_strategy: SyncStrategy = SyncStrategy.AUTO, ) -> None: """ Initialize a CacheManager instance. @@ -133,7 +133,7 @@ def __init__( def get_instance( cls, cache_path: Path | str | None = None, - sync_strategy: SyncStrategy = SyncStrategy.MANUAL, + sync_strategy: SyncStrategy = SyncStrategy.AUTO, reset: bool = False, ) -> "CacheManager": """ @@ -239,6 +239,12 @@ def get_cache(self, name: str) -> list[Any]: raise KeyError(f"Cache '{name}' not registered.") return self.caches[name].values.copy() + def try_get_cache(self, name: str) -> Any | None: + try: + return self.get_cache(name) + except KeyError: + return None + def get_latest(self, name: str) -> Any | None: """ Get the most recent value from a named cache (thread-safe). diff --git a/src/clabe/pickers/default_behavior.py b/src/clabe/pickers/default_behavior.py index c256032..4e93621 100644 --- a/src/clabe/pickers/default_behavior.py +++ b/src/clabe/pickers/default_behavior.py @@ -11,6 +11,7 @@ from .. import ui from .._typing import TRig, TSession, TTaskLogic +from ..cache_manager import CacheManager from ..constants import ByAnimalFiles from ..launcher import Launcher from ..services import ServiceSettings @@ -87,6 +88,7 @@ def __init__( self._experimenter_validator = experimenter_validator self._trainer_state: Optional[TrainerState] = None self._session: Optional[AindBehaviorSessionModel] = None + self._cache_manager = CacheManager.get_instance() @property def ui_helper(self) -> ui.UiHelper: @@ -196,27 +198,54 @@ def pick_rig(self, model: Type[TRig]) -> TRig: Raises: ValueError: If no rig configuration files are found or an invalid choice is made """ - available_rigs = glob.glob(os.path.join(self.rig_dir, "*.json")) - if len(available_rigs) == 0: - logger.error("No rig config files found.") - raise ValueError("No rig config files found.") - elif len(available_rigs) == 1: - logger.info("Found a single rig config file. Using %s.", {available_rigs[0]}) - rig = model_from_json_file(available_rigs[0], model) + rig: TRig | None = None + rig_path: str | None = None + + # Check cache for previously used rigs + cache = self._cache_manager.try_get_cache("rigs") + if cache: + rig_path = self.ui_helper.prompt_pick_from_list( + cache, + prompt="Choose a rig:", + allow_0_as_none=True, + zero_label="Select from library", + ) + if rig_path is not None: + rig = self._load_rig_from_path(Path(rig_path), model) + + # Prompt user to select a rig if not already selected + while rig is None: + available_rigs = glob.glob(os.path.join(self.rig_dir, "*.json")) + # We raise if no rigs are found to prevent an infinite loop + if len(available_rigs) == 0: + logger.error("No rig config files found.") + raise ValueError("No rig config files found.") + # Use the single available rig config file + elif len(available_rigs) == 1: + logger.info("Found a single rig config file. Using %s.", {available_rigs[0]}) + rig_path = available_rigs[0] + rig = model_from_json_file(rig_path, model) + else: + rig_path = self.ui_helper.prompt_pick_from_list(available_rigs, prompt="Choose a rig:") + if rig_path is not None: + rig = self._load_rig_from_path(Path(rig_path), model) + assert rig_path is not None + # Add the selected rig path to the cache + cache = self._cache_manager.add_to_cache("rigs", rig_path) + return rig + + @staticmethod + def _load_rig_from_path(path: Path, model: Type[TRig]) -> TRig | None: + try: + if not isinstance(path, str): + raise ValueError("Invalid choice.") + rig = model_from_json_file(path, model) + logger.info("Using %s.", path) return rig - else: - while True: - try: - path = self.ui_helper.prompt_pick_from_list(available_rigs, prompt="Choose a rig:") - if not isinstance(path, str): - raise ValueError("Invalid choice.") - rig = model_from_json_file(path, model) - logger.info("Using %s.", path) - return rig - except pydantic.ValidationError as e: - logger.error("Failed to validate pydantic model. Try again. %s", e) - except ValueError as e: - logger.info("Invalid choice. Try again. %s", e) + except pydantic.ValidationError as e: + logger.error("Failed to validate pydantic model. Try again. %s", e) + except ValueError as e: + logger.info("Invalid choice. Try again. %s", e) def pick_session(self, model: Type[TSession] = AindBehaviorSessionModel) -> TSession: """ @@ -378,22 +407,22 @@ def choose_subject(self, directory: str | os.PathLike) -> str: subject = picker.choose_subject("Subjects") ``` """ - subject = None + subjects = self._cache_manager.try_get_cache("subjects") + if subjects: + subject = self.ui_helper.prompt_pick_from_list( + subjects, + prompt="Choose a subject:", + allow_0_as_none=True, + zero_label="Enter manually", + ) + else: + subject = None + while subject is None: subject = self.ui_helper.input("Enter subject name: ") if subject == "": - subject = self.ui_helper.prompt_pick_from_list( - [ - os.path.basename(folder) - for folder in os.listdir(directory) - if os.path.isdir(os.path.join(directory, folder)) - ], - prompt="Choose a subject:", - allow_0_as_none=True, - ) - else: - return subject - + subject = None + self._cache_manager.add_to_cache("subjects", subject) return subject def prompt_experimenter(self, strict: bool = True) -> Optional[List[str]]: @@ -416,10 +445,22 @@ def prompt_experimenter(self, strict: bool = True) -> Optional[List[str]]: print("Experimenters:", names) ``` """ + experimenters_cache = self._cache_manager.try_get_cache("experimenters") experimenter: Optional[List[str]] = None + _picked: str | None = None while experimenter is None: - _user_input = self.ui_helper.prompt_text("Experimenter name: ") - experimenter = _user_input.replace(",", " ").split() + if experimenters_cache: + _picked = self.ui_helper.prompt_pick_from_list( + experimenters_cache, + prompt="Choose an experimenter:", + allow_0_as_none=True, + zero_label="Enter manually", + ) + if _picked is None: + _input = self.ui_helper.prompt_text("Experimenter name: ") + else: + _input = _picked + experimenter = _input.replace(",", " ").split() if strict & (len(experimenter) == 0): logger.info("Experimenter name is not valid. Try again.") experimenter = None @@ -430,6 +471,7 @@ def prompt_experimenter(self, strict: bool = True) -> Optional[List[str]]: logger.warning("Experimenter name: %s, is not valid. Try again", name) experimenter = None break + self._cache_manager.add_to_cache("experimenters", ",".join(experimenter)) return experimenter def dump_model( diff --git a/src/clabe/ui/ui_helper.py b/src/clabe/ui/ui_helper.py index 252b0c2..f4a157b 100644 --- a/src/clabe/ui/ui_helper.py +++ b/src/clabe/ui/ui_helper.py @@ -191,7 +191,7 @@ def input(self, prompt: str) -> str: return self._input(prompt) def prompt_pick_from_list( - self, value: List[str], prompt: str, allow_0_as_none: bool = True, **kwargs + self, value: List[str], prompt: str, *, allow_0_as_none: bool = True, zero_label: str = "None", **kwargs ) -> Optional[str]: """ Prompts the user to pick an item from a list. @@ -223,7 +223,7 @@ def prompt_pick_from_list( try: self.print(prompt) if allow_0_as_none: - self.print("0: None") + self.print(f"0: {zero_label}") for i, item in enumerate(value): self.print(f"{i + 1}: {item}") choice = int(input("Choice: ")) From f3027abe86999d7a59e2a31c00848696c294d30a Mon Sep 17 00:00:00 2001 From: bruno-f-cruz <7049351+bruno-f-cruz@users.noreply.github.com> Date: Thu, 13 Nov 2025 22:44:48 -0800 Subject: [PATCH 8/8] Add missing doc strings --- src/clabe/cache_manager.py | 4 ++++ src/clabe/pickers/default_behavior.py | 1 + 2 files changed, 5 insertions(+) diff --git a/src/clabe/cache_manager.py b/src/clabe/cache_manager.py index 3538f38..5c1c9f4 100644 --- a/src/clabe/cache_manager.py +++ b/src/clabe/cache_manager.py @@ -240,6 +240,7 @@ def get_cache(self, name: str) -> list[Any]: return self.caches[name].values.copy() def try_get_cache(self, name: str) -> Any | None: + """Attempt to get all values from a named cache, returning None if not found.""" try: return self.get_cache(name) except KeyError: @@ -299,6 +300,7 @@ class _ListCacheCli(BaseSettings): """CLI command to list all caches and their contents.""" def cli_cmd(self): + """Run the list cache CLI command.""" manager = CacheManager.get_instance() if not manager.caches: logger.info("No caches available.") @@ -310,6 +312,7 @@ class _ResetCacheCli(BaseSettings): """CLI command to reset all caches.""" def cli_cmd(self): + """Run the reset cache CLI command.""" CacheManager.get_instance().clear_all_caches() logger.info("All caches have been cleared.") @@ -321,4 +324,5 @@ class _CacheManagerCli(BaseSettings): list: CliSubCommand[_ListCacheCli] def cli_cmd(self): + """Run the cache manager CLI.""" CliApp.run_subcommand(self) diff --git a/src/clabe/pickers/default_behavior.py b/src/clabe/pickers/default_behavior.py index 4e93621..6dd718e 100644 --- a/src/clabe/pickers/default_behavior.py +++ b/src/clabe/pickers/default_behavior.py @@ -236,6 +236,7 @@ def pick_rig(self, model: Type[TRig]) -> TRig: @staticmethod def _load_rig_from_path(path: Path, model: Type[TRig]) -> TRig | None: + """Load a rig configuration from a given path.""" try: if not isinstance(path, str): raise ValueError("Invalid choice.")