Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/behavior_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ def fmt(value: str) -> str:
repository=launcher.repository,
script_path=Path("./mock/script.py"),
output_parameters={"suggestion": suggestion.model_dump()},
)
launcher.copy_logs()

).map()
return


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"semver",
"rich",
"aind_behavior_services < 1",
"questionary",
]

[project.urls]
Expand Down
7 changes: 4 additions & 3 deletions src/clabe/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
logger = logging.getLogger(__name__)

T = TypeVar("T")
_DEFAULT_MAX_HISTORY = 9


class SyncStrategy(str, Enum):
Expand Down Expand Up @@ -40,7 +41,7 @@ class CachedSettings(BaseModel, Generic[T]):
"""

values: list[T] = Field(default_factory=list)
max_history: int = Field(default=5, gt=0)
max_history: int = Field(default=_DEFAULT_MAX_HISTORY, gt=0)

def add(self, value: T) -> None:
"""
Expand Down Expand Up @@ -180,7 +181,7 @@ def _save_unlocked(self) -> None:
with self.cache_path.open("w", encoding="utf-8") as f:
f.write(cache_data.model_dump_json(indent=2))

def register_cache(self, name: str, max_history: int = 5) -> None:
def register_cache(self, name: str, max_history: int = _DEFAULT_MAX_HISTORY) -> None:
"""
Register a new cache with a specific history limit (thread-safe).

Expand All @@ -206,7 +207,7 @@ def add_to_cache(self, name: str, value: Any) -> None:
"""
with self._instance_lock:
if name not in self.caches:
self.caches[name] = CachedSettings(max_history=5)
self.caches[name] = CachedSettings(max_history=_DEFAULT_MAX_HISTORY)

cache = self.caches[name]

Expand Down
30 changes: 23 additions & 7 deletions src/clabe/pickers/default_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
launcher: Launcher,
ui_helper: Optional[ui.UiHelper] = None,
experimenter_validator: Optional[Callable[[str], bool]] = validate_aind_username,
use_cache: bool = True,
):
"""
Initializes the DefaultBehaviorPicker.
Expand All @@ -80,6 +81,7 @@ def __init__(
launcher: The launcher instance for managing experiment execution
ui_helper: Helper for user interface interactions. If None, uses launcher's ui_helper. Defaults to None
experimenter_validator: Function to validate the experimenter's username. If None, no validation is performed. Defaults to validate_aind_username
use_cache: Whether to use caching for selections. Defaults to True
"""
self._launcher = launcher
self._ui_helper = launcher.ui_helper if ui_helper is None else ui_helper
Expand All @@ -89,6 +91,7 @@ def __init__(
self._trainer_state: Optional[TrainerState] = None
self._session: Optional[AindBehaviorSessionModel] = None
self._cache_manager = CacheManager.get_instance()
self._use_cache = use_cache

@property
def ui_helper(self) -> ui.UiHelper:
Expand Down Expand Up @@ -202,8 +205,13 @@ def pick_rig(self, model: Type[TRig]) -> TRig:
rig_path: str | None = None

# Check cache for previously used rigs
cache = self._cache_manager.try_get_cache("rigs")
if self._use_cache:
cache = self._cache_manager.try_get_cache("rigs")
else:
cache = None

if cache:
cache.sort()
rig_path = self.ui_helper.prompt_pick_from_list(
cache,
prompt="Choose a rig:",
Expand All @@ -214,7 +222,7 @@ def pick_rig(self, model: Type[TRig]) -> TRig:
rig = self._load_rig_from_path(Path(rig_path), model)

# Prompt user to select a rig if not already selected
while rig is None:
while rig_path is None:
available_rigs = glob.glob(os.path.join(self.rig_dir, "*.json"))
# We raise if no rigs are found to prevent an infinite loop
if len(available_rigs) == 0:
Expand All @@ -230,23 +238,23 @@ def pick_rig(self, model: Type[TRig]) -> TRig:
if rig_path is not None:
rig = self._load_rig_from_path(Path(rig_path), model)
assert rig_path is not None
assert rig is not None
# Add the selected rig path to the cache
cache = self._cache_manager.add_to_cache("rigs", rig_path)
self._cache_manager.add_to_cache("rigs", rig_path)
return rig

@staticmethod
def _load_rig_from_path(path: Path, model: Type[TRig]) -> TRig | None:
"""Load a rig configuration from a given path."""
try:
if not isinstance(path, str):
raise ValueError("Invalid choice.")
rig = model_from_json_file(path, model)
logger.info("Using %s.", path)
return rig
except pydantic.ValidationError as e:
logger.error("Failed to validate pydantic model. Try again. %s", e)
except ValueError as e:
logger.info("Invalid choice. Try again. %s", e)
return None

def pick_session(self, model: Type[TSession] = AindBehaviorSessionModel) -> TSession:
"""
Expand Down Expand Up @@ -408,8 +416,12 @@ def choose_subject(self, directory: str | os.PathLike) -> str:
subject = picker.choose_subject("Subjects")
```
"""
subjects = self._cache_manager.try_get_cache("subjects")
if self._use_cache:
subjects = self._cache_manager.try_get_cache("subjects")
else:
subjects = None
if subjects:
subjects.sort()
subject = self.ui_helper.prompt_pick_from_list(
subjects,
prompt="Choose a subject:",
Expand Down Expand Up @@ -446,11 +458,15 @@ def prompt_experimenter(self, strict: bool = True) -> Optional[List[str]]:
print("Experimenters:", names)
```
"""
experimenters_cache = self._cache_manager.try_get_cache("experimenters")
if self._use_cache:
experimenters_cache = self._cache_manager.try_get_cache("experimenters")
else:
experimenters_cache = None
experimenter: Optional[List[str]] = None
_picked: str | None = None
while experimenter is None:
if experimenters_cache:
experimenters_cache.sort()
_picked = self.ui_helper.prompt_pick_from_list(
experimenters_cache,
prompt="Choose an experimenter:",
Expand Down
7 changes: 5 additions & 2 deletions src/clabe/ui/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from .ui_helper import DefaultUIHelper, UiHelper, prompt_field_from_input
from .questionary_ui_helper import QuestionaryUIHelper
from .ui_helper import NativeUiHelper, UiHelper, prompt_field_from_input

__all__ = ["DefaultUIHelper", "UiHelper", "prompt_field_from_input"]
DefaultUIHelper = QuestionaryUIHelper

__all__ = ["DefaultUIHelper", "UiHelper", "prompt_field_from_input", "NativeUiHelper", "QuestionaryUIHelper"]
110 changes: 110 additions & 0 deletions src/clabe/ui/questionary_ui_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import asyncio
import logging
from typing import List, Optional

import questionary
from questionary import Style

from .ui_helper import _UiHelperBase

logger = logging.getLogger(__name__)

custom_style = Style(
[
("qmark", "fg:#5f87ff bold"), # Question mark - blue
("question", "fg:#ffffff bold"), # Question text - white bold
("answer", "fg:#5f87ff bold"), # Selected answer - blue
("pointer", "fg:#5f87ff bold"), # Pointer - blue arrow
("highlighted", "fg:#000000 bg:#5f87ff bold"), # INVERTED: black text on blue background
("selected", "fg:#5f87ff"), # After selection
("separator", "fg:#666666"), # Separator
("instruction", "fg:#888888"), # Instructions
("text", ""), # Plain text
("disabled", "fg:#858585 italic"), # Disabled
]
)


def _ask_sync(question):
"""Ask question, handling both sync and async contexts.

When in an async context, runs the questionary prompt in a thread pool
to avoid the "asyncio.run() cannot be called from a running event loop" error.
"""
try:
# Check if we're in an async context
asyncio.get_running_loop()
# We are in an async context - use thread pool to avoid nested event loop
import concurrent.futures

with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(question.ask)
return future.result()
except RuntimeError:
# No running loop - use normal ask()
return question.ask()


class QuestionaryUIHelper(_UiHelperBase):
"""UI helper implementation using Questionary for interactive prompts."""

def __init__(self, style: Optional[questionary.Style] = None) -> None:
"""Initializes the QuestionaryUIHelper with an optional custom style."""
self.style = style or custom_style

def print(self, message: str) -> None:
"""Prints a message with custom styling."""
questionary.print(message, "bold italic")

def input(self, prompt: str) -> str:
"""Prompts the user for input with custom styling."""
return _ask_sync(questionary.text(prompt, style=self.style)) or ""

def prompt_pick_from_list(self, value: List[str], prompt: str, **kwargs) -> Optional[str]:
"""Interactive list selection with visual highlighting using arrow keys or number shortcuts."""
allow_0_as_none = kwargs.get("allow_0_as_none", True)
zero_label = kwargs.get("zero_label", "None")

choices = []

if allow_0_as_none:
choices.append(zero_label)

choices.extend(value)

result = _ask_sync(
questionary.select(
prompt,
choices=choices,
style=self.style,
use_arrow_keys=True,
use_indicator=True,
use_shortcuts=True,
)
)

if result is None:
return None

if result == zero_label and allow_0_as_none:
return None

return result

def prompt_yes_no_question(self, prompt: str) -> bool:
"""Prompts the user with a yes/no question using custom styling."""
return _ask_sync(questionary.confirm(prompt, style=self.style)) or False

def prompt_text(self, prompt: str) -> str:
"""Prompts the user for generic text input using custom styling."""
return _ask_sync(questionary.text(prompt, style=self.style)) or ""

def prompt_float(self, prompt: str) -> float:
"""Prompts the user for a float input using custom styling."""
while True:
try:
value_str = _ask_sync(questionary.text(prompt, style=self.style))
if value_str:
return float(value_str)
except ValueError:
self.print("Invalid input. Please enter a valid float.")
2 changes: 1 addition & 1 deletion src/clabe/ui/ui_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def prompt_float(self, prompt: str) -> float:
UiHelper: TypeAlias = _UiHelperBase


class DefaultUIHelper(_UiHelperBase):
class NativeUiHelper(_UiHelperBase):
"""
Default implementation of the UI helper for user interaction.

Expand Down
6 changes: 3 additions & 3 deletions tests/test_cached_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tempfile
from pathlib import Path

from clabe.cache_manager import CachedSettings, CacheManager, SyncStrategy
from clabe.cache_manager import _DEFAULT_MAX_HISTORY, CachedSettings, CacheManager, SyncStrategy


class TestCachedSettings:
Expand All @@ -18,7 +18,7 @@ def test_add_single_value(self):

def test_add_multiple_values(self):
"""Test adding multiple values maintains order (newest first)."""
cache = CachedSettings[str](max_history=5)
cache = CachedSettings[str](max_history=_DEFAULT_MAX_HISTORY)
cache.add("first")
cache.add("second")
cache.add("third")
Expand All @@ -38,7 +38,7 @@ def test_max_history_limit(self):

def test_duplicate_values_moved_to_front(self):
"""Test that adding a duplicate moves it to the front."""
cache = CachedSettings[str](max_history=5)
cache = CachedSettings[str](max_history=_DEFAULT_MAX_HISTORY)
cache.add("first")
cache.add("second")
cache.add("third")
Expand Down
6 changes: 3 additions & 3 deletions tests/ui/test_ui.py → tests/ui/test_native_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import pytest

from clabe.ui import DefaultUIHelper
from clabe.ui import NativeUiHelper


@pytest.fixture
def ui_helper():
return DefaultUIHelper(print_func=MagicMock())
return NativeUiHelper(print_func=MagicMock())


class TestDefaultUiHelper:
class TestNativeUiHelper:
@patch("builtins.input", side_effect=["Some notes"])
def test_prompt_get_text(self, mock_input, ui_helper):
result = ui_helper.prompt_text("")
Expand Down
Loading