Skip to content

Commit 4f77f67

Browse files
committed
uses rustyline for managed input
1 parent 6de023d commit 4f77f67

File tree

5 files changed

+112
-86
lines changed

5 files changed

+112
-86
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/chat-cli-ui/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ tokio.workspace = true
2323
eyre.workspace = true
2424
tokio-util.workspace = true
2525
futures.workspace = true
26+
rustyline.workspace = true
2627
ratatui = "0.29.0"
2728

2829
[target.'cfg(unix)'.dependencies]

crates/chat-cli-ui/src/conduit.rs

Lines changed: 79 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@ use crossterm::{
1010
execute,
1111
queue,
1212
};
13+
use rustyline::DefaultEditor;
14+
use tracing::error;
1315

1416
use crate::legacy_ui_util::ThemeSource;
1517
use crate::protocol::{
1618
Event,
1719
InputEvent,
1820
LegacyPassThroughOutput,
21+
MetaEvent,
1922
ToolCallRejection,
2023
ToolCallStart,
2124
};
@@ -58,19 +61,25 @@ impl ViewEnd {
5861
mut stderr: std::io::Stderr,
5962
mut stdout: std::io::Stdout,
6063
) -> Result<(), ConduitError> {
64+
#[derive(Debug)]
6165
enum IncomingEvent {
6266
Input(String),
6367
Interrupt,
64-
Send,
65-
Backspace,
6668
}
6769

68-
#[derive(Default)]
70+
#[derive(Default, Debug)]
71+
struct PromptSignal {
72+
active_agent: Option<String>,
73+
trust_all: bool,
74+
}
75+
76+
#[derive(Default, Debug)]
6977
enum DisplayState {
70-
#[default]
7178
Prompting,
7279
UserInsertingText,
7380
StreamingOutput,
81+
#[default]
82+
Hidden,
7483
}
7584

7685
#[inline]
@@ -79,9 +88,8 @@ impl ViewEnd {
7988
stderr: &mut std::io::Stderr,
8089
stdout: &mut std::io::Stdout,
8190
theme_source: &impl ThemeSource,
82-
) -> Result<DisplayState, ConduitError> {
83-
let mut display_state = DisplayState::default();
84-
91+
display_state: Option<&mut DisplayState>,
92+
) -> Result<(), ConduitError> {
8593
match event {
8694
Event::LegacyPassThrough(content) => match content {
8795
LegacyPassThroughOutput::Stderr(content) => {
@@ -99,25 +107,29 @@ impl ViewEnd {
99107
Event::StepStarted(_step_started) => {},
100108
Event::StepFinished(_step_finished) => {},
101109
Event::TextMessageStart(_text_message_start) => {
102-
display_state = DisplayState::StreamingOutput;
110+
if let Some(display_state) = display_state {
111+
*display_state = DisplayState::StreamingOutput;
112+
}
103113

104114
queue!(stdout, theme_source.success_fg(), Print("> "), theme_source.reset(),)?;
105115
},
106116
Event::TextMessageContent(text_message_content) => {
107-
display_state = DisplayState::StreamingOutput;
117+
if let Some(display_state) = display_state {
118+
*display_state = DisplayState::StreamingOutput;
119+
}
108120

109121
stdout.write_all(&text_message_content.delta)?;
110122
stdout.flush()?;
111123
},
112124
Event::TextMessageEnd(_text_message_end) => {
113-
display_state = DisplayState::Prompting;
114-
115125
queue!(stderr, theme_source.reset(), theme_source.reset_attributes())?;
116126
execute!(stdout, style::Print("\n"))?;
117127
},
118128
Event::TextMessageChunk(_text_message_chunk) => {},
119129
Event::ToolCallStart(tool_call_start) => {
120-
display_state = DisplayState::StreamingOutput;
130+
if let Some(display_state) = display_state {
131+
*display_state = DisplayState::StreamingOutput;
132+
}
121133

122134
let ToolCallStart {
123135
tool_call_name,
@@ -167,12 +179,8 @@ impl ViewEnd {
167179
execute!(stdout, style::Print(tool_call_args.delta))?;
168180
}
169181
},
170-
Event::ToolCallEnd(_tool_call_end) => {
171-
// noop for now
172-
},
173-
Event::ToolCallResult(_tool_call_result) => {
174-
// noop for now (currently we don't show the tool call results to users)
175-
},
182+
Event::ToolCallEnd(_tool_call_end) => {},
183+
Event::ToolCallResult(_tool_call_result) => {},
176184
Event::StateSnapshot(_state_snapshot) => {},
177185
Event::StateDelta(_state_delta) => {},
178186
Event::MessagesSnapshot(_messages_snapshot) => {},
@@ -186,7 +194,17 @@ impl ViewEnd {
186194
Event::ReasoningMessageEnd(_reasoning_message_end) => {},
187195
Event::ReasoningMessageChunk(_reasoning_message_chunk) => {},
188196
Event::ReasoningEnd(_reasoning_end) => {},
189-
Event::MetaEvent(_meta_event) => {},
197+
Event::MetaEvent(MetaEvent { meta_type, payload }) => {
198+
if meta_type.as_str() == "timing" {
199+
if let serde_json::Value::String(s) = payload {
200+
if s.as_str() == "prompt_user" {
201+
if let Some(display_state) = display_state {
202+
*display_state = DisplayState::Prompting;
203+
}
204+
}
205+
}
206+
}
207+
},
190208
Event::ToolCallRejection(tool_call_rejection) => {
191209
let ToolCallRejection { reason, name, .. } = tool_call_rejection;
192210

@@ -205,105 +223,83 @@ impl ViewEnd {
205223
},
206224
}
207225

208-
Ok::<DisplayState, ConduitError>(display_state)
226+
Ok::<(), ConduitError>(())
209227
}
210228

211229
if handle_input {
212-
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<IncomingEvent>();
230+
let (incoming_events_tx, mut incoming_events_rx) = tokio::sync::mpsc::unbounded_channel::<IncomingEvent>();
231+
let (prompt_signal_tx, prompt_signal_rx) = std::sync::mpsc::channel::<PromptSignal>();
213232

214233
tokio::task::spawn_blocking(move || {
215-
loop {
216-
if let Ok(event) = crossterm::event::read() {
217-
match event {
218-
crossterm::event::Event::Key(key_event) => {
219-
let crossterm::event::KeyEvent { code, modifiers, .. } = key_event;
220-
221-
match (modifiers, code) {
222-
(crossterm::event::KeyModifiers::CONTROL, crossterm::event::KeyCode::Char('c')) => {
223-
_ = tx.send(IncomingEvent::Interrupt);
224-
},
225-
(_, crossterm::event::KeyCode::Char(input_char)) => {
226-
_ = tx.send(IncomingEvent::Input(input_char.to_string()));
227-
},
228-
(_, crossterm::event::KeyCode::Enter) => {
229-
_ = tx.send(IncomingEvent::Send);
230-
},
231-
(_, crossterm::event::KeyCode::Backspace) => {
232-
_ = tx.send(IncomingEvent::Backspace);
233-
},
234-
// TODO: make a handler for clearing the entire line
235-
(_, _) => {},
236-
}
237-
},
238-
crossterm::event::Event::Paste(content) => {
239-
let _ = tx.send(IncomingEvent::Input(content));
240-
},
241-
_ => {},
242-
}
243-
}
234+
while let Ok(prompt_signal) = prompt_signal_rx.recv() {
235+
let PromptSignal {
236+
active_agent: _,
237+
trust_all: _,
238+
} = prompt_signal;
239+
240+
// TODO: Actually utilize the info to spawn readline here
241+
let prompt = "> ";
242+
let mut rl = DefaultEditor::new().expect("Failed to spawn readline");
243+
244+
// std::thread::sleep(std::time::Duration::from_millis(5000));
245+
246+
match rl.readline(prompt) {
247+
Ok(input) => {
248+
_ = incoming_events_tx.send(IncomingEvent::Input(input));
249+
},
250+
Err(rustyline::error::ReadlineError::Interrupted) => {
251+
_ = incoming_events_tx.send(IncomingEvent::Interrupt);
252+
},
253+
Err(e) => panic!("Failed to spawn readline: {:?}", e),
254+
};
255+
256+
drop(rl);
244257
}
245258
});
246259

247260
tokio::spawn(async move {
248261
let mut display_state = DisplayState::default();
249-
let mut outgoing_buf = String::new();
262+
250263
loop {
251264
if matches!(display_state, DisplayState::Prompting) {
252-
_ = execute!(
253-
stderr,
254-
style::SetAttribute(style::Attribute::Bold),
255-
Print(">"),
256-
style::SetAttribute(style::Attribute::Reset),
257-
Print(" ")
258-
);
265+
tracing::info!("## ui: prompting sent");
266+
// TODO: fetch prompt related info from session and send it here
267+
if let Err(e) = prompt_signal_tx.send(Default::default()) {
268+
error!("Error sending prompt signal: {:?}", e);
269+
}
259270
display_state = DisplayState::UserInsertingText;
260271
}
261272

262273
tokio::select! {
263-
Some(incoming_event) = rx.recv() => {
274+
Some(incoming_event) = incoming_events_rx.recv() => {
264275
match display_state {
265-
DisplayState::Prompting | DisplayState::UserInsertingText => {
276+
DisplayState::UserInsertingText => {
266277
match incoming_event {
267278
IncomingEvent::Input(content) => {
268-
outgoing_buf.push_str(&content);
269-
_ = execute!(
270-
stderr,
271-
crossterm::cursor::MoveRight(content.len() as u16)
272-
);
279+
if let Err(e) = self.sender.send(InputEvent::Text(content)).await {
280+
error!("Error sending input event: {:?}", e);
281+
}
282+
display_state = DisplayState::StreamingOutput;
273283
},
274-
IncomingEvent::Send => {
275-
_ = self.sender.send(InputEvent::Text(outgoing_buf.clone()));
276-
outgoing_buf.clear();
277-
}
278284
IncomingEvent::Interrupt => {
279285
// If user is still inputting text, the session does
280286
// not need to be notified that they are hitting
281287
// control c.
282-
outgoing_buf.clear();
283288
display_state = DisplayState::default();
284289
},
285-
IncomingEvent::Backspace => {
286-
if outgoing_buf.pop().is_some() {
287-
_ = execute!(
288-
stderr,
289-
crossterm::cursor::MoveLeft(1),
290-
);
291-
}
292-
}
293290
}
294291
},
295292
DisplayState::StreamingOutput if matches!(incoming_event, IncomingEvent::Interrupt)=> {
296-
_ = self.sender.send(InputEvent::Interrupt);
297-
display_state = DisplayState::Prompting;
293+
_ = self.sender.send(InputEvent::Interrupt).await;
298294
},
299-
DisplayState::StreamingOutput => {
295+
DisplayState::Hidden | DisplayState::StreamingOutput | DisplayState::Prompting => {
300296
// We ignore everything that's not a sigint here
301297
}
302298
}
303299
},
304300
session_event = self.receiver.recv() => {
305301
if let Some(event) = session_event {
306-
display_state = handle_session_event_legacy_mode(event, &mut stderr, &mut stdout, &theme_source)?;
302+
handle_session_event_legacy_mode(event, &mut stderr, &mut stdout, &theme_source, Some(&mut display_state))?;
307303
} else {
308304
break;
309305
}
@@ -316,7 +312,7 @@ impl ViewEnd {
316312
} else {
317313
tokio::spawn(async move {
318314
while let Some(event) = self.receiver.recv().await {
319-
_ = handle_session_event_legacy_mode(event, &mut stderr, &mut stdout, &theme_source)?;
315+
handle_session_event_legacy_mode(event, &mut stderr, &mut stdout, &theme_source, None)?;
320316
}
321317

322318
Ok::<(), ConduitError>(())

crates/chat-cli/src/cli/chat/managed_input.rs

Whitespace-only changes.

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ mod prompt_parser;
2121
pub mod server_messenger;
2222
use crate::cli::chat::checkpoint::CHECKPOINT_MESSAGE_MAX_LENGTH;
2323
use crate::constants::ui_text;
24+
mod managed_input;
2425
#[cfg(unix)]
2526
mod skim_integration;
2627
mod token_counter;
@@ -54,6 +55,7 @@ use chat_cli_ui::conduit::{
5455
};
5556
use chat_cli_ui::protocol::{
5657
Event,
58+
InputEvent,
5759
MessageRole,
5860
TextMessageContent,
5961
TextMessageEnd,
@@ -605,6 +607,7 @@ pub struct ChatSession {
605607
inner: Option<ChatState>,
606608
ctrlc_rx: broadcast::Receiver<()>,
607609
wrap: Option<WrapMode>,
610+
managed_input: Option<tokio::sync::mpsc::Receiver<InputEvent>>,
608611
}
609612

610613
impl ChatSession {
@@ -628,12 +631,12 @@ impl ChatSession {
628631
let mut existing_conversation = false;
629632

630633
let should_send_structured_msg = should_send_structured_message(os);
631-
let (view_end, _byte_receiver, mut control_end_stderr, control_end_stdout) =
634+
let (view_end, managed_input, mut control_end_stderr, control_end_stdout) =
632635
get_legacy_conduits(should_send_structured_msg);
633636

634637
let stderr = std::io::stderr();
635638
let stdout = std::io::stdout();
636-
if let Err(e) = view_end.into_legacy_mode(false, StyledText, stderr, stdout) {
639+
if let Err(e) = view_end.into_legacy_mode(true, StyledText, stderr, stdout) {
637640
error!("Conduit view end legacy mode exited: {:?}", e);
638641
}
639642

@@ -737,6 +740,7 @@ impl ChatSession {
737740
inner: Some(ChatState::default()),
738741
ctrlc_rx,
739742
wrap,
743+
managed_input: Some(managed_input),
740744
})
741745
}
742746

@@ -1940,7 +1944,17 @@ impl ChatSession {
19401944

19411945
execute!(self.stderr, StyledText::reset(), StyledText::reset_attributes())?;
19421946
let prompt = self.generate_tool_trust_prompt(os).await;
1943-
let user_input = match self.read_user_input(&prompt, false) {
1947+
let user_input = if self.managed_input.is_some() {
1948+
self.stderr.send(Event::MetaEvent(chat_cli_ui::protocol::MetaEvent {
1949+
meta_type: "timing".to_string(),
1950+
payload: serde_json::Value::String("prompt_user".to_string()),
1951+
}))?;
1952+
self.read_user_input_managed().await
1953+
} else {
1954+
self.read_user_input(&prompt, false)
1955+
};
1956+
debug!("## ui: User input: {:?}", user_input);
1957+
let user_input = match user_input {
19441958
Some(input) => input,
19451959
None => return Ok(ChatState::Exit),
19461960
};
@@ -1970,6 +1984,20 @@ impl ChatSession {
19701984
Ok(ChatState::HandleInput { input: user_input })
19711985
}
19721986

1987+
async fn read_user_input_managed(&mut self) -> Option<String> {
1988+
if let Some(managed_input) = &mut self.managed_input {
1989+
if let Some(content) = managed_input.recv().await {
1990+
if let InputEvent::Text(content) = content {
1991+
return Some(content);
1992+
} else {
1993+
return None;
1994+
}
1995+
}
1996+
}
1997+
1998+
None
1999+
}
2000+
19732001
async fn handle_input(&mut self, os: &mut Os, mut user_input: String) -> Result<ChatState, ChatError> {
19742002
queue!(self.stderr, style::Print('\n'))?;
19752003
user_input = sanitize_unicode_tags(&user_input);

0 commit comments

Comments
 (0)