diff --git a/Cargo.lock b/Cargo.lock index af4f853433..34d000363b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1714,9 +1714,9 @@ dependencies = [ [[package]] name = "console" -version = "0.15.10" +version = "0.15.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" dependencies = [ "encode_unicode", "libc", @@ -3093,7 +3093,7 @@ dependencies = [ "num-traits", "parking_lot", "pin-project", - "portable-pty", + "portable-pty 0.8.1", "predicates", "radix_trie", "regex", @@ -3128,7 +3128,7 @@ dependencies = [ "fig_proto", "fig_remote_ipc", "fig_util", - "portable-pty", + "portable-pty 0.8.1", "tempfile", "tokio", "uuid", @@ -6189,6 +6189,27 @@ dependencies = [ "winreg 0.10.1", ] +[[package]] +name = "portable-pty" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4a596a2b3d2752d94f51fac2d4a96737b8705dddd311a32b9af47211f08671e" +dependencies = [ + "anyhow", + "bitflags 1.3.2", + "downcast-rs", + "filedescriptor", + "lazy_static", + "libc", + "log", + "nix 0.28.0", + "serial2", + "shared_library", + "shell-words", + "winapi", + "winreg 0.10.1", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -6515,6 +6536,7 @@ dependencies = [ "clap_complete_fig", "color-eyre", "color-print", + "console", "convert_case 0.8.0", "criterion", "crossterm", @@ -6536,6 +6558,7 @@ dependencies = [ "fig_settings", "fig_telemetry", "fig_util", + "filedescriptor", "flume", "futures", "glob", @@ -6552,6 +6575,7 @@ dependencies = [ "owo-colors 4.2.0", "parking_lot", "paste", + "portable-pty 0.9.0", "predicates", "rand 0.9.0", "regex", @@ -7619,6 +7643,17 @@ dependencies = [ "serial-core", ] +[[package]] +name = "serial2" +version = "0.2.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cd0c773455b60177d1abe4c739cbfa316c4f2f0ef37465befcb72e8a15cdd02" +dependencies = [ + "cfg-if", + "libc", + "winapi", +] + [[package]] name = "servo_arc" version = "0.1.1" diff --git a/codebase-summary.md b/codebase-summary.md index f8fbb2a2ca..0389904eaf 100644 --- a/codebase-summary.md +++ b/codebase-summary.md @@ -57,7 +57,7 @@ The chat implementation includes a robust tool system that allows Amazon Q to in 1. **Available Tools**: - `fs_read`: Reads files or lists directories (similar to `cat` or `ls`) - `fs_write`: Creates or modifies files with various operations (create, append, replace) - - `execute_bash`: Executes shell commands in the user's environment + - `execute_shell_commands`: Executes shell commands in the user's environment - `use_aws`: Makes AWS CLI API calls with specified services and operations 2. **Tool Execution Flow**: @@ -68,7 +68,7 @@ The chat implementation includes a robust tool system that allows Amazon Q to in - The conversation continues with the tool results incorporated 3. **Security Considerations**: - - Tools that modify the system (like `fs_write` and `execute_bash`) require user confirmation + - Tools that modify the system (like `fs_write` and `execute_shell_commands`) require user confirmation - The `/acceptall` command can toggle automatic acceptance for the session - Tool responses are limited to prevent excessive output (30KB limit) diff --git a/crates/q_cli/Cargo.toml b/crates/q_cli/Cargo.toml index 7ad77015cd..d22c4f7e42 100644 --- a/crates/q_cli/Cargo.toml +++ b/crates/q_cli/Cargo.toml @@ -35,6 +35,7 @@ clap_complete_fig = "4.4.0" color-eyre = "0.6.2" color-print = "0.3.5" convert_case.workspace = true +console = "0.15.11" crossterm = { version = "0.28.1", features = ["event-stream"] } ctrlc = "3.2.5" dialoguer = { version = "0.11.0", features = ["fuzzy-select"] } @@ -62,6 +63,8 @@ indoc.workspace = true mimalloc.workspace = true owo-colors = "4.2.0" parking_lot.workspace = true +portable-pty = "0.9.0" +filedescriptor = "0.8.3" rand.workspace = true regex.workspace = true rustyline = { version = "15.0.0", features = ["derive"] } diff --git a/crates/q_cli/src/cli/chat/parser.rs b/crates/q_cli/src/cli/chat/parser.rs index dfa287e6b0..78e03ce4e0 100644 --- a/crates/q_cli/src/cli/chat/parser.rs +++ b/crates/q_cli/src/cli/chat/parser.rs @@ -332,7 +332,7 @@ mod tests { async fn test_parse() { let _ = tracing_subscriber::fmt::try_init(); let tool_use_id = "TEST_ID".to_string(); - let tool_name = "execute_bash".to_string(); + let tool_name = "execute_shell_commands".to_string(); let tool_args = serde_json::json!({ "command": "echo hello" }) diff --git a/crates/q_cli/src/cli/chat/tools/execute_bash.rs b/crates/q_cli/src/cli/chat/tools/execute_shell_commands.rs similarity index 52% rename from crates/q_cli/src/cli/chat/tools/execute_bash.rs rename to crates/q_cli/src/cli/chat/tools/execute_shell_commands.rs index 9610d49871..862368d6f0 100644 --- a/crates/q_cli/src/cli/chat/tools/execute_bash.rs +++ b/crates/q_cli/src/cli/chat/tools/execute_shell_commands.rs @@ -1,6 +1,45 @@ use std::collections::VecDeque; -use std::io::Write; -use std::process::Stdio; +use std::io::{ + self, + Write, +}; +use std::os::fd::{ + AsFd, + AsRawFd, + FromRawFd, + RawFd, +}; +use std::path::Path; + +use console::strip_ansi_codes; +use fig_os_shim::Context; +use filedescriptor::FileDescriptor; +use nix::fcntl::{ + FcntlArg, + FdFlag, + OFlag, + fcntl, + open, +}; +use nix::libc; +use nix::pty::{ + Winsize, + grantpt, + posix_openpt, + ptsname, + unlockpt, +}; +use nix::sys::signal::{ + SigHandler, + Signal, + signal, +}; +use nix::sys::stat::Mode; +use portable_pty::unix::close_random_fds; +use tokio::io::unix::AsyncFd; +use tokio::select; +use tokio::sync::mpsc::channel; +nix::ioctl_write_ptr_bad!(ioctl_tiocswinsz, libc::TIOCSWINSZ, Winsize); use crossterm::queue; use crossterm::style::{ @@ -11,10 +50,7 @@ use eyre::{ Context as EyreContext, Result, }; -use fig_os_shim::Context; use serde::Deserialize; -use tokio::io::AsyncBufReadExt; -use tokio::select; use tracing::error; use super::{ @@ -27,11 +63,21 @@ use crate::cli::chat::truncate_safe; const READONLY_COMMANDS: &[&str] = &["ls", "cat", "echo", "pwd", "which", "head", "tail", "find", "grep"]; #[derive(Debug, Clone, Deserialize)] -pub struct ExecuteBash { +pub struct ExecuteShellCommands { pub command: String, } -impl ExecuteBash { +/// Helper function to set the close-on-exec flag for a raw descriptor +fn cloexec(fd: RawFd) -> Result<()> { + let flags = fcntl(fd, FcntlArg::F_GETFD)?; + fcntl( + fd, + FcntlArg::F_SETFD(FdFlag::from_bits_truncate(flags) | FdFlag::FD_CLOEXEC), + )?; + Ok(()) +} + +impl ExecuteShellCommands { pub fn requires_acceptance(&self) -> bool { let Some(args) = shlex::split(&self.command) else { return true; @@ -90,66 +136,164 @@ impl ExecuteBash { } pub async fn invoke(&self, mut updates: impl Write) -> Result { - // We need to maintain a handle on stderr and stdout, but pipe it to the terminal as well - let mut child = tokio::process::Command::new("bash") + // The pseudoterminal must be initialized with O_NONBLOCK since on macOS, the + // it can not be safely set with fcntl() later on. + // https://github.com/pkgw/stund/blob/master/tokio-pty-process/src/lib.rs#L127-L133 + cfg_if::cfg_if! { + if #[cfg(any(target_os = "macos", target_os = "linux"))] { + let oflag = OFlag::O_RDWR | OFlag::O_NONBLOCK; + } else if #[cfg(target_os = "freebsd")] { + let oflag = OFlag::O_RDWR; + } + } + let master_pty = std::sync::Arc::new(posix_openpt(oflag).context("Failed to openpt")?); + + // Allow pseudoterminal pair to be generated + grantpt(&master_pty).context("Failed to grantpt")?; + unlockpt(&master_pty).context("Failed to unlockpt")?; + + // Get the name of the pseudoterminal + // SAFETY: This is done before any threads are spawned, thus it being + // non thread safe is not an issue + let pty_name = { unsafe { ptsname(&master_pty) }? }; + + // This will be the reader + let slave_pty = open(Path::new(&pty_name), OFlag::O_RDWR, Mode::empty())?; + + let winsize = Winsize { + ws_row: 30, + ws_col: 100, + ws_xpixel: 0, + ws_ypixel: 0, + }; + unsafe { ioctl_tiocswinsz(slave_pty, &winsize) }?; + + cloexec(master_pty.as_fd().as_raw_fd())?; + cloexec(slave_pty.as_raw_fd())?; + + let shell: String = std::env::var("SHELL").unwrap_or_else(|_| "bash".to_string()); + + let slave_fd = unsafe { FileDescriptor::from_raw_fd(slave_pty.as_raw_fd()) }; + + let mut base_command = tokio::process::Command::new(&shell); + let command = base_command .arg("-c") + .arg("-l") + .arg("-i") .arg(&self.command) - .stdin(Stdio::inherit()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .wrap_err_with(|| format!("Unable to spawn command '{}'", &self.command))?; + .stdin(slave_fd.as_stdio()?) + .stdout(slave_fd.as_stdio()?) + .stderr(slave_fd.as_stdio()?); + + let pre_exec_fn = move || { + // Clean up a few things before we exec the program + // Clear out any potentially problematic signal + // dispositions that we might have inherited + for signo in [ + Signal::SIGCHLD, + Signal::SIGHUP, + Signal::SIGINT, + Signal::SIGQUIT, + Signal::SIGTERM, + Signal::SIGALRM, + ] { + unsafe { signal(signo, SigHandler::SigDfl) }?; + } - let stdout = child.stdout.take().unwrap(); - let stdout = tokio::io::BufReader::new(stdout); - let mut stdout = stdout.lines(); + // Establish ourselves as a session leader. + nix::unistd::setsid()?; - let stderr = child.stderr.take().unwrap(); - let stderr = tokio::io::BufReader::new(stderr); - let mut stderr = stderr.lines(); + // Clippy wants us to explicitly cast TIOCSCTTY using + // type::from(), but the size and potentially signedness + // are system dependent, which is why we're using `as _`. + // Suppress this lint for this section of code. + { + // Set the pty as the controlling terminal. + // Failure to do this means that delivery of + // SIGWINCH won't happen when we resize the + // terminal, among other undesirable effects. + if unsafe { libc::ioctl(0, libc::TIOCSCTTY as _, 0) == -1 } { + return Err(io::Error::last_os_error()); + } + } + + close_random_fds(); + + Ok(()) + }; + + unsafe { command.pre_exec(pre_exec_fn) }; + + let mut child = command.spawn()?; + + let async_master = AsyncFd::new(master_pty.as_fd().as_raw_fd())?; const LINE_COUNT: usize = 1024; - let mut stdout_buf = VecDeque::with_capacity(LINE_COUNT); - let mut stderr_buf = VecDeque::with_capacity(LINE_COUNT); - let mut stdout_done = false; - let mut stderr_done = false; + let (tx, mut rx) = channel(LINE_COUNT); + let mut buffer = [0u8; LINE_COUNT]; + + tokio::spawn(async move { + loop { + match async_master.readable().await { + Ok(mut guard) => { + let n = match guard.try_io(|inner| { + nix::unistd::read(inner.get_ref().as_raw_fd(), &mut buffer) + .map_err(|e| std::io::Error::from_raw_os_error(e as i32)) + }) { + Ok(Ok(n)) => n, + Ok(Err(e)) => { + print!("{} ", e); + error!(%e, "Read error"); + break; + }, + Err(_) => continue, + }; + + if n == 0 { + break; + } + + let raw_output = &buffer[..n]; + if tx.send(raw_output.to_vec()).await.is_err() { + error!("channel closed"); + break; + } + }, + Err(e) => { + error!(%e, "readable failed"); + break; + }, + } + } + }); + + let mut stdout_lines: VecDeque = VecDeque::with_capacity(LINE_COUNT); + let exit_status = loop { select! { biased; - line = stdout.next_line(), if !stdout_done => match line { - Ok(Some(line)) => { - writeln!(updates, "{line}")?; - if stdout_buf.len() >= LINE_COUNT { - stdout_buf.pop_front(); - } - stdout_buf.push_back(line); - }, - Ok(None) => stdout_done = true, - Err(err) => error!(%err, "Failed to read stdout of child process"), - }, - line = stderr.next_line(), if !stderr_done => match line { - Ok(Some(line)) => { - writeln!(updates, "{line}")?; - if stderr_buf.len() >= LINE_COUNT { - stderr_buf.pop_front(); + Some(line) = rx.recv() => { + updates.write_all(&line)?; + updates.flush()?; + + if let Ok(text) = std::str::from_utf8(&line) { + for subline in text.split_inclusive('\n') { + if stdout_lines.len() >= LINE_COUNT { + stdout_lines.pop_front(); + } + stdout_lines.push_back(strip_ansi_codes(subline).to_string().trim().to_string()); } - stderr_buf.push_back(line); - }, - Ok(None) => stderr_done = true, - Err(err) => error!(%err, "Failed to read stderr of child process"), - }, - exit_status = child.wait() => { - break exit_status; - }, + } + } + status = child.wait() => { + break status; + } }; } .wrap_err_with(|| format!("No exit status for '{}'", &self.command))?; - updates.flush()?; - - let stdout = stdout_buf.into_iter().collect::>().join("\n"); - let stderr = stderr_buf.into_iter().collect::>().join("\n"); + let stdout = stdout_lines.into_iter().collect::(); let output = serde_json::json!({ "exit_status": exit_status.code().unwrap_or(0).to_string(), @@ -162,17 +306,10 @@ impl ExecuteBash { "" } ), - "stderr": format!( - "{}{}", - truncate_safe(&stderr, MAX_TOOL_RESPONSE_SIZE / 3), - if stderr.len() > MAX_TOOL_RESPONSE_SIZE / 3 { - " ... truncated" - } else { - "" - } - ), }); + child.kill().await?; + Ok(InvokeOutput { output: OutputKind::Json(output), }) @@ -206,14 +343,14 @@ mod tests { #[ignore = "todo: fix failing on musl for some reason"] #[tokio::test] - async fn test_execute_bash_tool() { + async fn test_execute_shell_commands_tool() { let mut stdout = std::io::stdout(); // Verifying stdout let v = serde_json::json!({ "command": "echo Hello, world!", }); - let out = serde_json::from_value::(v) + let out = serde_json::from_value::(v) .unwrap() .invoke(&mut stdout) .await @@ -222,25 +359,6 @@ mod tests { if let OutputKind::Json(json) = out.output { assert_eq!(json.get("exit_status").unwrap(), &0.to_string()); assert_eq!(json.get("stdout").unwrap(), "Hello, world!"); - assert_eq!(json.get("stderr").unwrap(), ""); - } else { - panic!("Expected JSON output"); - } - - // Verifying stderr - let v = serde_json::json!({ - "command": "echo Hello, world! 1>&2", - }); - let out = serde_json::from_value::(v) - .unwrap() - .invoke(&mut stdout) - .await - .unwrap(); - - if let OutputKind::Json(json) = out.output { - assert_eq!(json.get("exit_status").unwrap(), &0.to_string()); - assert_eq!(json.get("stdout").unwrap(), ""); - assert_eq!(json.get("stderr").unwrap(), "Hello, world!"); } else { panic!("Expected JSON output"); } @@ -250,7 +368,7 @@ mod tests { "command": "exit 1", "interactive": false }); - let out = serde_json::from_value::(v) + let out = serde_json::from_value::(v) .unwrap() .invoke(&mut stdout) .await @@ -258,7 +376,6 @@ mod tests { if let OutputKind::Json(json) = out.output { assert_eq!(json.get("exit_status").unwrap(), &1.to_string()); assert_eq!(json.get("stdout").unwrap(), ""); - assert_eq!(json.get("stderr").unwrap(), ""); } else { panic!("Expected JSON output"); } @@ -304,7 +421,7 @@ mod tests { ("find important-dir/ -name '*.txt'", false), ]; for (cmd, expected) in cmds { - let tool = serde_json::from_value::(serde_json::json!({ + let tool = serde_json::from_value::(serde_json::json!({ "command": cmd, })) .unwrap(); diff --git a/crates/q_cli/src/cli/chat/tools/mod.rs b/crates/q_cli/src/cli/chat/tools/mod.rs index 6c004c990b..98e8cf7e6e 100644 --- a/crates/q_cli/src/cli/chat/tools/mod.rs +++ b/crates/q_cli/src/cli/chat/tools/mod.rs @@ -1,4 +1,4 @@ -pub mod execute_bash; +pub mod execute_shell_commands; pub mod fs_read; pub mod fs_write; pub mod gh_issue; @@ -15,7 +15,7 @@ use aws_smithy_types::{ Document, Number as SmithyNumber, }; -use execute_bash::ExecuteBash; +use execute_shell_commands::ExecuteShellCommands; use eyre::Result; use fig_api_client::model::{ ToolResult, @@ -38,7 +38,7 @@ pub const MAX_TOOL_RESPONSE_SIZE: usize = 800000; pub enum Tool { FsRead(FsRead), FsWrite(FsWrite), - ExecuteBash(ExecuteBash), + ExecuteShellCommands(ExecuteShellCommands), UseAws(UseAws), GhIssue(GhIssue), } @@ -49,7 +49,7 @@ impl Tool { match self { Tool::FsRead(_) => "Read from filesystem", Tool::FsWrite(_) => "Write to filesystem", - Tool::ExecuteBash(_) => "Execute shell command", + Tool::ExecuteShellCommands(_) => "Execute shell command", Tool::UseAws(_) => "Use AWS CLI", Tool::GhIssue(_) => "Prepare GitHub issue", } @@ -60,7 +60,9 @@ impl Tool { match self { Tool::FsRead(_) => "Reading from filesystem", Tool::FsWrite(_) => "Writing to filesystem", - Tool::ExecuteBash(execute_bash) => return format!("Executing `{}`", execute_bash.command), + Tool::ExecuteShellCommands(execute_shell_commands) => { + return format!("Executing `{}`", execute_shell_commands.command); + }, Tool::UseAws(_) => "Using AWS CLI", Tool::GhIssue(_) => "Preparing GitHub issue", } @@ -72,7 +74,7 @@ impl Tool { match self { Tool::FsRead(_) => false, Tool::FsWrite(_) => true, - Tool::ExecuteBash(execute_bash) => execute_bash.requires_acceptance(), + Tool::ExecuteShellCommands(execute_shell_commands) => execute_shell_commands.requires_acceptance(), Tool::UseAws(use_aws) => use_aws.requires_acceptance(), Tool::GhIssue(_) => false, } @@ -83,7 +85,7 @@ impl Tool { match self { Tool::FsRead(fs_read) => fs_read.invoke(context, updates).await, Tool::FsWrite(fs_write) => fs_write.invoke(context, updates).await, - Tool::ExecuteBash(execute_bash) => execute_bash.invoke(updates).await, + Tool::ExecuteShellCommands(execute_shell_commands) => execute_shell_commands.invoke(updates).await, Tool::UseAws(use_aws) => use_aws.invoke(context, updates).await, Tool::GhIssue(gh_issue) => gh_issue.invoke(updates).await, } @@ -94,7 +96,7 @@ impl Tool { match self { Tool::FsRead(fs_read) => fs_read.queue_description(ctx, updates).await, Tool::FsWrite(fs_write) => fs_write.queue_description(ctx, updates), - Tool::ExecuteBash(execute_bash) => execute_bash.queue_description(updates), + Tool::ExecuteShellCommands(execute_shell_commands) => execute_shell_commands.queue_description(updates), Tool::UseAws(use_aws) => use_aws.queue_description(updates), Tool::GhIssue(gh_issue) => gh_issue.queue_description(updates), } @@ -105,7 +107,7 @@ impl Tool { match self { Tool::FsRead(fs_read) => fs_read.validate(ctx).await, Tool::FsWrite(fs_write) => fs_write.validate(ctx).await, - Tool::ExecuteBash(execute_bash) => execute_bash.validate(ctx).await, + Tool::ExecuteShellCommands(execute_shell_commands) => execute_shell_commands.validate(ctx).await, Tool::UseAws(use_aws) => use_aws.validate(ctx).await, Tool::GhIssue(gh_issue) => gh_issue.validate(ctx).await, } @@ -127,7 +129,9 @@ impl TryFrom for Tool { Ok(match value.name.as_str() { "fs_read" => Self::FsRead(serde_json::from_value::(value.args).map_err(map_err)?), "fs_write" => Self::FsWrite(serde_json::from_value::(value.args).map_err(map_err)?), - "execute_bash" => Self::ExecuteBash(serde_json::from_value::(value.args).map_err(map_err)?), + "execute_shell_commands" => { + Self::ExecuteShellCommands(serde_json::from_value::(value.args).map_err(map_err)?) + }, "use_aws" => Self::UseAws(serde_json::from_value::(value.args).map_err(map_err)?), "report_issue" => Self::GhIssue(serde_json::from_value::(value.args).map_err(map_err)?), unknown => { diff --git a/crates/q_cli/src/cli/chat/tools/tool_index.json b/crates/q_cli/src/cli/chat/tools/tool_index.json index 397d856cfa..a1e6b46b43 100644 --- a/crates/q_cli/src/cli/chat/tools/tool_index.json +++ b/crates/q_cli/src/cli/chat/tools/tool_index.json @@ -1,13 +1,13 @@ { - "execute_bash": { - "name": "execute_bash", - "description": "Execute the specified bash command.", + "execute_shell_commands": { + "name": "execute_shell_commands", + "description": "Execute the specified shell command.", "input_schema": { "type": "object", "properties": { "command": { "type": "string", - "description": "Bash command to execute" + "description": "Shell command to execute" } }, "required": [