Skip to content

Commit b114cdb

Browse files
authored
Merge pull request #206 from AllenNeuralDynamics/development
Merge development branch
2 parents 98dabd8 + 8b6efa8 commit b114cdb

File tree

10 files changed

+188
-22
lines changed

10 files changed

+188
-22
lines changed

examples/behavior_launcher.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ def fmt(value: str) -> str:
7070
repository=launcher.repository,
7171
script_path=Path("./mock/script.py"),
7272
output_parameters={"suggestion": suggestion.model_dump()},
73-
)
74-
launcher.copy_logs()
75-
73+
).map()
7674
return
7775

7876

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dependencies = [
2929
"semver",
3030
"rich",
3131
"aind_behavior_services < 1",
32+
"questionary",
3233
]
3334

3435
[project.urls]

src/clabe/cache_manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
logger = logging.getLogger(__name__)
1313

1414
T = TypeVar("T")
15+
_DEFAULT_MAX_HISTORY = 9
1516

1617

1718
class SyncStrategy(str, Enum):
@@ -40,7 +41,7 @@ class CachedSettings(BaseModel, Generic[T]):
4041
"""
4142

4243
values: list[T] = Field(default_factory=list)
43-
max_history: int = Field(default=5, gt=0)
44+
max_history: int = Field(default=_DEFAULT_MAX_HISTORY, gt=0)
4445

4546
def add(self, value: T) -> None:
4647
"""
@@ -180,7 +181,7 @@ def _save_unlocked(self) -> None:
180181
with self.cache_path.open("w", encoding="utf-8") as f:
181182
f.write(cache_data.model_dump_json(indent=2))
182183

183-
def register_cache(self, name: str, max_history: int = 5) -> None:
184+
def register_cache(self, name: str, max_history: int = _DEFAULT_MAX_HISTORY) -> None:
184185
"""
185186
Register a new cache with a specific history limit (thread-safe).
186187
@@ -206,7 +207,7 @@ def add_to_cache(self, name: str, value: Any) -> None:
206207
"""
207208
with self._instance_lock:
208209
if name not in self.caches:
209-
self.caches[name] = CachedSettings(max_history=5)
210+
self.caches[name] = CachedSettings(max_history=_DEFAULT_MAX_HISTORY)
210211

211212
cache = self.caches[name]
212213

src/clabe/pickers/default_behavior.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
launcher: Launcher,
7272
ui_helper: Optional[ui.UiHelper] = None,
7373
experimenter_validator: Optional[Callable[[str], bool]] = validate_aind_username,
74+
use_cache: bool = True,
7475
):
7576
"""
7677
Initializes the DefaultBehaviorPicker.
@@ -80,6 +81,7 @@ def __init__(
8081
launcher: The launcher instance for managing experiment execution
8182
ui_helper: Helper for user interface interactions. If None, uses launcher's ui_helper. Defaults to None
8283
experimenter_validator: Function to validate the experimenter's username. If None, no validation is performed. Defaults to validate_aind_username
84+
use_cache: Whether to use caching for selections. Defaults to True
8385
"""
8486
self._launcher = launcher
8587
self._ui_helper = launcher.ui_helper if ui_helper is None else ui_helper
@@ -89,6 +91,7 @@ def __init__(
8991
self._trainer_state: Optional[TrainerState] = None
9092
self._session: Optional[AindBehaviorSessionModel] = None
9193
self._cache_manager = CacheManager.get_instance()
94+
self._use_cache = use_cache
9295

9396
@property
9497
def ui_helper(self) -> ui.UiHelper:
@@ -202,8 +205,13 @@ def pick_rig(self, model: Type[TRig]) -> TRig:
202205
rig_path: str | None = None
203206

204207
# Check cache for previously used rigs
205-
cache = self._cache_manager.try_get_cache("rigs")
208+
if self._use_cache:
209+
cache = self._cache_manager.try_get_cache("rigs")
210+
else:
211+
cache = None
212+
206213
if cache:
214+
cache.sort()
207215
rig_path = self.ui_helper.prompt_pick_from_list(
208216
cache,
209217
prompt="Choose a rig:",
@@ -214,7 +222,7 @@ def pick_rig(self, model: Type[TRig]) -> TRig:
214222
rig = self._load_rig_from_path(Path(rig_path), model)
215223

216224
# Prompt user to select a rig if not already selected
217-
while rig is None:
225+
while rig_path is None:
218226
available_rigs = glob.glob(os.path.join(self.rig_dir, "*.json"))
219227
# We raise if no rigs are found to prevent an infinite loop
220228
if len(available_rigs) == 0:
@@ -230,23 +238,23 @@ def pick_rig(self, model: Type[TRig]) -> TRig:
230238
if rig_path is not None:
231239
rig = self._load_rig_from_path(Path(rig_path), model)
232240
assert rig_path is not None
241+
assert rig is not None
233242
# Add the selected rig path to the cache
234-
cache = self._cache_manager.add_to_cache("rigs", rig_path)
243+
self._cache_manager.add_to_cache("rigs", rig_path)
235244
return rig
236245

237246
@staticmethod
238247
def _load_rig_from_path(path: Path, model: Type[TRig]) -> TRig | None:
239248
"""Load a rig configuration from a given path."""
240249
try:
241-
if not isinstance(path, str):
242-
raise ValueError("Invalid choice.")
243250
rig = model_from_json_file(path, model)
244251
logger.info("Using %s.", path)
245252
return rig
246253
except pydantic.ValidationError as e:
247254
logger.error("Failed to validate pydantic model. Try again. %s", e)
248255
except ValueError as e:
249256
logger.info("Invalid choice. Try again. %s", e)
257+
return None
250258

251259
def pick_session(self, model: Type[TSession] = AindBehaviorSessionModel) -> TSession:
252260
"""
@@ -408,8 +416,12 @@ def choose_subject(self, directory: str | os.PathLike) -> str:
408416
subject = picker.choose_subject("Subjects")
409417
```
410418
"""
411-
subjects = self._cache_manager.try_get_cache("subjects")
419+
if self._use_cache:
420+
subjects = self._cache_manager.try_get_cache("subjects")
421+
else:
422+
subjects = None
412423
if subjects:
424+
subjects.sort()
413425
subject = self.ui_helper.prompt_pick_from_list(
414426
subjects,
415427
prompt="Choose a subject:",
@@ -446,11 +458,15 @@ def prompt_experimenter(self, strict: bool = True) -> Optional[List[str]]:
446458
print("Experimenters:", names)
447459
```
448460
"""
449-
experimenters_cache = self._cache_manager.try_get_cache("experimenters")
461+
if self._use_cache:
462+
experimenters_cache = self._cache_manager.try_get_cache("experimenters")
463+
else:
464+
experimenters_cache = None
450465
experimenter: Optional[List[str]] = None
451466
_picked: str | None = None
452467
while experimenter is None:
453468
if experimenters_cache:
469+
experimenters_cache.sort()
454470
_picked = self.ui_helper.prompt_pick_from_list(
455471
experimenters_cache,
456472
prompt="Choose an experimenter:",

src/clabe/ui/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1-
from .ui_helper import DefaultUIHelper, UiHelper, prompt_field_from_input
1+
from .questionary_ui_helper import QuestionaryUIHelper
2+
from .ui_helper import NativeUiHelper, UiHelper, prompt_field_from_input
23

3-
__all__ = ["DefaultUIHelper", "UiHelper", "prompt_field_from_input"]
4+
DefaultUIHelper = QuestionaryUIHelper
5+
6+
__all__ = ["DefaultUIHelper", "UiHelper", "prompt_field_from_input", "NativeUiHelper", "QuestionaryUIHelper"]
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import asyncio
2+
import logging
3+
from typing import List, Optional
4+
5+
import questionary
6+
from questionary import Style
7+
8+
from .ui_helper import _UiHelperBase
9+
10+
logger = logging.getLogger(__name__)
11+
12+
custom_style = Style(
13+
[
14+
("qmark", "fg:#5f87ff bold"), # Question mark - blue
15+
("question", "fg:#ffffff bold"), # Question text - white bold
16+
("answer", "fg:#5f87ff bold"), # Selected answer - blue
17+
("pointer", "fg:#5f87ff bold"), # Pointer - blue arrow
18+
("highlighted", "fg:#000000 bg:#5f87ff bold"), # INVERTED: black text on blue background
19+
("selected", "fg:#5f87ff"), # After selection
20+
("separator", "fg:#666666"), # Separator
21+
("instruction", "fg:#888888"), # Instructions
22+
("text", ""), # Plain text
23+
("disabled", "fg:#858585 italic"), # Disabled
24+
]
25+
)
26+
27+
28+
def _ask_sync(question):
29+
"""Ask question, handling both sync and async contexts.
30+
31+
When in an async context, runs the questionary prompt in a thread pool
32+
to avoid the "asyncio.run() cannot be called from a running event loop" error.
33+
"""
34+
try:
35+
# Check if we're in an async context
36+
asyncio.get_running_loop()
37+
# We are in an async context - use thread pool to avoid nested event loop
38+
import concurrent.futures
39+
40+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
41+
future = executor.submit(question.ask)
42+
return future.result()
43+
except RuntimeError:
44+
# No running loop - use normal ask()
45+
return question.ask()
46+
47+
48+
class QuestionaryUIHelper(_UiHelperBase):
49+
"""UI helper implementation using Questionary for interactive prompts."""
50+
51+
def __init__(self, style: Optional[questionary.Style] = None) -> None:
52+
"""Initializes the QuestionaryUIHelper with an optional custom style."""
53+
self.style = style or custom_style
54+
55+
def print(self, message: str) -> None:
56+
"""Prints a message with custom styling."""
57+
questionary.print(message, "bold italic")
58+
59+
def input(self, prompt: str) -> str:
60+
"""Prompts the user for input with custom styling."""
61+
return _ask_sync(questionary.text(prompt, style=self.style)) or ""
62+
63+
def prompt_pick_from_list(self, value: List[str], prompt: str, **kwargs) -> Optional[str]:
64+
"""Interactive list selection with visual highlighting using arrow keys or number shortcuts."""
65+
allow_0_as_none = kwargs.get("allow_0_as_none", True)
66+
zero_label = kwargs.get("zero_label", "None")
67+
68+
choices = []
69+
70+
if allow_0_as_none:
71+
choices.append(zero_label)
72+
73+
choices.extend(value)
74+
75+
result = _ask_sync(
76+
questionary.select(
77+
prompt,
78+
choices=choices,
79+
style=self.style,
80+
use_arrow_keys=True,
81+
use_indicator=True,
82+
use_shortcuts=True,
83+
)
84+
)
85+
86+
if result is None:
87+
return None
88+
89+
if result == zero_label and allow_0_as_none:
90+
return None
91+
92+
return result
93+
94+
def prompt_yes_no_question(self, prompt: str) -> bool:
95+
"""Prompts the user with a yes/no question using custom styling."""
96+
return _ask_sync(questionary.confirm(prompt, style=self.style)) or False
97+
98+
def prompt_text(self, prompt: str) -> str:
99+
"""Prompts the user for generic text input using custom styling."""
100+
return _ask_sync(questionary.text(prompt, style=self.style)) or ""
101+
102+
def prompt_float(self, prompt: str) -> float:
103+
"""Prompts the user for a float input using custom styling."""
104+
while True:
105+
try:
106+
value_str = _ask_sync(questionary.text(prompt, style=self.style))
107+
if value_str:
108+
return float(value_str)
109+
except ValueError:
110+
self.print("Invalid input. Please enter a valid float.")

src/clabe/ui/ui_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def prompt_float(self, prompt: str) -> float:
137137
UiHelper: TypeAlias = _UiHelperBase
138138

139139

140-
class DefaultUIHelper(_UiHelperBase):
140+
class NativeUiHelper(_UiHelperBase):
141141
"""
142142
Default implementation of the UI helper for user interaction.
143143

tests/test_cached_settings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import tempfile
44
from pathlib import Path
55

6-
from clabe.cache_manager import CachedSettings, CacheManager, SyncStrategy
6+
from clabe.cache_manager import _DEFAULT_MAX_HISTORY, CachedSettings, CacheManager, SyncStrategy
77

88

99
class TestCachedSettings:
@@ -18,7 +18,7 @@ def test_add_single_value(self):
1818

1919
def test_add_multiple_values(self):
2020
"""Test adding multiple values maintains order (newest first)."""
21-
cache = CachedSettings[str](max_history=5)
21+
cache = CachedSettings[str](max_history=_DEFAULT_MAX_HISTORY)
2222
cache.add("first")
2323
cache.add("second")
2424
cache.add("third")
@@ -38,7 +38,7 @@ def test_max_history_limit(self):
3838

3939
def test_duplicate_values_moved_to_front(self):
4040
"""Test that adding a duplicate moves it to the front."""
41-
cache = CachedSettings[str](max_history=5)
41+
cache = CachedSettings[str](max_history=_DEFAULT_MAX_HISTORY)
4242
cache.add("first")
4343
cache.add("second")
4444
cache.add("third")
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
import pytest
44

5-
from clabe.ui import DefaultUIHelper
5+
from clabe.ui import NativeUiHelper
66

77

88
@pytest.fixture
99
def ui_helper():
10-
return DefaultUIHelper(print_func=MagicMock())
10+
return NativeUiHelper(print_func=MagicMock())
1111

1212

13-
class TestDefaultUiHelper:
13+
class TestNativeUiHelper:
1414
@patch("builtins.input", side_effect=["Some notes"])
1515
def test_prompt_get_text(self, mock_input, ui_helper):
1616
result = ui_helper.prompt_text("")

0 commit comments

Comments
 (0)