diff --git a/examples/behavior_launcher.py b/examples/behavior_launcher.py index 7f5e8a12..04a11d34 100644 --- a/examples/behavior_launcher.py +++ b/examples/behavior_launcher.py @@ -70,9 +70,7 @@ def fmt(value: str) -> str: repository=launcher.repository, script_path=Path("./mock/script.py"), output_parameters={"suggestion": suggestion.model_dump()}, - ) - launcher.copy_logs() - + ).map() return diff --git a/pyproject.toml b/pyproject.toml index a007cfb2..4b611ed6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "semver", "rich", "aind_behavior_services < 1", + "questionary", ] [project.urls] diff --git a/src/clabe/cache_manager.py b/src/clabe/cache_manager.py index 5c1c9f45..69085db3 100644 --- a/src/clabe/cache_manager.py +++ b/src/clabe/cache_manager.py @@ -12,6 +12,7 @@ logger = logging.getLogger(__name__) T = TypeVar("T") +_DEFAULT_MAX_HISTORY = 9 class SyncStrategy(str, Enum): @@ -40,7 +41,7 @@ class CachedSettings(BaseModel, Generic[T]): """ values: list[T] = Field(default_factory=list) - max_history: int = Field(default=5, gt=0) + max_history: int = Field(default=_DEFAULT_MAX_HISTORY, gt=0) def add(self, value: T) -> None: """ @@ -180,7 +181,7 @@ def _save_unlocked(self) -> None: with self.cache_path.open("w", encoding="utf-8") as f: f.write(cache_data.model_dump_json(indent=2)) - def register_cache(self, name: str, max_history: int = 5) -> None: + def register_cache(self, name: str, max_history: int = _DEFAULT_MAX_HISTORY) -> None: """ Register a new cache with a specific history limit (thread-safe). @@ -206,7 +207,7 @@ def add_to_cache(self, name: str, value: Any) -> None: """ with self._instance_lock: if name not in self.caches: - self.caches[name] = CachedSettings(max_history=5) + self.caches[name] = CachedSettings(max_history=_DEFAULT_MAX_HISTORY) cache = self.caches[name] diff --git a/src/clabe/pickers/default_behavior.py b/src/clabe/pickers/default_behavior.py index 6dd718e5..2bbef821 100644 --- a/src/clabe/pickers/default_behavior.py +++ b/src/clabe/pickers/default_behavior.py @@ -71,6 +71,7 @@ def __init__( launcher: Launcher, ui_helper: Optional[ui.UiHelper] = None, experimenter_validator: Optional[Callable[[str], bool]] = validate_aind_username, + use_cache: bool = True, ): """ Initializes the DefaultBehaviorPicker. @@ -80,6 +81,7 @@ def __init__( launcher: The launcher instance for managing experiment execution ui_helper: Helper for user interface interactions. If None, uses launcher's ui_helper. Defaults to None experimenter_validator: Function to validate the experimenter's username. If None, no validation is performed. Defaults to validate_aind_username + use_cache: Whether to use caching for selections. Defaults to True """ self._launcher = launcher self._ui_helper = launcher.ui_helper if ui_helper is None else ui_helper @@ -89,6 +91,7 @@ def __init__( self._trainer_state: Optional[TrainerState] = None self._session: Optional[AindBehaviorSessionModel] = None self._cache_manager = CacheManager.get_instance() + self._use_cache = use_cache @property def ui_helper(self) -> ui.UiHelper: @@ -202,8 +205,13 @@ def pick_rig(self, model: Type[TRig]) -> TRig: rig_path: str | None = None # Check cache for previously used rigs - cache = self._cache_manager.try_get_cache("rigs") + if self._use_cache: + cache = self._cache_manager.try_get_cache("rigs") + else: + cache = None + if cache: + cache.sort() rig_path = self.ui_helper.prompt_pick_from_list( cache, prompt="Choose a rig:", @@ -214,7 +222,7 @@ def pick_rig(self, model: Type[TRig]) -> TRig: rig = self._load_rig_from_path(Path(rig_path), model) # Prompt user to select a rig if not already selected - while rig is None: + while rig_path is None: available_rigs = glob.glob(os.path.join(self.rig_dir, "*.json")) # We raise if no rigs are found to prevent an infinite loop if len(available_rigs) == 0: @@ -230,16 +238,15 @@ def pick_rig(self, model: Type[TRig]) -> TRig: if rig_path is not None: rig = self._load_rig_from_path(Path(rig_path), model) assert rig_path is not None + assert rig is not None # Add the selected rig path to the cache - cache = self._cache_manager.add_to_cache("rigs", rig_path) + self._cache_manager.add_to_cache("rigs", rig_path) return rig @staticmethod def _load_rig_from_path(path: Path, model: Type[TRig]) -> TRig | None: """Load a rig configuration from a given path.""" try: - if not isinstance(path, str): - raise ValueError("Invalid choice.") rig = model_from_json_file(path, model) logger.info("Using %s.", path) return rig @@ -247,6 +254,7 @@ def _load_rig_from_path(path: Path, model: Type[TRig]) -> TRig | None: logger.error("Failed to validate pydantic model. Try again. %s", e) except ValueError as e: logger.info("Invalid choice. Try again. %s", e) + return None def pick_session(self, model: Type[TSession] = AindBehaviorSessionModel) -> TSession: """ @@ -408,8 +416,12 @@ def choose_subject(self, directory: str | os.PathLike) -> str: subject = picker.choose_subject("Subjects") ``` """ - subjects = self._cache_manager.try_get_cache("subjects") + if self._use_cache: + subjects = self._cache_manager.try_get_cache("subjects") + else: + subjects = None if subjects: + subjects.sort() subject = self.ui_helper.prompt_pick_from_list( subjects, prompt="Choose a subject:", @@ -446,11 +458,15 @@ def prompt_experimenter(self, strict: bool = True) -> Optional[List[str]]: print("Experimenters:", names) ``` """ - experimenters_cache = self._cache_manager.try_get_cache("experimenters") + if self._use_cache: + experimenters_cache = self._cache_manager.try_get_cache("experimenters") + else: + experimenters_cache = None experimenter: Optional[List[str]] = None _picked: str | None = None while experimenter is None: if experimenters_cache: + experimenters_cache.sort() _picked = self.ui_helper.prompt_pick_from_list( experimenters_cache, prompt="Choose an experimenter:", diff --git a/src/clabe/ui/__init__.py b/src/clabe/ui/__init__.py index e61dd87e..459fb20f 100644 --- a/src/clabe/ui/__init__.py +++ b/src/clabe/ui/__init__.py @@ -1,3 +1,6 @@ -from .ui_helper import DefaultUIHelper, UiHelper, prompt_field_from_input +from .questionary_ui_helper import QuestionaryUIHelper +from .ui_helper import NativeUiHelper, UiHelper, prompt_field_from_input -__all__ = ["DefaultUIHelper", "UiHelper", "prompt_field_from_input"] +DefaultUIHelper = QuestionaryUIHelper + +__all__ = ["DefaultUIHelper", "UiHelper", "prompt_field_from_input", "NativeUiHelper", "QuestionaryUIHelper"] diff --git a/src/clabe/ui/questionary_ui_helper.py b/src/clabe/ui/questionary_ui_helper.py new file mode 100644 index 00000000..7c0cc9b1 --- /dev/null +++ b/src/clabe/ui/questionary_ui_helper.py @@ -0,0 +1,110 @@ +import asyncio +import logging +from typing import List, Optional + +import questionary +from questionary import Style + +from .ui_helper import _UiHelperBase + +logger = logging.getLogger(__name__) + +custom_style = Style( + [ + ("qmark", "fg:#5f87ff bold"), # Question mark - blue + ("question", "fg:#ffffff bold"), # Question text - white bold + ("answer", "fg:#5f87ff bold"), # Selected answer - blue + ("pointer", "fg:#5f87ff bold"), # Pointer - blue arrow + ("highlighted", "fg:#000000 bg:#5f87ff bold"), # INVERTED: black text on blue background + ("selected", "fg:#5f87ff"), # After selection + ("separator", "fg:#666666"), # Separator + ("instruction", "fg:#888888"), # Instructions + ("text", ""), # Plain text + ("disabled", "fg:#858585 italic"), # Disabled + ] +) + + +def _ask_sync(question): + """Ask question, handling both sync and async contexts. + + When in an async context, runs the questionary prompt in a thread pool + to avoid the "asyncio.run() cannot be called from a running event loop" error. + """ + try: + # Check if we're in an async context + asyncio.get_running_loop() + # We are in an async context - use thread pool to avoid nested event loop + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(question.ask) + return future.result() + except RuntimeError: + # No running loop - use normal ask() + return question.ask() + + +class QuestionaryUIHelper(_UiHelperBase): + """UI helper implementation using Questionary for interactive prompts.""" + + def __init__(self, style: Optional[questionary.Style] = None) -> None: + """Initializes the QuestionaryUIHelper with an optional custom style.""" + self.style = style or custom_style + + def print(self, message: str) -> None: + """Prints a message with custom styling.""" + questionary.print(message, "bold italic") + + def input(self, prompt: str) -> str: + """Prompts the user for input with custom styling.""" + return _ask_sync(questionary.text(prompt, style=self.style)) or "" + + def prompt_pick_from_list(self, value: List[str], prompt: str, **kwargs) -> Optional[str]: + """Interactive list selection with visual highlighting using arrow keys or number shortcuts.""" + allow_0_as_none = kwargs.get("allow_0_as_none", True) + zero_label = kwargs.get("zero_label", "None") + + choices = [] + + if allow_0_as_none: + choices.append(zero_label) + + choices.extend(value) + + result = _ask_sync( + questionary.select( + prompt, + choices=choices, + style=self.style, + use_arrow_keys=True, + use_indicator=True, + use_shortcuts=True, + ) + ) + + if result is None: + return None + + if result == zero_label and allow_0_as_none: + return None + + return result + + def prompt_yes_no_question(self, prompt: str) -> bool: + """Prompts the user with a yes/no question using custom styling.""" + return _ask_sync(questionary.confirm(prompt, style=self.style)) or False + + def prompt_text(self, prompt: str) -> str: + """Prompts the user for generic text input using custom styling.""" + return _ask_sync(questionary.text(prompt, style=self.style)) or "" + + def prompt_float(self, prompt: str) -> float: + """Prompts the user for a float input using custom styling.""" + while True: + try: + value_str = _ask_sync(questionary.text(prompt, style=self.style)) + if value_str: + return float(value_str) + except ValueError: + self.print("Invalid input. Please enter a valid float.") diff --git a/src/clabe/ui/ui_helper.py b/src/clabe/ui/ui_helper.py index f4a157bf..5339ec72 100644 --- a/src/clabe/ui/ui_helper.py +++ b/src/clabe/ui/ui_helper.py @@ -137,7 +137,7 @@ def prompt_float(self, prompt: str) -> float: UiHelper: TypeAlias = _UiHelperBase -class DefaultUIHelper(_UiHelperBase): +class NativeUiHelper(_UiHelperBase): """ Default implementation of the UI helper for user interaction. diff --git a/tests/test_cached_settings.py b/tests/test_cached_settings.py index fd0fe4d2..6fbff8b1 100644 --- a/tests/test_cached_settings.py +++ b/tests/test_cached_settings.py @@ -3,7 +3,7 @@ import tempfile from pathlib import Path -from clabe.cache_manager import CachedSettings, CacheManager, SyncStrategy +from clabe.cache_manager import _DEFAULT_MAX_HISTORY, CachedSettings, CacheManager, SyncStrategy class TestCachedSettings: @@ -18,7 +18,7 @@ def test_add_single_value(self): def test_add_multiple_values(self): """Test adding multiple values maintains order (newest first).""" - cache = CachedSettings[str](max_history=5) + cache = CachedSettings[str](max_history=_DEFAULT_MAX_HISTORY) cache.add("first") cache.add("second") cache.add("third") @@ -38,7 +38,7 @@ def test_max_history_limit(self): def test_duplicate_values_moved_to_front(self): """Test that adding a duplicate moves it to the front.""" - cache = CachedSettings[str](max_history=5) + cache = CachedSettings[str](max_history=_DEFAULT_MAX_HISTORY) cache.add("first") cache.add("second") cache.add("third") diff --git a/tests/ui/test_ui.py b/tests/ui/test_native_ui.py similarity index 89% rename from tests/ui/test_ui.py rename to tests/ui/test_native_ui.py index 682b4f6f..ffe5975e 100644 --- a/tests/ui/test_ui.py +++ b/tests/ui/test_native_ui.py @@ -2,15 +2,15 @@ import pytest -from clabe.ui import DefaultUIHelper +from clabe.ui import NativeUiHelper @pytest.fixture def ui_helper(): - return DefaultUIHelper(print_func=MagicMock()) + return NativeUiHelper(print_func=MagicMock()) -class TestDefaultUiHelper: +class TestNativeUiHelper: @patch("builtins.input", side_effect=["Some notes"]) def test_prompt_get_text(self, mock_input, ui_helper): result = ui_helper.prompt_text("") diff --git a/uv.lock b/uv.lock index 47824042..be529e25 100644 --- a/uv.lock +++ b/uv.lock @@ -46,6 +46,7 @@ dependencies = [ { name = "gitpython" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "questionary" }, { name = "requests" }, { name = "rich" }, { name = "semver" }, @@ -99,6 +100,7 @@ requires-dist = [ { name = "pydantic-settings" }, { name = "pykeepass", marker = "extra == 'aind-services'" }, { name = "pyyaml", marker = "extra == 'aind-services'" }, + { name = "questionary" }, { name = "requests" }, { name = "requests", marker = "extra == 'aind-services'" }, { name = "rich" }, @@ -1291,6 +1293,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "prompt-toolkit" +version = "3.0.52" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/96/06e01a7b38dce6fe1db213e061a4602dd6032a8a97ef6c1a862537732421/prompt_toolkit-3.0.52.tar.gz", hash = "sha256:28cde192929c8e7321de85de1ddbe736f1375148b02f2e17edd840042b1be855", size = 434198, upload-time = "2025-08-27T15:24:02.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, +] + [[package]] name = "py" version = "1.11.0" @@ -1715,6 +1729,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/11/432f32f8097b03e3cd5fe57e88efb685d964e2e5178a48ed61e841f7fdce/pyyaml_env_tag-1.1-py3-none-any.whl", hash = "sha256:17109e1a528561e32f026364712fee1264bc2ea6715120891174ed1b980d2e04", size = 4722, upload-time = "2025-05-13T15:23:59.629Z" }, ] +[[package]] +name = "questionary" +version = "2.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "prompt-toolkit" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/45/eafb0bba0f9988f6a2520f9ca2df2c82ddfa8d67c95d6625452e97b204a5/questionary-2.1.1.tar.gz", hash = "sha256:3d7e980292bb0107abaa79c68dd3eee3c561b83a0f89ae482860b181c8bd412d", size = 25845, upload-time = "2025-08-28T19:00:20.851Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/26/1062c7ec1b053db9e499b4d2d5bc231743201b74051c973dadeac80a8f43/questionary-2.1.1-py3-none-any.whl", hash = "sha256:a51af13f345f1cdea62347589fbb6df3b290306ab8930713bfae4d475a7d4a59", size = 36753, upload-time = "2025-08-28T19:00:19.56Z" }, +] + [[package]] name = "requests" version = "2.32.5" @@ -1795,6 +1821,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/fa/3234f913fe9a6525a7b97c6dad1f51e72b917e6872e051a5e2ffd8b16fbb/ruamel.yaml.clib-0.2.14-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:70eda7703b8126f5e52fcf276e6c0f40b0d314674f896fc58c47b0aef2b9ae83", size = 137970, upload-time = "2025-09-22T19:51:09.472Z" }, { url = "https://files.pythonhosted.org/packages/ef/ec/4edbf17ac2c87fa0845dd366ef8d5852b96eb58fcd65fc1ecf5fe27b4641/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a0cb71ccc6ef9ce36eecb6272c81afdc2f565950cdcec33ae8e6cd8f7fc86f27", size = 739639, upload-time = "2025-09-22T19:51:10.566Z" }, { url = "https://files.pythonhosted.org/packages/15/18/b0e1fafe59051de9e79cdd431863b03593ecfa8341c110affad7c8121efc/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e7cb9ad1d525d40f7d87b6df7c0ff916a66bc52cb61b66ac1b2a16d0c1b07640", size = 764456, upload-time = "2025-09-22T19:51:11.736Z" }, + { url = "https://files.pythonhosted.org/packages/e7/cd/150fdb96b8fab27fe08d8a59fe67554568727981806e6bc2677a16081ec7/ruamel_yaml_clib-0.2.14-cp314-cp314-win32.whl", hash = "sha256:9b4104bf43ca0cd4e6f738cb86326a3b2f6eef00f417bd1e7efb7bdffe74c539", size = 102394, upload-time = "2025-11-14T21:57:36.703Z" }, + { url = "https://files.pythonhosted.org/packages/bd/e6/a3fa40084558c7e1dc9546385f22a93949c890a8b2e445b2ba43935f51da/ruamel_yaml_clib-0.2.14-cp314-cp314-win_amd64.whl", hash = "sha256:13997d7d354a9890ea1ec5937a219817464e5cc344805b37671562a401ca3008", size = 122673, upload-time = "2025-11-14T21:57:38.177Z" }, ] [[package]] @@ -1974,6 +2002,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "wcwidth" +version = "0.2.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/30/6b0809f4510673dc723187aeaf24c7f5459922d01e2f794277a3dfb90345/wcwidth-0.2.14.tar.gz", hash = "sha256:4d478375d31bc5395a3c55c40ccdf3354688364cd61c4f6adacaa9215d0b3605", size = 102293, upload-time = "2025-09-22T16:29:53.023Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" }, +] + [[package]] name = "winkerberos" version = "0.12.2"