Skip to content

Commit 6b00e03

Browse files
authored
Merge pull request #21 from damassi/feat/add-model-selection
feat: add model selection slash command
2 parents 4787c37 + e5881d5 commit 6b00e03

File tree

11 files changed

+244
-35
lines changed

11 files changed

+244
-35
lines changed

CLAUDE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- Run commands should be prefixed with `uv`: `uv run ...`
99
- Use `asyncio` features, if such is needed
1010
- Prefer early returns
11+
- Private methods always go below public methods
1112
- Absolutely no useless comments! Every class and method does not need to be documented (unless it is legitimetly complex or "lib-ish")
1213
- Imports belong at the top of files, not inside functions (unless needed to avoid circular imports)
1314

src/agent_chat_cli/components/chat_history.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
class ChatHistory(Container):
1515
def add_message(self, message: Message) -> None:
16-
widget = self._create_message(message)
17-
self.mount(widget)
16+
message_item = self._create_message(message)
17+
self.mount(message_item)
1818

1919
def _create_message(
2020
self, message: Message
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import TYPE_CHECKING
2+
3+
from textual.widget import Widget
4+
from textual.app import ComposeResult
5+
from textual.containers import VerticalScroll
6+
from textual.widgets import OptionList
7+
from textual.widgets.option_list import Option
8+
9+
if TYPE_CHECKING:
10+
from agent_chat_cli.core.actions import Actions
11+
12+
MODELS = [
13+
{"id": "sonnet", "label": "Sonnet"},
14+
{"id": "haiku", "label": "Haiku"},
15+
{"id": "opus", "label": "Opus"},
16+
]
17+
18+
19+
class ModelSelectionMenu(Widget):
20+
def __init__(self, actions: Actions) -> None:
21+
super().__init__()
22+
self.actions = actions
23+
24+
def compose(self) -> ComposeResult:
25+
yield OptionList(*[Option(model["label"], id=model["id"]) for model in MODELS])
26+
27+
def show(self) -> None:
28+
self.add_class("visible")
29+
30+
scroll_containers = self.app.query(VerticalScroll)
31+
if scroll_containers:
32+
scroll_containers.first().scroll_end(animate=False)
33+
34+
option_list = self.query_one(OptionList)
35+
option_list.highlighted = 0
36+
option_list.focus()
37+
38+
def hide(self) -> None:
39+
self.remove_class("visible")
40+
41+
@property
42+
def is_visible(self) -> bool:
43+
return self.has_class("visible")
44+
45+
async def on_option_list_option_selected(
46+
self, event: OptionList.OptionSelected
47+
) -> None:
48+
self.hide()
49+
50+
if event.option_id:
51+
await self.actions.change_model(event.option_id)

src/agent_chat_cli/components/slash_command_menu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
COMMANDS = [
1212
{"id": "new", "label": "/new - Start new conversation"},
1313
{"id": "clear", "label": "/clear - Clear chat history"},
14+
{"id": "model", "label": "/model - Change model"},
1415
{"id": "save", "label": "/save - Save conversation to markdown"},
1516
{"id": "exit", "label": "/exit - Exit"},
1617
]
@@ -84,5 +85,7 @@ async def on_option_list_option_selected(
8485
await self.actions.clear()
8586
case "new":
8687
await self.actions.new()
88+
case "model":
89+
self.actions.show_model_menu()
8790
case "save":
8891
await self.actions.save()

src/agent_chat_cli/components/user_input.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from agent_chat_cli.components.caret import Caret
88
from agent_chat_cli.components.flex import Flex
99
from agent_chat_cli.components.slash_command_menu import SlashCommandMenu
10+
from agent_chat_cli.components.model_selection_menu import ModelSelectionMenu
1011
from agent_chat_cli.core.actions import Actions
1112
from agent_chat_cli.utils.enums import Key
1213

@@ -35,6 +36,7 @@ def compose(self) -> ComposeResult:
3536
yield SlashCommandMenu(
3637
actions=self.actions, on_filter_change=self._on_filter_change
3738
)
39+
yield ModelSelectionMenu(actions=self.actions)
3840

3941
def _on_filter_change(self, char: str) -> None:
4042
text_area = self.query_one(TextArea)
@@ -51,13 +53,15 @@ def on_descendant_blur(self, event: DescendantBlur) -> None:
5153
if not self.display:
5254
return
5355

54-
menu = self.query_one(SlashCommandMenu)
56+
menu = self._get_visible_menu()
5557

56-
if isinstance(event.widget, TextArea) and not menu.is_visible:
58+
if isinstance(event.widget, TextArea) and not menu:
5759
event.widget.focus(scroll_visible=False)
58-
elif isinstance(event.widget, OptionList) and menu.is_visible:
59-
menu.hide()
60-
self.query_one(TextArea).focus(scroll_visible=False)
60+
elif isinstance(event.widget, OptionList) and menu:
61+
menu_option_list = menu.query_one(OptionList)
62+
if event.widget == menu_option_list:
63+
menu.hide()
64+
self.query_one(TextArea).focus(scroll_visible=False)
6165

6266
def on_text_area_changed(self, event: TextArea.Changed) -> None:
6367
menu = self.query_one(SlashCommandMenu)
@@ -68,10 +72,10 @@ def on_text_area_changed(self, event: TextArea.Changed) -> None:
6872
menu.show()
6973

7074
async def on_key(self, event) -> None:
71-
menu = self.query_one(SlashCommandMenu)
75+
menu = self._get_visible_menu()
7276

73-
if menu.is_visible:
74-
self._close_menu(event)
77+
if menu:
78+
self._close_menu(event, menu)
7579
return
7680

7781
if event.key == "up":
@@ -92,9 +96,7 @@ def _insert_newline(self, event) -> None:
9296
input_widget = self.query_one(TextArea)
9397
input_widget.insert("\n")
9498

95-
def _close_menu(self, event) -> None:
96-
menu = self.query_one(SlashCommandMenu)
97-
99+
def _close_menu(self, event, menu: SlashCommandMenu | ModelSelectionMenu) -> None:
98100
if event.key == Key.ESCAPE.value:
99101
event.stop()
100102
event.prevent_default()
@@ -104,7 +106,10 @@ def _close_menu(self, event) -> None:
104106
input_widget.focus()
105107
return
106108

107-
if event.key in (Key.BACKSPACE.value, Key.DELETE.value):
109+
if isinstance(menu, SlashCommandMenu) and event.key in (
110+
Key.BACKSPACE.value,
111+
Key.DELETE.value,
112+
):
108113
if menu.filter_text:
109114
menu.filter_text = menu.filter_text[:-1]
110115
menu._refresh_options()
@@ -147,10 +152,21 @@ async def _navigate_history(self, event, direction: int) -> None:
147152
input_widget.text = self.message_history[self.history_index]
148153
input_widget.move_cursor_relative(rows=999, columns=999)
149154

155+
def _get_visible_menu(self) -> SlashCommandMenu | ModelSelectionMenu | None:
156+
slash_menu = self.query_one(SlashCommandMenu)
157+
if slash_menu.is_visible:
158+
return slash_menu
159+
160+
model_menu = self.query_one(ModelSelectionMenu)
161+
if model_menu.is_visible:
162+
return model_menu
163+
164+
return None
165+
150166
async def action_submit(self) -> None:
151-
menu = self.query_one(SlashCommandMenu)
167+
menu = self._get_visible_menu()
152168

153-
if menu.is_visible:
169+
if menu:
154170
option_list = menu.query_one(OptionList)
155171
option_list.action_select()
156172
input_widget = self.query_one(TextArea)

src/agent_chat_cli/core/actions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from agent_chat_cli.components.messages import RoleType
55
from agent_chat_cli.components.chat_history import ChatHistory
66
from agent_chat_cli.components.tool_permission_prompt import ToolPermissionPrompt
7+
from agent_chat_cli.components.model_selection_menu import ModelSelectionMenu
78
from agent_chat_cli.utils.logger import log_json
89
from agent_chat_cli.utils.save_conversation import save_conversation
910

@@ -72,5 +73,13 @@ async def save(self) -> None:
7273
f"Conversation saved to {file_path}", thinking=False
7374
)
7475

76+
def show_model_menu(self) -> None:
77+
model_menu = self.app.query_one(ModelSelectionMenu)
78+
model_menu.show()
79+
80+
async def change_model(self, model: str) -> None:
81+
await self.app.agent_loop.change_model(model)
82+
await self.post_system_message(f"Switched to {model}", thinking=False)
83+
7584
async def _query(self, user_input: str) -> None:
7685
await self.app.agent_loop.query_queue.put(user_input)

src/agent_chat_cli/core/agent_loop.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
get_available_servers,
2525
get_sdk_config,
2626
)
27-
from agent_chat_cli.utils.enums import AppEventType, ContentType, ControlCommand
27+
from agent_chat_cli.utils.enums import (
28+
AppEventType,
29+
ContentType,
30+
ControlCommand,
31+
ModelChangeCommand,
32+
)
2833
from agent_chat_cli.utils.logger import log_json
2934
from agent_chat_cli.utils.mcp_server_status import MCPServerStatus
3035

@@ -51,36 +56,33 @@ def __init__(
5156

5257
self.client: ClaudeSDKClient
5358

54-
self.query_queue: asyncio.Queue[str | ControlCommand] = asyncio.Queue()
59+
self.query_queue: asyncio.Queue[str | ControlCommand | ModelChangeCommand] = (
60+
asyncio.Queue()
61+
)
5562
self.permission_response_queue: asyncio.Queue[str] = asyncio.Queue()
5663
self.permission_lock = asyncio.Lock()
5764

5865
self._running = False
5966

6067
async def start(self) -> None:
61-
mcp_servers = {
62-
name: config.model_dump() for name, config in self.available_servers.items()
63-
}
64-
65-
await self._initialize_client(mcp_servers=mcp_servers)
68+
await self._initialize_client()
6669

6770
self._running = True
6871

6972
while self._running:
7073
user_input = await self.query_queue.get()
7174

75+
if isinstance(user_input, ModelChangeCommand):
76+
self.config.model = user_input.model
77+
await self.client.disconnect()
78+
await self._initialize_client()
79+
continue
80+
7281
if isinstance(user_input, ControlCommand):
7382
if user_input == ControlCommand.NEW_CONVERSATION:
7483
await self.client.disconnect()
75-
7684
self.session_id = None
77-
78-
mcp_servers = {
79-
name: config.model_dump()
80-
for name, config in self.available_servers.items()
81-
}
82-
83-
await self._initialize_client(mcp_servers=mcp_servers)
85+
await self._initialize_client()
8486
continue
8587

8688
self.app.ui_state.set_interrupting(False)
@@ -97,7 +99,17 @@ async def start(self) -> None:
9799
AppEvent(type=AppEventType.RESULT, data=None)
98100
)
99101

100-
async def _initialize_client(self, mcp_servers: dict) -> None:
102+
async def change_model(self, model: str) -> None:
103+
await self.query_queue.put(
104+
ModelChangeCommand(ControlCommand.CHANGE_MODEL, model)
105+
)
106+
107+
async def _initialize_client(self, mcp_servers: dict | None = None) -> None:
108+
if mcp_servers is None:
109+
mcp_servers = {
110+
name: config.model_dump()
111+
for name, config in self.available_servers.items()
112+
}
101113
sdk_config = get_sdk_config(self.config)
102114

103115
sdk_config["mcp_servers"] = mcp_servers

src/agent_chat_cli/core/styles.tcss

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,17 @@ TextArea .text-area--cursor {
108108
padding-left: 2;
109109
}
110110

111-
SlashCommandMenu {
111+
SlashCommandMenu, ModelSelectionMenu {
112112
height: auto;
113113
max-height: 10;
114114
display: none;
115115
}
116116

117-
SlashCommandMenu.visible {
117+
SlashCommandMenu.visible, ModelSelectionMenu.visible {
118118
display: block;
119119
}
120120

121-
SlashCommandMenu OptionList {
121+
SlashCommandMenu OptionList, ModelSelectionMenu OptionList {
122122
height: auto;
123123
max-height: 10;
124124
border: solid $primary;

src/agent_chat_cli/utils/enums.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from enum import Enum
2+
from typing import NamedTuple
23

34

45
class AppEventType(Enum):
@@ -20,10 +21,16 @@ class ContentType(Enum):
2021

2122
class ControlCommand(Enum):
2223
NEW_CONVERSATION = "new_conversation"
24+
CHANGE_MODEL = "change_model"
2325
EXIT = "exit"
2426
CLEAR = "clear"
2527

2628

29+
class ModelChangeCommand(NamedTuple):
30+
command: ControlCommand
31+
model: str
32+
33+
2734
class Key(Enum):
2835
ENTER = "enter"
2936
ESCAPE = "escape"

0 commit comments

Comments
 (0)