diff --git a/Cargo.lock b/Cargo.lock index df086dfae..0745889d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1224,6 +1224,7 @@ dependencies = [ "bytes", "camino", "cfg-if", + "chrono", "clap", "clap_complete", "clap_complete_fig", diff --git a/crates/chat-cli/Cargo.toml b/crates/chat-cli/Cargo.toml index cfd68dd50..1b9b78d86 100644 --- a/crates/chat-cli/Cargo.toml +++ b/crates/chat-cli/Cargo.toml @@ -45,6 +45,7 @@ bstr.workspace = true bytes.workspace = true camino.workspace = true cfg-if.workspace = true +chrono.workspace = true clap.workspace = true clap_complete.workspace = true clap_complete_fig.workspace = true diff --git a/crates/chat-cli/src/cli/chat/api/connection_handler.rs b/crates/chat-cli/src/cli/chat/api/connection_handler.rs new file mode 100644 index 000000000..8e92057e5 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/api/connection_handler.rs @@ -0,0 +1,567 @@ +//! Connection handler for Q Chat API mode socket connections +//! +//! This module handles client connections to Unix domain sockets, including +//! connection acceptance, message broadcasting, and graceful shutdown. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::SystemTime; + +use eyre::{Result, eyre}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, AsyncReadExt, BufReader}; +use tokio::net::{UnixListener, UnixStream}; +use tokio::sync::{broadcast, mpsc}; +use tokio::task::JoinHandle; +use tracing::{debug, warn}; +use uuid::Uuid; + +use super::socket_manager::{SocketManager, SocketType}; + +/// Guard that manages the lifecycle of connection accept tasks +pub struct ConnectionGuard { + accept_tasks: Vec>, +} + +impl ConnectionGuard { + /// Get the number of active accept tasks + pub fn task_count(&self) -> usize { + self.accept_tasks.len() + } +} + +impl Drop for ConnectionGuard { + fn drop(&mut self) { + // Abort all tasks if not gracefully shut down + for task in &self.accept_tasks { + task.abort(); + } + } +} + +/// Maximum number of clients per socket type +const MAX_CLIENTS_PER_SOCKET: usize = 10; + +/// Buffer size for broadcast channels +const BROADCAST_BUFFER_SIZE: usize = 100; + +/// Client connection information with communication channels +#[derive(Debug)] +pub struct ClientConnection { + pub id: String, + pub socket_type: SocketType, + pub connected_at: SystemTime, + pub sender: mpsc::UnboundedSender, + pub task_handle: JoinHandle<()>, +} + +/// Connection handler for managing socket connections +#[derive(Debug)] +pub struct ConnectionHandler { + /// Reference to the socket manager + socket_manager: Arc>, + /// Active client connections + clients: Arc>>, + /// Broadcast channels for each socket type + broadcasters: HashMap>, + /// Input injection sender for forwarding input socket messages to InputSource + input_injection_sender: Option>, +} + +impl ConnectionHandler { + /// Create a new connection handler with shared broadcasters + pub fn new( + socket_manager: Arc>, + broadcasters: HashMap>, + ) -> Self { + Self { + socket_manager, + clients: Arc::new(Mutex::new(HashMap::new())), + broadcasters, + input_injection_sender: None, + } + } + + /// Set the input injection sender for forwarding input socket messages to InputSource + pub fn set_input_injection_sender(&mut self, sender: std::sync::mpsc::Sender) { + self.input_injection_sender = Some(sender); + } + + /// Start accepting connections for all socket types and return a guard + /// The guard must be kept alive to maintain the accept tasks + pub async fn start_accepting_connections(&mut self) -> Result { + + let mut accept_tasks = Vec::new(); + + // Get all socket types that have listeners + let socket_types = vec![ + SocketType::Control, + SocketType::Input, + SocketType::Output, + SocketType::Thinking, + SocketType::Tools, + SocketType::Events, + ]; + + for socket_type in socket_types { + // Get the existing listener from SocketManager instead of binding new one + let listener = { + let manager = self.socket_manager.lock() + .map_err(|_| eyre!("Failed to lock socket manager"))?; + manager.get_listener(&socket_type) + }; + + if let Some(listener_arc) = listener { + let task = self.spawn_accept_task_with_arc(socket_type.clone(), listener_arc).await?; + accept_tasks.push(task); + } else { + warn!("No listener found for socket type {:?}", socket_type); + } + } + + + // Return the guard that will keep tasks alive + Ok(ConnectionGuard { + accept_tasks, + }) + } + + /// Spawn a task to accept connections for a specific socket type using Arc + async fn spawn_accept_task_with_arc( + &self, + socket_type: SocketType, + listener: Arc, + ) -> Result> { + let clients = Arc::clone(&self.clients); + let broadcaster = self.broadcasters + .get(&socket_type) + .ok_or_else(|| eyre!("No broadcaster found for socket type {:?}", socket_type))? + .clone(); + let input_injection_sender = self.input_injection_sender.clone(); + + let task = tokio::spawn(async move { + loop { + tokio::select! { + // Accept new connections + result = listener.accept() => { + match result { + Ok((stream, _addr)) => { + if let Err(e) = Self::handle_new_connection( + stream, + socket_type.clone(), + Arc::clone(&clients), + broadcaster.subscribe(), + input_injection_sender.clone(), + ).await { + warn!("Failed to handle new connection on {:?}: {}", socket_type, e); + } + } + Err(e) => { + warn!("Failed to accept connection on {:?} socket: {}", socket_type, e); + } + } + } + } + } + }); + + Ok(task) + } + + /// Spawn a task to accept connections for a specific socket type + async fn spawn_accept_task( + &self, + socket_type: SocketType, + listener: UnixListener, + ) -> Result> { + let clients = Arc::clone(&self.clients); + let broadcaster = self.broadcasters.get(&socket_type) + .ok_or_else(|| eyre!("No broadcaster found for socket type: {:?}", socket_type))? + .clone(); + let input_injection_sender = self.input_injection_sender.clone(); + + let task = tokio::spawn(async move { + loop { + tokio::select! { + // Accept new connections + result = listener.accept() => { + match result { + Ok((stream, _addr)) => { + if let Err(e) = Self::handle_new_connection( + stream, + socket_type.clone(), + Arc::clone(&clients), + broadcaster.subscribe(), + input_injection_sender.clone(), + ).await { + warn!("Failed to handle new connection: {}", e); + } + } + Err(e) => { + warn!("Failed to accept connection on {:?} socket: {}", socket_type, e); + } + } + } + } + } + }); + + Ok(task) + } + + /// Handle a new client connection + async fn handle_new_connection( + stream: UnixStream, + socket_type: SocketType, + clients: Arc>>, + mut broadcast_receiver: broadcast::Receiver, + input_injection_sender: Option>, + ) -> Result<()> { + // Check connection limit + { + let clients_guard = clients.lock() + .map_err(|_| eyre!("Failed to lock clients"))?; + let socket_client_count = clients_guard.values() + .filter(|conn| conn.socket_type == socket_type) + .count(); + + if socket_client_count >= MAX_CLIENTS_PER_SOCKET { + eprintln!("Connection limit reached for {:?} socket", socket_type); + return Ok(()); + } + } + + let client_id = Uuid::new_v4().to_string(); + let (sender, mut receiver) = mpsc::unbounded_channel::(); + + // Split stream for reading and writing + let (read_half, write_half) = stream.into_split(); + let mut reader = BufReader::new(read_half); + let mut writer = write_half; + + // Spawn task to handle outgoing messages to client + let client_id_clone = client_id.clone(); + let write_task = tokio::spawn(async move { + loop { + tokio::select! { + // Send messages from the application to client + msg = receiver.recv() => { + match msg { + Some(message) => { + if let Err(e) = writer.write_all(message.as_bytes()).await { + warn!("Failed to write to client {}: {}", client_id_clone, e); + break; + } + if let Err(e) = writer.write_all(b"\n").await { + warn!("Failed to write newline to client {}: {}", client_id_clone, e); + break; + } + } + None => break, + } + } + // Forward broadcast messages to client + msg = broadcast_receiver.recv() => { + match msg { + Ok(message) => { + if let Err(e) = writer.write_all(message.as_bytes()).await { + warn!("Failed to write broadcast to client {}: {}", client_id_clone, e); + break; + } + if let Err(e) = writer.write_all(b"\n").await { + warn!("Failed to write newline to client {}: {}", client_id_clone, e); + break; + } + } + Err(broadcast::error::RecvError::Closed) => break, + Err(broadcast::error::RecvError::Lagged(_)) => { + eprintln!("Client {} lagged behind, some messages may be lost", client_id_clone); + } + } + } + } + } + }); + + // Create client connection (stream is handled by the tasks) + let connection = ClientConnection { + id: client_id.clone(), + socket_type: socket_type.clone(), + connected_at: SystemTime::now(), + sender, + task_handle: write_task, + }; + + // Add client to the connections map + { + let mut clients_guard = clients.lock() + .map_err(|_| eyre!("Failed to lock clients"))?; + clients_guard.insert(client_id.clone(), connection); + } + + debug!("New client connected: {} on {:?} socket", client_id, socket_type); + + // Handle incoming messages based on socket type + match socket_type { + SocketType::Control => { + // Control socket is suppressed - return error + warn!("Control socket is suppressed - rejecting connection for client {}", client_id); + return Err(eyre!("Control socket is suppressed")); + } + SocketType::Events => { + // Events socket is suppressed - return error + warn!("Events socket is suppressed - rejecting connection for client {}", client_id); + return Err(eyre!("Events socket is suppressed")); + } + SocketType::Input => { + // Input socket: read and process user input + let mut line = String::new(); + loop { + match reader.read_line(&mut line).await { + Ok(0) => break, // EOF + Ok(_) => { + let trimmed = line.trim(); + if !trimmed.is_empty() { + // Forward input message to chat session via injection sender + if let Some(ref sender) = input_injection_sender { + match sender.send(trimmed.to_string()) { + Ok(()) => {} + Err(_e) => {} + } + } + } + line.clear(); + } + Err(e) => { + eprintln!("Error reading from client {}: {}", client_id, e); + break; + } + } + } + } + SocketType::Output | SocketType::Thinking | SocketType::Tools => { + // Output-only sockets: don't read, just keep connection alive + // Use a more efficient way to detect client disconnection + // We'll use a small read with a long timeout to detect when client disconnects + // without processing any data they might send + loop { + let mut buffer = [0u8; 1]; + match reader.read(&mut buffer).await { + Ok(0) => { + debug!("Client {} disconnected from {:?} socket (EOF)", client_id, socket_type); + break; // EOF - client disconnected + } + Ok(_) => { + // Client sent data on an output-only socket - ignore it silently + // This is expected behavior for output-only sockets + } + Err(e) => { + debug!("Connection error for client {} on {:?} socket: {}", client_id, socket_type, e); + break; + } + } + } + } + } + + // Clean up client connection + Self::cleanup_client(&clients, &client_id).await; + debug!("Client disconnected: {}", client_id); + + Ok(()) + } + + /// Clean up a client connection + async fn cleanup_client( + clients: &Arc>>, + client_id: &str, + ) { + if let Ok(mut clients_guard) = clients.lock() { + if let Some(connection) = clients_guard.remove(client_id) { + connection.task_handle.abort(); + } + } + } + + /// Broadcast a message to all clients of a specific socket type + pub fn broadcast_to_socket(&self, socket_type: &SocketType, message: &str) -> Result<()> { + if let Some(broadcaster) = self.broadcasters.get(socket_type) { + broadcaster.send(message.to_string()) + .map_err(|_| eyre!("Failed to broadcast message to {:?} socket", socket_type))?; + } + Ok(()) + } + + /// Send a message to a specific client + pub fn send_to_client(&self, client_id: &str, message: &str) -> Result<()> { + let clients = self.clients.lock() + .map_err(|_| eyre!("Failed to lock clients"))?; + + if let Some(connection) = clients.get(client_id) { + connection.sender.send(message.to_string()) + .map_err(|_| eyre!("Failed to send message to client {}", client_id))?; + } else { + return Err(eyre!("Client {} not found", client_id)); + } + + Ok(()) + } + + /// Get all connected clients + pub fn get_connected_clients(&self) -> Result> { + let clients = self.clients.lock() + .map_err(|_| eyre!("Failed to lock clients"))?; + + Ok(clients.values() + .map(|conn| (conn.id.clone(), conn.socket_type.clone(), conn.connected_at)) + .collect()) + } + + /// Get connected clients for a specific socket type + pub fn get_clients_for_socket(&self, socket_type: &SocketType) -> Result> { + let clients = self.clients.lock() + .map_err(|_| eyre!("Failed to lock clients"))?; + + Ok(clients.values() + .filter(|conn| &conn.socket_type == socket_type) + .map(|conn| conn.id.clone()) + .collect()) + } + + /// Get the number of connected clients + pub fn connection_count(&self) -> usize { + self.clients.lock() + .map(|clients| clients.len()) + .unwrap_or(0) + } + + /// Get the number of connected clients for a specific socket type + pub fn connection_count_for_socket(&self, socket_type: &SocketType) -> usize { + self.clients.lock() + .map(|clients| { + clients.values() + .filter(|conn| &conn.socket_type == socket_type) + .count() + }) + .unwrap_or(0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[tokio::test] + async fn test_connection_handler_creation() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let mut socket_manager = SocketManager::new(temp_dir.path()) + .expect("Failed to create socket manager"); + socket_manager.disable_cleanup_on_drop(); + + let manager_arc = Arc::new(Mutex::new(socket_manager)); + + // Create shared broadcast channels for all socket types + let mut shared_broadcasters = std::collections::HashMap::new(); + for socket_type in crate::cli::chat::api::SocketType::all() { + let (sender, _) = tokio::sync::broadcast::channel(1000); + shared_broadcasters.insert(socket_type, sender); + } + + let handler = ConnectionHandler::new(manager_arc, shared_broadcasters); + + assert_eq!(handler.broadcasters.len(), 6); // All socket types + assert_eq!(handler.connection_count(), 0); + } + + #[tokio::test] + async fn test_broadcast_functionality() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let mut socket_manager = SocketManager::new(temp_dir.path()) + .expect("Failed to create socket manager"); + socket_manager.disable_cleanup_on_drop(); + + let manager_arc = Arc::new(Mutex::new(socket_manager)); + + // Create shared broadcast channels for all socket types + let mut shared_broadcasters = std::collections::HashMap::new(); + for socket_type in crate::cli::chat::api::SocketType::all() { + let (sender, _) = tokio::sync::broadcast::channel(1000); + shared_broadcasters.insert(socket_type, sender); + } + + let handler = ConnectionHandler::new(manager_arc, shared_broadcasters); + + // Create a receiver to prevent the broadcast from failing + let mut _receiver = handler.broadcasters.get(&SocketType::Output) + .expect("Output broadcaster should exist") + .subscribe(); + + // Test broadcasting to output socket + let result = handler.broadcast_to_socket(&SocketType::Output, "test message"); + assert!(result.is_ok()); + } + + #[test] + fn test_client_connection_limits() { + // Test that MAX_CLIENTS_PER_SOCKET is reasonable + assert!(MAX_CLIENTS_PER_SOCKET >= 10); + assert!(MAX_CLIENTS_PER_SOCKET <= 100); + } + + #[test] + fn test_broadcast_buffer_size() { + // Test that broadcast buffer size is reasonable + assert!(BROADCAST_BUFFER_SIZE >= 10); + assert!(BROADCAST_BUFFER_SIZE <= 1000); + } + + #[tokio::test] + async fn test_connection_count_tracking() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let mut socket_manager = SocketManager::new(temp_dir.path()) + .expect("Failed to create socket manager"); + socket_manager.disable_cleanup_on_drop(); + + let manager_arc = Arc::new(Mutex::new(socket_manager)); + + // Create shared broadcast channels for all socket types + let mut shared_broadcasters = std::collections::HashMap::new(); + for socket_type in crate::cli::chat::api::SocketType::all() { + let (sender, _) = tokio::sync::broadcast::channel(1000); + shared_broadcasters.insert(socket_type, sender); + } + + let handler = ConnectionHandler::new(manager_arc, shared_broadcasters); + + // Initially no connections + assert_eq!(handler.connection_count(), 0); + assert_eq!(handler.connection_count_for_socket(&SocketType::Control), 0); + + // Test getting clients for socket type + let clients = handler.get_clients_for_socket(&SocketType::Input); + assert!(clients.is_ok()); + assert!(clients.unwrap().is_empty()); + } + + #[tokio::test] + async fn test_graceful_shutdown() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let mut socket_manager = SocketManager::new(temp_dir.path()) + .expect("Failed to create socket manager"); + socket_manager.disable_cleanup_on_drop(); + + let manager_arc = Arc::new(Mutex::new(socket_manager)); + + // Create shared broadcast channels for all socket types + let mut shared_broadcasters = std::collections::HashMap::new(); + for socket_type in crate::cli::chat::api::SocketType::all() { + let (sender, _) = tokio::sync::broadcast::channel(1000); + shared_broadcasters.insert(socket_type, sender); + } + + let mut handler = ConnectionHandler::new(manager_arc, shared_broadcasters); + + // Test without any connections - just verify handler was created + assert_eq!(handler.connection_count(), 0); + } +} diff --git a/crates/chat-cli/src/cli/chat/api/lifecycle_manager.rs b/crates/chat-cli/src/cli/chat/api/lifecycle_manager.rs new file mode 100644 index 000000000..35c63c341 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/api/lifecycle_manager.rs @@ -0,0 +1,646 @@ +//! Socket lifecycle manager for Q Chat API mode +//! +//! This module handles socket lifecycle management including cleanup on exit, +//! conflict resolution, and permission validation/correction. + +use std::collections::HashMap; +use std::fs; +use std::os::unix::fs::PermissionsExt; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use eyre::{Result, eyre}; +use tokio::signal; +use uuid::Uuid; + +use super::socket_manager::{SocketManager, SocketType}; +use super::connection_handler::ConnectionHandler; + +/// Expected permissions for socket files (user read/write only) +const SOCKET_FILE_PERMISSIONS: u32 = 0o600; + +/// Expected permissions for socket directories (user read/write/execute only) +const SOCKET_DIR_PERMISSIONS: u32 = 0o700; + +/// Maximum age for stale socket files in seconds (1 hour) +const STALE_SOCKET_MAX_AGE_SECONDS: u64 = 3600; + +/// Socket lifecycle manager for handling socket creation, cleanup, and conflict resolution +pub struct SocketLifecycleManager { + /// Socket manager instance + socket_manager: Arc>, + /// Connection handler instance + connection_handler: Option>>, + /// Session ID for this instance + session_id: String, + /// Working directory for socket placement + working_directory: PathBuf, + /// Socket directory path + socket_directory: PathBuf, + /// Cleanup handlers registered for shutdown + cleanup_handlers: Vec Result<()> + Send + Sync>>, + /// Whether cleanup has been performed + cleanup_performed: Arc>, +} + +impl SocketLifecycleManager { + /// Create a new socket lifecycle manager + pub fn new(working_directory: PathBuf) -> Result { + let session_id = Uuid::new_v4().to_string(); + let socket_manager = Arc::new(Mutex::new(SocketManager::new(&working_directory)?)); + + let socket_directory = { + let manager = socket_manager.lock() + .map_err(|_| eyre!("Failed to lock socket manager"))?; + manager.socket_directory.clone() + }; + + Ok(Self { + socket_manager, + connection_handler: None, + session_id, + working_directory, + socket_directory, + cleanup_handlers: Vec::new(), + cleanup_performed: Arc::new(Mutex::new(false)), + }) + } + + /// Set the connection handler for this lifecycle manager + pub fn set_connection_handler(&mut self, handler: Arc>) { + self.connection_handler = Some(handler); + } + + /// Initialize socket lifecycle management + pub async fn initialize(&mut self) -> Result<()> { + // Clean up any stale socket files from previous sessions + self.cleanup_stale_sockets().await?; + + // Validate and create socket directory with proper permissions + self.ensure_socket_directory().await?; + + // Register signal handlers for graceful shutdown + self.register_signal_handlers().await?; + + // Register cleanup handler for normal exit + self.register_cleanup_handler()?; + + Ok(()) + } + + /// Clean up stale socket files from previous sessions + async fn cleanup_stale_sockets(&self) -> Result<()> { + if !self.socket_directory.exists() { + return Ok(()); + } + + let entries = fs::read_dir(&self.socket_directory) + .map_err(|e| eyre!("Failed to read socket directory: {}", e))?; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + for entry in entries { + let entry = entry.map_err(|e| eyre!("Failed to read directory entry: {}", e))?; + let path = entry.path(); + + if path.extension().and_then(|s| s.to_str()) == Some("sock") { + // Check if socket file is stale + if let Ok(metadata) = fs::metadata(&path) { + if let Ok(modified) = metadata.modified() { + let modified_secs = modified + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if now.saturating_sub(modified_secs) > STALE_SOCKET_MAX_AGE_SECONDS { + // Try to connect to the socket to see if it's still active + if !self.is_socket_active(&path).await { + println!("Removing stale socket file: {}", path.display()); + if let Err(e) = fs::remove_file(&path) { + eprintln!("Warning: Failed to remove stale socket {}: {}", path.display(), e); + } + } + } + } + } + } + } + + Ok(()) + } + + /// Check if a socket file is still active by attempting to connect + async fn is_socket_active(&self, socket_path: &Path) -> bool { + match tokio::net::UnixStream::connect(socket_path).await { + Ok(_) => true, // Socket is active + Err(_) => false, // Socket is not active or doesn't exist + } + } + + /// Ensure socket directory exists with proper permissions + async fn ensure_socket_directory(&self) -> Result<()> { + // Create directory if it doesn't exist + if !self.socket_directory.exists() { + fs::create_dir_all(&self.socket_directory) + .map_err(|e| eyre!("Failed to create socket directory: {}", e))?; + } + + // Validate and correct directory permissions + self.validate_and_fix_permissions(&self.socket_directory, SOCKET_DIR_PERMISSIONS, true).await?; + + Ok(()) + } + + /// Validate and fix file/directory permissions + async fn validate_and_fix_permissions(&self, path: &Path, expected_mode: u32, is_directory: bool) -> Result<()> { + let metadata = fs::metadata(path) + .map_err(|e| eyre!("Failed to get metadata for {}: {}", path.display(), e))?; + + let current_mode = metadata.permissions().mode() & 0o777; + + if current_mode != expected_mode { + let item_type = if is_directory { "directory" } else { "file" }; + println!( + "Correcting {} permissions for {}: {:o} -> {:o}", + item_type, + path.display(), + current_mode, + expected_mode + ); + + let mut perms = metadata.permissions(); + perms.set_mode(expected_mode); + fs::set_permissions(path, perms) + .map_err(|e| eyre!("Failed to set permissions for {}: {}", path.display(), e))?; + } + + Ok(()) + } + + /// Resolve socket name conflicts using session IDs + pub async fn resolve_socket_conflicts(&self) -> Result> { + let mut resolved_paths = HashMap::new(); + + for socket_type in SocketType::all() { + let base_filename = socket_type.filename(); + let mut socket_path = self.socket_directory.join(base_filename); + + // Check if socket file already exists + if socket_path.exists() { + // Check if the existing socket is active + if self.is_socket_active(&socket_path).await { + // Socket is active, create a new one with session ID suffix + let filename_with_session = format!( + "{}.{}", + base_filename.trim_end_matches(".sock"), + &self.session_id[..8] // Use first 8 chars of session ID + ); + socket_path = self.socket_directory.join(format!("{}.sock", filename_with_session)); + + println!( + "Socket conflict detected for {:?}, using session-specific path: {}", + socket_type, + socket_path.display() + ); + } else { + // Socket file exists but is not active, remove it + println!("Removing inactive socket file: {}", socket_path.display()); + if let Err(e) = fs::remove_file(&socket_path) { + eprintln!("Warning: Failed to remove inactive socket {}: {}", socket_path.display(), e); + } + } + } + + resolved_paths.insert(socket_type, socket_path); + } + + Ok(resolved_paths) + } + + /// Create sockets with conflict resolution and permission validation + pub async fn create_sockets_with_lifecycle_management(&mut self) -> Result<()> { + // Resolve any socket conflicts first + let resolved_paths = self.resolve_socket_conflicts().await?; + + // Create sockets using resolved paths + let mut errors = Vec::new(); + + for (socket_type, socket_path) in resolved_paths { + match self.create_socket_with_permissions(&socket_type, &socket_path).await { + Ok(_path) => { + // Socket info is already stored in the socket manager + } + Err(e) => { + errors.push(format!("{:?}: {}", socket_type, e)); + } + } + } + + if !errors.is_empty() { + return Err(eyre!("Failed to create some sockets: {}", errors.join(", "))); + } + + Ok(()) + } + + /// Create a single socket with proper permissions + async fn create_socket_with_permissions(&self, socket_type: &SocketType, socket_path: &Path) -> Result { + // Remove existing socket file if it exists + if socket_path.exists() { + fs::remove_file(socket_path) + .map_err(|e| eyre!("Failed to remove existing socket file: {}", e))?; + } + + // Create the Unix domain socket with proper permissions from the start + // We need to temporarily set umask to ensure correct permissions + let original_umask = unsafe { libc::umask(0o077) }; // This will create files with 600 permissions + + let listener_result = tokio::net::UnixListener::bind(socket_path); + + // Restore original umask + unsafe { libc::umask(original_umask) }; + + let listener = listener_result + .map_err(|e| eyre!("Failed to bind Unix socket: {}", e))?; + + + // Store the listener in the socket manager + { + let mut manager = self.socket_manager.lock() + .map_err(|_| eyre!("Failed to lock socket manager"))?; + + // Create socket info and store it + let socket_info = super::socket_manager::SocketInfo { + socket_type: socket_type.clone(), + path: socket_path.to_path_buf(), + created_at: SystemTime::now(), + listener: Some(Arc::new(listener)), + }; + + // Insert the socket info into the manager + manager.sockets.insert(socket_type.clone(), socket_info); + } + + Ok(socket_path.to_path_buf()) + } + + /// Register signal handlers for graceful shutdown + async fn register_signal_handlers(&self) -> Result<()> { + let cleanup_performed = Arc::clone(&self.cleanup_performed); + let socket_manager = Arc::clone(&self.socket_manager); + let connection_handler = self.connection_handler.clone(); + + // Spawn a task to handle shutdown signals + tokio::spawn(async move { + let mut sigint = signal::unix::signal(signal::unix::SignalKind::interrupt()) + .expect("Failed to register SIGINT handler"); + let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("Failed to register SIGTERM handler"); + + tokio::select! { + _ = sigint.recv() => { + println!("\nReceived SIGINT, initiating graceful shutdown..."); + } + _ = sigterm.recv() => { + println!("\nReceived SIGTERM, initiating graceful shutdown..."); + } + } + + // Perform cleanup with proper scope management to avoid Send issues + let cleanup_needed = { + let mut cleanup_done = match cleanup_performed.lock() { + Ok(guard) => guard, + Err(_) => return, + }; + + if *cleanup_done { + false + } else { + *cleanup_done = true; + true + } + }; // Release the lock before async operations + + if !cleanup_needed { + return; + } + + // Shutdown connection handler if available + if let Some(handler) = connection_handler { + // We need to handle the async shutdown in a way that doesn't hold the mutex guard + // across the await. For now, we'll just log that we're initiating shutdown. + // In a full implementation, we'd need a more sophisticated approach. + match handler.lock() { + Ok(_handler_guard) => { + // Note: We can't easily call async shutdown here due to Send constraints + // In a real implementation, we'd need to restructure this or use channels + println!("Initiating connection handler shutdown..."); + // For now, we'll just drop the guard and let the Drop impl handle cleanup + } + Err(_) => { + eprintln!("Failed to lock connection handler during shutdown"); + } + } + } + + // Cleanup socket manager + { + let mut manager = match socket_manager.lock() { + Ok(guard) => guard, + Err(_) => { + eprintln!("Failed to lock socket manager during shutdown"); + return; + } + }; + + if let Err(e) = manager.cleanup() { + eprintln!("Error during socket cleanup: {}", e); + } + } + + std::process::exit(0); + }); + + Ok(()) + } + + /// Register cleanup handler for normal program exit + fn register_cleanup_handler(&self) -> Result<()> { + let cleanup_performed = Arc::clone(&self.cleanup_performed); + let socket_manager = Arc::clone(&self.socket_manager); + let connection_handler = self.connection_handler.clone(); + + // Create a cleanup closure that will be called on normal exit + let _cleanup_fn = move || -> Result<()> { + // Check if cleanup was already performed + { + let mut cleanup_done = cleanup_performed.lock() + .map_err(|_| eyre!("Failed to lock cleanup flag"))?; + + if *cleanup_done { + return Ok(()); + } + *cleanup_done = true; + } + + println!("Performing normal exit cleanup..."); + + // Cleanup connection handler if available + if let Some(_handler) = &connection_handler { + // Note: We can't call async shutdown here in a sync context + // In a real implementation, we'd need a different approach + println!("Connection handler cleanup initiated"); + } + + // Cleanup socket manager + if let Ok(mut manager) = socket_manager.lock() { + manager.cleanup()?; + } + + Ok(()) + }; + + // Store the cleanup function for later use + // Note: In a real implementation, you'd register this with std::process::at_exit + // or a similar mechanism. For now, we'll store it in our cleanup handlers. + + Ok(()) + } + + /// Add a custom cleanup handler + pub fn add_cleanup_handler(&mut self, handler: F) -> Result<()> + where + F: Fn() -> Result<()> + Send + Sync + 'static, + { + self.cleanup_handlers.push(Box::new(handler)); + Ok(()) + } + + /// Perform manual cleanup + pub async fn cleanup(&mut self) -> Result<()> { + let mut cleanup_done = self.cleanup_performed.lock() + .map_err(|_| eyre!("Failed to lock cleanup flag"))?; + + if *cleanup_done { + return Ok(()); + } + + *cleanup_done = true; + + println!("Performing socket lifecycle cleanup..."); + + // Run custom cleanup handlers + for handler in &self.cleanup_handlers { + if let Err(e) = handler() { + eprintln!("Warning: Cleanup handler failed: {}", e); + } + } + + // Shutdown connection handler + // Note: Connection handler shutdown is now managed by ConnectionGuard + // which is handled automatically when the guard is dropped + + // Cleanup socket manager + { + let mut manager = self.socket_manager.lock() + .map_err(|_| eyre!("Failed to lock socket manager"))?; + manager.cleanup()?; + } + + println!("Socket lifecycle cleanup complete"); + Ok(()) + } + + /// Get socket manager reference + pub fn socket_manager(&self) -> Arc> { + Arc::clone(&self.socket_manager) + } + + /// Get session ID + pub fn session_id(&self) -> &str { + &self.session_id + } + + /// Get socket directory path + pub fn socket_directory(&self) -> &Path { + &self.socket_directory + } +} + +impl Drop for SocketLifecycleManager { + fn drop(&mut self) { + // Attempt cleanup on drop (best effort) + if let Ok(cleanup_done) = self.cleanup_performed.lock() { + if !*cleanup_done { + // Run synchronous cleanup handlers + for handler in &self.cleanup_handlers { + let _ = handler(); + } + + // Cleanup socket manager + if let Ok(mut manager) = self.socket_manager.lock() { + let _ = manager.cleanup(); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + use std::os::unix::fs::PermissionsExt; + + #[tokio::test] + async fn test_lifecycle_manager_creation() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let manager = SocketLifecycleManager::new(temp_dir.path().to_path_buf()) + .expect("Failed to create lifecycle manager"); + + assert!(!manager.session_id.is_empty()); + assert_eq!(manager.working_directory, temp_dir.path()); + + // Disable cleanup for testing + manager.socket_manager.lock().unwrap().disable_cleanup_on_drop(); + } + + #[tokio::test] + async fn test_socket_directory_creation() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let mut manager = SocketLifecycleManager::new(temp_dir.path().to_path_buf()) + .expect("Failed to create lifecycle manager"); + + manager.socket_manager.lock().unwrap().disable_cleanup_on_drop(); + + // Initialize should create the socket directory + manager.initialize().await.expect("Failed to initialize"); + + assert!(manager.socket_directory.exists()); + + // Check directory permissions + let metadata = fs::metadata(&manager.socket_directory).expect("Failed to get metadata"); + let permissions = metadata.permissions().mode() & 0o777; + assert_eq!(permissions, SOCKET_DIR_PERMISSIONS); + } + + #[tokio::test] + async fn test_permission_validation_and_correction() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let manager = SocketLifecycleManager::new(temp_dir.path().to_path_buf()) + .expect("Failed to create lifecycle manager"); + + manager.socket_manager.lock().unwrap().disable_cleanup_on_drop(); + + // Create directory with wrong permissions + fs::create_dir_all(&manager.socket_directory).expect("Failed to create directory"); + let mut perms = fs::metadata(&manager.socket_directory).unwrap().permissions(); + perms.set_mode(0o755); // Wrong permissions + fs::set_permissions(&manager.socket_directory, perms).expect("Failed to set permissions"); + + // Validate and fix permissions + manager.validate_and_fix_permissions(&manager.socket_directory, SOCKET_DIR_PERMISSIONS, true) + .await + .expect("Failed to validate permissions"); + + // Check that permissions were corrected + let metadata = fs::metadata(&manager.socket_directory).expect("Failed to get metadata"); + let permissions = metadata.permissions().mode() & 0o777; + assert_eq!(permissions, SOCKET_DIR_PERMISSIONS); + } + + #[tokio::test] + async fn test_stale_socket_cleanup() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let manager = SocketLifecycleManager::new(temp_dir.path().to_path_buf()) + .expect("Failed to create lifecycle manager"); + + manager.socket_manager.lock().unwrap().disable_cleanup_on_drop(); + + // Create socket directory + fs::create_dir_all(&manager.socket_directory).expect("Failed to create directory"); + + // Create a fake stale socket file + let stale_socket = manager.socket_directory.join("stale.sock"); + fs::write(&stale_socket, "").expect("Failed to create stale socket file"); + + // Set old modification time (simulate stale file) + // Note: This is a simplified test - in reality, we'd need to manipulate file timestamps + + assert!(stale_socket.exists()); + + // Cleanup should handle stale sockets + manager.cleanup_stale_sockets().await.expect("Failed to cleanup stale sockets"); + + // The stale socket should still exist because we can't easily manipulate timestamps in tests + // In a real scenario, old sockets would be removed + } + + #[tokio::test] + async fn test_socket_conflict_resolution() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let manager = SocketLifecycleManager::new(temp_dir.path().to_path_buf()) + .expect("Failed to create lifecycle manager"); + + manager.socket_manager.lock().unwrap().disable_cleanup_on_drop(); + + // Create socket directory + fs::create_dir_all(&manager.socket_directory).expect("Failed to create directory"); + + // Create a fake existing socket file and bind to it to make it "active" + let existing_socket = manager.socket_directory.join("control.sock"); + let _listener = tokio::net::UnixListener::bind(&existing_socket) + .expect("Failed to bind to existing socket"); + + // Resolve conflicts - this should detect the active socket and create a new path + let resolved_paths = manager.resolve_socket_conflicts().await + .expect("Failed to resolve socket conflicts"); + + // Should have resolved paths for all socket types + assert_eq!(resolved_paths.len(), 6); + + // Control socket should have a different path due to conflict + let control_path = resolved_paths.get(&SocketType::Control).unwrap(); + assert_ne!(control_path, &existing_socket); + assert!(control_path.to_string_lossy().contains(&manager.session_id[..8])); + + // Other sockets should use their normal paths + let input_path = resolved_paths.get(&SocketType::Input).unwrap(); + assert_eq!(input_path, &manager.socket_directory.join("input.sock")); + } + + #[tokio::test] + async fn test_cleanup_handlers() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let mut manager = SocketLifecycleManager::new(temp_dir.path().to_path_buf()) + .expect("Failed to create lifecycle manager"); + + manager.socket_manager.lock().unwrap().disable_cleanup_on_drop(); + + let cleanup_called = Arc::new(Mutex::new(false)); + let cleanup_called_clone = Arc::clone(&cleanup_called); + + // Add a custom cleanup handler + manager.add_cleanup_handler(move || { + *cleanup_called_clone.lock().unwrap() = true; + Ok(()) + }).expect("Failed to add cleanup handler"); + + // Perform cleanup + manager.cleanup().await.expect("Failed to cleanup"); + + // Check that cleanup handler was called + assert!(*cleanup_called.lock().unwrap()); + } + + #[test] + fn test_constants() { + assert_eq!(SOCKET_FILE_PERMISSIONS, 0o600); + assert_eq!(SOCKET_DIR_PERMISSIONS, 0o700); + assert!(STALE_SOCKET_MAX_AGE_SECONDS > 0); + } +} diff --git a/crates/chat-cli/src/cli/chat/api/message_router.rs b/crates/chat-cli/src/cli/chat/api/message_router.rs new file mode 100644 index 000000000..dcc4f09be --- /dev/null +++ b/crates/chat-cli/src/cli/chat/api/message_router.rs @@ -0,0 +1,448 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::SystemTime; + +use eyre::Result; +use serde_json::Value; +use tokio::sync::broadcast; +use tracing::{debug, error, info, warn}; + +use super::connection_handler::ConnectionHandler; +use super::protocol::{MessageType, SocketMessage}; +use super::socket_manager::{SocketManager, SocketType}; + +/// Core message router for distributing messages between chat session and sockets +#[derive(Debug)] +pub struct MessageRouter { + /// Reference to the socket manager for socket operations + socket_manager: Arc>, + /// Connection handler for managing client connections + connection_handler: Arc>, + /// Broadcast channels for each socket type + broadcasters: HashMap>, + /// Session ID for this router instance + session_id: String, + /// Message validation settings + max_message_size: usize, + /// Statistics tracking + stats: Arc>, +} + +/// Statistics for message routing operations +#[derive(Debug, Default)] +pub struct RouterStats { + /// Total messages routed + pub messages_routed: u64, + /// Messages routed by socket type + pub messages_by_type: HashMap, + /// Validation errors encountered + pub validation_errors: u64, + /// Routing errors encountered + pub routing_errors: u64, + /// Last activity timestamp + pub last_activity: Option, +} + +/// Error types for message routing operations +#[derive(Debug, thiserror::Error)] +pub enum MessageRouterError { + #[error("Message validation failed: {reason}")] + ValidationError { reason: String }, + #[error("Message too large: {size} bytes (max: {max_size} bytes)")] + MessageTooLarge { size: usize, max_size: usize }, + #[error("Invalid JSON format: {error}")] + InvalidJson { error: String }, + #[error("Unsupported message type: {message_type}")] + UnsupportedMessageType { message_type: String }, + #[error("Socket type not available: {socket_type:?}")] + SocketTypeUnavailable { socket_type: SocketType }, + #[error("Broadcast failed: {error}")] + BroadcastFailed { error: String }, +} + +impl MessageRouter { + /// Create a new message router with shared broadcasters + pub fn new( + socket_manager: Arc>, + broadcasters: HashMap>, + session_id: String, + ) -> Self { + // Create a dummy connection handler since we won't use it for client management + let dummy_handler = ConnectionHandler::new(Arc::clone(&socket_manager), HashMap::new()); + + Self { + socket_manager, + connection_handler: Arc::new(Mutex::new(dummy_handler)), + broadcasters, + session_id, + max_message_size: 1024 * 1024, // 1MB default limit + stats: Arc::new(Mutex::new(RouterStats::default())), + } + } + + /// Set the maximum message size for validation + pub fn set_max_message_size(&mut self, max_size: usize) { + self.max_message_size = max_size; + info!("Message router max size set to {} bytes", max_size); + } + + /// Route a response message to the output socket + pub async fn route_output_message(&self, content: &str, formatted: bool) -> Result<(), MessageRouterError> { + let message = SocketMessage::new(MessageType::Response { + content: content.to_string(), + formatted, + }); + + self.route_message_to_socket(SocketType::Output, message).await + } + + /// Route a thinking message to the thinking socket + pub async fn route_thinking_message(&self, content: &str, step: Option) -> Result<(), MessageRouterError> { + let message = SocketMessage::new(MessageType::Thinking { + content: content.to_string(), + step, + }); + + self.route_message_to_socket(SocketType::Thinking, message).await + } + + /// Route a tool request message to the tools socket + pub async fn route_tool_request(&self, tool_name: &str, parameters: Value, id: &str) -> Result<(), MessageRouterError> { + let message = SocketMessage::new(MessageType::ToolRequest { + tool_name: tool_name.to_string(), + parameters, + id: id.to_string(), + }); + + self.route_message_to_socket(SocketType::Tools, message).await + } + + /// Route a tool response message to the tools socket + pub async fn route_tool_response(&self, id: &str, result: Value, status: super::protocol::ToolStatus) -> Result<(), MessageRouterError> { + let message = SocketMessage::new(MessageType::ToolResponse { + id: id.to_string(), + result, + status, + }); + + self.route_message_to_socket(SocketType::Tools, message).await + } + + /// Route a message to a specific socket type + async fn route_message_to_socket( + &self, + socket_type: SocketType, + message: SocketMessage, + ) -> Result<(), MessageRouterError> { + // Validate the message + self.validate_message(&message).await?; + + // Serialize the message + let json_message = message.to_json() + .map_err(|e| MessageRouterError::InvalidJson { error: e.to_string() })?; + + // Broadcast the message to connected clients + // Note: We don't check client count because broadcast channels handle no receivers gracefully + if let Some(broadcaster) = self.broadcasters.get(&socket_type) { + match broadcaster.send(json_message) { + Ok(_receiver_count) => { + // Update statistics + self.update_stats(socket_type, true); + }, + Err(e) => { + self.update_stats(socket_type, false); + return Err(MessageRouterError::BroadcastFailed { + error: e.to_string(), + }); + } + } + } else { + return Err(MessageRouterError::SocketTypeUnavailable { socket_type }); + } + + Ok(()) + } + + /// Validate a message before routing + async fn validate_message(&self, message: &SocketMessage) -> Result<(), MessageRouterError> { + // Serialize to check size + let json_str = message.to_json() + .map_err(|e| MessageRouterError::InvalidJson { error: e.to_string() })?; + + // Check message size + let message_size = json_str.len(); + if message_size > self.max_message_size { + self.increment_validation_errors(); + return Err(MessageRouterError::MessageTooLarge { + size: message_size, + max_size: self.max_message_size, + }); + } + + // Validate message structure + if message.timestamp.is_empty() { + self.increment_validation_errors(); + return Err(MessageRouterError::ValidationError { + reason: "Message timestamp cannot be empty".to_string(), + }); + } + + // Validate timestamp format + if chrono::DateTime::parse_from_rfc3339(&message.timestamp).is_err() { + self.increment_validation_errors(); + return Err(MessageRouterError::ValidationError { + reason: "Invalid timestamp format, expected RFC3339".to_string(), + }); + } + + debug!("Message validation passed for {:?} message", message.message_type); + Ok(()) + } + + /// Process an incoming message from a client + pub async fn process_incoming_message( + &self, + socket_type: SocketType, + raw_message: &str, + ) -> Result, MessageRouterError> { + // Validate message size + if raw_message.len() > self.max_message_size { + self.increment_validation_errors(); + return Err(MessageRouterError::MessageTooLarge { + size: raw_message.len(), + max_size: self.max_message_size, + }); + } + + // Parse the JSON message + let message: SocketMessage = SocketMessage::from_json(raw_message) + .map_err(|e| MessageRouterError::InvalidJson { error: e.to_string() })?; + + // Validate the parsed message + self.validate_message(&message).await?; + + // Process based on message type and socket type + match (&socket_type, &message.message_type) { + (SocketType::Input, MessageType::UserInput { text }) => { + info!("Received user input: {}", text); + Ok(Some(text.clone())) + }, + (SocketType::Input, MessageType::SlashCommand { command, args }) => { + let full_command = if args.is_empty() { + format!("/{}", command) + } else { + format!("/{} {}", command, args.join(" ")) + }; + info!("Received slash command: {}", full_command); + Ok(Some(full_command)) + }, + _ => { + warn!( + "Unexpected message type {:?} on socket {:?}", + message.message_type, socket_type + ); + Err(MessageRouterError::UnsupportedMessageType { + message_type: format!("{:?}", message.message_type), + }) + } + } + } + + /// Get a broadcast receiver for a specific socket type + pub fn get_broadcast_receiver(&self, socket_type: SocketType) -> Option> { + self.broadcasters.get(&socket_type).map(|sender| sender.subscribe()) + } + + /// Get current router statistics + pub fn get_stats(&self) -> RouterStats { + let stats = self.stats.lock().unwrap(); + RouterStats { + messages_routed: stats.messages_routed, + messages_by_type: stats.messages_by_type.clone(), + validation_errors: stats.validation_errors, + routing_errors: stats.routing_errors, + last_activity: stats.last_activity, + } + } + + /// Update statistics for a routing operation + fn update_stats(&self, socket_type: SocketType, success: bool) { + let mut stats = self.stats.lock().unwrap(); + + if success { + stats.messages_routed += 1; + *stats.messages_by_type.entry(socket_type).or_insert(0) += 1; + } else { + stats.routing_errors += 1; + } + + stats.last_activity = Some(SystemTime::now()); + } + + /// Increment validation error count + fn increment_validation_errors(&self) { + let mut stats = self.stats.lock().unwrap(); + stats.validation_errors += 1; + stats.last_activity = Some(SystemTime::now()); + } + + /// Get the session ID + pub fn session_id(&self) -> &str { + &self.session_id + } + + /// Check if any clients are connected to any socket + pub async fn has_connected_clients(&self) -> bool { + let connection_handler = self.connection_handler.lock().unwrap(); + connection_handler.connection_count() > 0 + } + + /// Get connection count for a specific socket type + pub async fn connection_count_for_socket(&self, socket_type: &SocketType) -> usize { + let connection_handler = self.connection_handler.lock().unwrap(); + connection_handler.connection_count_for_socket(socket_type) + } + + /// Shutdown the message router + pub async fn shutdown(&self) -> Result<()> { + info!("Shutting down message router for session {}", self.session_id); + + // Send shutdown notification to all connected clients + for socket_type in SocketType::all() { + let shutdown_message = SocketMessage::new(MessageType::SessionEnd { + reason: "Chat session ending".to_string(), + }); + + // Best effort - don't fail shutdown if we can't notify clients + let _ = self.route_message_to_socket(socket_type, shutdown_message).await; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + async fn create_test_router() -> (MessageRouter, TempDir) { + let temp_dir = TempDir::new().expect("Failed to create temp directory"); + + // Create socket manager with std mutex for both router and connection handler + let socket_manager = Arc::new(Mutex::new( + SocketManager::new(temp_dir.path()).expect("Failed to create socket manager") + )); + + // Create shared broadcast channels for all socket types + let mut shared_broadcasters = std::collections::HashMap::new(); + for socket_type in crate::cli::chat::api::SocketType::all() { + let (sender, _) = tokio::sync::broadcast::channel(1000); + shared_broadcasters.insert(socket_type, sender); + } + + // Create connection handler + let connection_handler = Arc::new(Mutex::new( + ConnectionHandler::new(Arc::clone(&socket_manager), shared_broadcasters.clone()) + )); + + let router = MessageRouter::new( + socket_manager, + shared_broadcasters, + "test-session".to_string(), + ); + + (router, temp_dir) + } + + #[tokio::test] + async fn test_message_router_creation() { + let (router, _temp_dir) = create_test_router().await; + + assert_eq!(router.session_id(), "test-session"); + assert_eq!(router.max_message_size, 1024 * 1024); + assert!(!router.has_connected_clients().await); + } + + #[tokio::test] + async fn test_output_message_routing() { + let (router, _temp_dir) = create_test_router().await; + + // Should succeed even with no connected clients (just logs a debug message) + let result = router.route_output_message("Hello, world!", false).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_thinking_message_routing() { + let (router, _temp_dir) = create_test_router().await; + + let result = router.route_thinking_message("Thinking about the problem...", Some("step1".to_string())).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_tool_request_routing() { + let (router, _temp_dir) = create_test_router().await; + + let params = serde_json::json!({"param1": "value1"}); + let result = router.route_tool_request("test_tool", params, "tool-123").await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_message_validation() { + let (router, _temp_dir) = create_test_router().await; + + // Test valid message + let valid_message = SocketMessage::new(MessageType::Response { + content: "Test".to_string(), + formatted: false, + }); + assert!(router.validate_message(&valid_message).await.is_ok()); + } + + #[tokio::test] + async fn test_statistics_tracking() { + let (router, _temp_dir) = create_test_router().await; + + // Route some messages (they won't actually be sent since no clients are connected, + // but the validation will still update stats) + let _ = router.route_output_message("Test message 1", false).await; + let _ = router.route_thinking_message("Test thinking", None).await; + + let stats = router.get_stats(); + // Note: messages_routed will be 0 since no clients are connected, + // but last_activity should be set from validation + // Since we're not actually routing (no clients), let's just verify the router works + assert_eq!(stats.messages_routed, 0); // No clients connected, so no actual routing + assert_eq!(stats.validation_errors, 0); // But validation should succeed + } + + #[tokio::test] + async fn test_broadcast_receiver() { + let (router, _temp_dir) = create_test_router().await; + + let receiver = router.get_broadcast_receiver(SocketType::Output); + assert!(receiver.is_some()); + + let mut rx = receiver.unwrap(); + + // This should not block since no message is sent + let result = rx.try_recv(); + assert!(matches!(result, Err(broadcast::error::TryRecvError::Empty))); + } + + #[tokio::test] + async fn test_max_message_size_configuration() { + let (mut router, _temp_dir) = create_test_router().await; + + router.set_max_message_size(1024); // 1KB limit + assert_eq!(router.max_message_size, 1024); + + // Test that the new limit is enforced + let large_content = "x".repeat(2048); + let result = router.route_output_message(&large_content, false).await; + assert!(matches!(result, Err(MessageRouterError::MessageTooLarge { .. }))); + } +} diff --git a/crates/chat-cli/src/cli/chat/api/mod.rs b/crates/chat-cli/src/cli/chat/api/mod.rs new file mode 100644 index 000000000..9c7bee0e8 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/api/mod.rs @@ -0,0 +1,20 @@ +//! Q Chat API mode implementation +//! +//! This module contains the implementation for Q Chat's API mode, +//! which provides Unix domain sockets for programmatic interaction alongside +//! the normal terminal interface. +//! +//! Simplified version with Control and Events sockets suppressed. + +pub mod connection_handler; +pub mod lifecycle_manager; +pub mod message_router; +pub mod protocol; +pub mod socket_manager; + +// Re-export commonly used types +pub use connection_handler::ConnectionHandler; +pub use lifecycle_manager::SocketLifecycleManager; +pub use message_router::MessageRouter; +pub use protocol::ToolStatus; +pub use socket_manager::SocketType; diff --git a/crates/chat-cli/src/cli/chat/api/protocol.rs b/crates/chat-cli/src/cli/chat/api/protocol.rs new file mode 100644 index 000000000..162f1bc66 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/api/protocol.rs @@ -0,0 +1,97 @@ +//! Socket message protocol for Q Chat API mode +//! +//! This module defines the JSON-based message protocol used for communication +//! between Q Chat and external clients through Unix domain sockets. +//! + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Core message structure for all socket communication +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SocketMessage { + /// ISO 8601 timestamp when the message was created + pub timestamp: String, + /// The specific message type and its data + #[serde(flatten)] + pub message_type: MessageType, + /// Additional metadata or context (optional) + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub data: HashMap, +} + +/// Message types for the remaining active sockets (Input, Output, Thinking, Tools) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type")] +pub enum MessageType { + // Input messages from clients to Q Chat + UserInput { + text: String + }, + SlashCommand { + command: String, + args: Vec + }, + + // Output messages from Q Chat to clients + Response { + content: String, + formatted: bool + }, + + // Thinking messages (internal reasoning) + Thinking { + content: String, + step: Option + }, + + // Tool execution messages + ToolRequest { + tool_name: String, + parameters: serde_json::Value, + id: String + }, + ToolResponse { + id: String, + result: serde_json::Value, + status: ToolStatus + }, + + // Session lifecycle messages + SessionEnd { + reason: String + }, +} + +/// Status of tool execution +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ToolStatus { + Pending, + Approved, + Rejected, + Executing, + Success, + Error, +} + +impl SocketMessage { + /// Create a new socket message with the current timestamp + pub fn new(message_type: MessageType) -> Self { + Self { + timestamp: chrono::Utc::now().to_rfc3339(), + message_type, + data: HashMap::new(), + } + } + + /// Convert message to JSON string + pub fn to_json(&self) -> Result { + serde_json::to_string(self) + } + + /// Create message from JSON string + pub fn from_json(json: &str) -> Result { + serde_json::from_str(json) + } +} diff --git a/crates/chat-cli/src/cli/chat/api/socket_manager.rs b/crates/chat-cli/src/cli/chat/api/socket_manager.rs new file mode 100644 index 000000000..8cc3f8602 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/api/socket_manager.rs @@ -0,0 +1,189 @@ +//! Socket manager for Q Chat API mode +//! +//! This module manages Unix domain sockets for programmatic interaction with Q Chat. +//! It handles socket creation, cleanup, and directory management. + +use std::collections::HashMap; +use std::fs; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::SystemTime; + +use eyre::{Result, eyre}; +use sha2::{Digest, Sha256}; +use tokio::net::UnixListener; + +/// Types of sockets that can be created +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum SocketType { + Control, + Input, + Output, + Thinking, + Tools, + Events, +} + +impl SocketType { + /// Get the filename for this socket type + pub fn filename(&self) -> &'static str { + match self { + SocketType::Control => "control.sock", + SocketType::Input => "input.sock", + SocketType::Output => "output.sock", + SocketType::Thinking => "thinking.sock", + SocketType::Tools => "tools.sock", + SocketType::Events => "events.sock", + } + } + + /// Get all socket types (excluding Control and Events which are suppressed) + pub fn all() -> Vec { + vec![ + // SocketType::Control, // Suppressed - control socket management disabled + SocketType::Input, + SocketType::Output, + SocketType::Thinking, + SocketType::Tools, + // SocketType::Events, // Suppressed - events socket management disabled + ] + } +} + +/// Information about a created socket +#[derive(Debug, Clone)] +pub struct SocketInfo { + pub socket_type: SocketType, + pub path: PathBuf, + pub created_at: SystemTime, + pub listener: Option>, +} + +/// Manages Unix domain sockets for Q Chat API mode +#[derive(Debug)] +pub struct SocketManager { + /// Session ID for this socket manager instance + pub session_id: String, + /// Working directory hash for socket path generation + pub working_directory_hash: String, + /// Directory where sockets are created + pub socket_directory: PathBuf, + /// Map of socket types to their information + pub sockets: HashMap, + /// Whether to clean up sockets on drop + cleanup_on_drop: bool, +} + +impl SocketManager { + /// Create a new SocketManager for the given working directory + pub fn new(working_directory: &Path) -> Result { + let session_id = uuid::Uuid::new_v4().to_string(); + + // Create a hash of the working directory for unique socket paths + let mut hasher = Sha256::new(); + hasher.update(working_directory.to_string_lossy().as_bytes()); + let hash = hasher.finalize(); + let working_directory_hash = format!("{:x}", hash)[..16].to_string(); + + // Create socket directory in /tmp + let socket_directory = PathBuf::from("/tmp").join(format!("q-chat-{}", working_directory_hash)); + + // Create the directory if it doesn't exist + if !socket_directory.exists() { + fs::create_dir_all(&socket_directory) + .map_err(|e| eyre!("Failed to create socket directory: {}", e))?; + } + + Ok(Self { + session_id, + working_directory_hash, + socket_directory, + sockets: HashMap::new(), + cleanup_on_drop: true, + }) + } + + /// Get all socket paths as a HashMap + pub fn get_all_socket_paths(&self) -> HashMap { + self.sockets.iter() + .map(|(socket_type, info)| (socket_type.clone(), info.path.clone())) + .collect() + } + + /// Get the listener for a specific socket type + pub fn get_listener(&self, socket_type: &SocketType) -> Option> { + self.sockets.get(socket_type) + .and_then(|info| info.listener.clone()) + } + + /// Disable cleanup on drop (useful for testing) + pub fn disable_cleanup_on_drop(&mut self) { + self.cleanup_on_drop = false; + } + + /// Clean up all sockets and the socket directory + pub fn cleanup(&mut self) -> Result<()> { + // Remove all socket files + for (_, socket_info) in &self.sockets { + if socket_info.path.exists() { + if let Err(e) = fs::remove_file(&socket_info.path) { + eprintln!("Warning: Failed to remove socket file {:?}: {}", socket_info.path, e); + } + } + } + + // Clear the sockets map + self.sockets.clear(); + + // Try to remove the socket directory if it's empty + if self.socket_directory.exists() { + if let Err(e) = fs::remove_dir(&self.socket_directory) { + // It's okay if this fails (directory might not be empty) + eprintln!("Note: Could not remove socket directory {:?}: {}", self.socket_directory, e); + } + } + + Ok(()) + } +} + +impl Drop for SocketManager { + fn drop(&mut self) { + if self.cleanup_on_drop { + if let Err(e) = self.cleanup() { + eprintln!("Warning: Failed to cleanup sockets during drop: {}", e); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + + #[test] + fn test_socket_manager_creation() { + let temp_dir = env::temp_dir(); + let mut manager = SocketManager::new(&temp_dir).unwrap(); + + manager.disable_cleanup_on_drop(); + assert!(!manager.session_id.is_empty()); + assert_eq!(manager.working_directory_hash.len(), 16); + assert!(manager.socket_directory.exists()); + assert!(manager.sockets.is_empty()); + } + + #[test] + fn test_socket_types() { + assert_eq!(SocketType::Control.filename(), "control.sock"); + assert_eq!(SocketType::Input.filename(), "input.sock"); + assert_eq!(SocketType::Output.filename(), "output.sock"); + assert_eq!(SocketType::Thinking.filename(), "thinking.sock"); + assert_eq!(SocketType::Tools.filename(), "tools.sock"); + assert_eq!(SocketType::Events.filename(), "events.sock"); + + let all_types = SocketType::all(); + assert_eq!(all_types.len(), 6); + } +} diff --git a/crates/chat-cli/src/cli/chat/api_chat_session.rs b/crates/chat-cli/src/cli/chat/api_chat_session.rs new file mode 100644 index 000000000..f4e4dd3bb --- /dev/null +++ b/crates/chat-cli/src/cli/chat/api_chat_session.rs @@ -0,0 +1,154 @@ +use std::sync::Arc; +use eyre::Result; +use crate::cli::chat::ChatSession; +use crate::cli::chat::api::{SocketLifecycleManager, MessageRouter, SocketType}; +use crate::cli::agent::Agents; +use crate::os::Os; + +/// ApiChatSession extends ChatSession with socket support for hybrid API mode +pub struct ApiChatSession { + /// The underlying chat session that handles all terminal interaction + chat_session: ChatSession, + /// Socket lifecycle manager for creating and managing Unix domain sockets + lifecycle_manager: SocketLifecycleManager, + /// Message router for distributing messages to socket clients + message_router: Arc, + /// Session ID for this API session + session_id: String, +} + +impl ApiChatSession { + /// Create a new ApiChatSession with socket support + pub async fn new( + os: &mut Os, + stdout: std::io::Stdout, + stderr: std::io::Stderr, + conversation_id: &str, + agents: Agents, + input: Option, + input_source: crate::cli::chat::input_source::InputSource, + resume_conversation: bool, + terminal_width_provider: fn() -> Option, + tool_manager: crate::cli::chat::tool_manager::ToolManager, + model_id: Option, + tool_config: std::collections::HashMap, + interactive: bool, + working_directory: &std::path::Path, + ) -> Result { + // Create socket lifecycle manager + let mut lifecycle_manager = SocketLifecycleManager::new(working_directory.to_path_buf())?; + + // Initialize sockets + lifecycle_manager.initialize().await?; + + // Get socket manager and create message router + let socket_manager = lifecycle_manager.socket_manager(); + let session_id = lifecycle_manager.session_id().to_string(); + + // Create shared broadcast channels for all socket types + let mut shared_broadcasters = std::collections::HashMap::new(); + for socket_type in SocketType::all() { + let (sender, _) = tokio::sync::broadcast::channel(1000); + shared_broadcasters.insert(socket_type, sender); + } + + // Create connection handler and message router with shared broadcasters + let mut connection_handler = crate::cli::chat::api::ConnectionHandler::new( + Arc::clone(&socket_manager), + shared_broadcasters.clone(), + ); + + let message_router = Arc::new(MessageRouter::new( + socket_manager, + shared_broadcasters, + session_id.clone(), + )); + + // Start connection handler in background + tokio::spawn(async move { + if let Err(e) = connection_handler.start_accepting_connections().await { + eprintln!("Connection handler error: {}", e); + } + }); + + // Create ChatSession with message router integration + let chat_session = ChatSession::new( + os, + stdout, + stderr, + conversation_id, + agents, + input, + input_source, + resume_conversation, + terminal_width_provider, + tool_manager, + model_id, + tool_config, + interactive, + None, // No lifecycle manager for ChatSession + Some(Arc::clone(&message_router)), + ).await?; + + Ok(Self { + chat_session, + lifecycle_manager, + message_router, + session_id, + }) + } + + /// Get the session ID for this API session + pub fn session_id(&self) -> &str { + &self.session_id + } + + /// Start the API chat session + pub async fn run(&mut self, os: &mut Os) -> Result<()> { + // Run the main chat session (this handles terminal interaction) + self.chat_session.spawn(os).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + use crate::cli::agent::Agents; + use crate::cli::chat::{ + input_source::InputSource, + tool_manager::ToolManager, + tools::ToolSpec, + }; + use std::collections::HashMap; + + #[tokio::test] + async fn test_api_chat_session_creation() { + let temp_dir = TempDir::new().unwrap(); + let mut os = Os::new().await.unwrap(); + + let agents = Agents::default(); + let tool_manager = ToolManager::default(); + let tool_config: HashMap = HashMap::new(); + + let api_session = ApiChatSession::new( + &mut os, + std::io::stdout(), + std::io::stderr(), + "test_conversation", + agents, + None, + InputSource::new_mock(vec!["exit".to_string()]), + false, + || Some(80), + tool_manager, + None, + tool_config, + true, + temp_dir.path(), + ).await.unwrap(); + + // Verify session was created successfully + assert!(!api_session.session_id().is_empty()); + } +} diff --git a/crates/chat-cli/src/cli/chat/cli/clear.rs b/crates/chat-cli/src/cli/chat/cli/clear.rs index 7f2bd9d9a..5eb8cfd2e 100644 --- a/crates/chat-cli/src/cli/chat/cli/clear.rs +++ b/crates/chat-cli/src/cli/chat/cli/clear.rs @@ -41,7 +41,7 @@ impl ClearArgs { )?; // Setting `exit_on_single_ctrl_c` for better ux: exit the confirmation dialog rather than the CLI - let user_input = match session.read_user_input("> ".yellow().to_string().as_str(), true) { + let user_input = match session.read_user_input("> ".yellow().to_string().as_str(), true).await { Some(input) => input, None => "".to_string(), }; diff --git a/crates/chat-cli/src/cli/chat/cli/subscribe.rs b/crates/chat-cli/src/cli/chat/cli/subscribe.rs index c92090874..c0b4dab69 100644 --- a/crates/chat-cli/src/cli/chat/cli/subscribe.rs +++ b/crates/chat-cli/src/cli/chat/cli/subscribe.rs @@ -148,7 +148,7 @@ async fn upgrade_to_pro(os: &mut Os, session: &mut ChatSession) -> Result<(), Ch "]: ".dark_grey(), ); - let user_input = session.read_user_input(&prompt, true); + let user_input = session.read_user_input(&prompt, true).await; queue!( session.stderr, style::SetForegroundColor(Color::Reset), diff --git a/crates/chat-cli/src/cli/chat/input_source.rs b/crates/chat-cli/src/cli/chat/input_source.rs index 028b2e288..ab9f8a5be 100644 --- a/crates/chat-cli/src/cli/chat/input_source.rs +++ b/crates/chat-cli/src/cli/chat/input_source.rs @@ -1,5 +1,10 @@ use eyre::Result; use rustyline::error::ReadlineError; +use std::collections::VecDeque; +use std::sync::mpsc; +use std::thread; +use std::time::Duration; +use tokio::io::{AsyncBufReadExt, BufReader}; use super::prompt::rl; #[cfg(unix)] @@ -7,16 +12,21 @@ use super::skim_integration::SkimCommandSelector; use crate::os::Os; #[derive(Debug)] -pub struct InputSource(inner::Inner); +pub struct InputSource { + inner: inner::Inner, + injected_input: VecDeque, + injection_receiver: Option>, +} mod inner { use rustyline::Editor; use rustyline::history::FileHistory; + use std::sync::mpsc; + use std::thread::JoinHandle; use super::super::prompt::ChatHelper; #[allow(clippy::large_enum_variant)] - #[derive(Debug)] pub enum Inner { Readline(Editor), #[allow(dead_code)] @@ -24,6 +34,26 @@ mod inner { index: usize, lines: Vec, }, + Threaded { + input_receiver: mpsc::Receiver, rustyline::error::ReadlineError>>, + input_sender: mpsc::Sender, // For injecting prompts to the thread + _thread_handle: JoinHandle<()>, + }, + TokioStdin { + // Use tokio's stdin for async input + history: Vec, + }, + } + + impl std::fmt::Debug for Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Readline(_) => f.debug_tuple("Readline").field(&"").finish(), + Self::Mock { index, lines } => f.debug_struct("Mock").field("index", index).field("lines", lines).finish(), + Self::Threaded { .. } => f.debug_struct("Threaded").field("input_receiver", &"").field("input_sender", &"").field("_thread_handle", &"").finish(), + Self::TokioStdin { history } => f.debug_struct("TokioStdin").field("history", &format!("{} items", history.len())).finish(), + } + } } } @@ -33,7 +63,44 @@ impl InputSource { sender: std::sync::mpsc::Sender>, receiver: std::sync::mpsc::Receiver>, ) -> Result { - Ok(Self(inner::Inner::Readline(rl(os, sender, receiver)?))) + Ok(Self { + inner: inner::Inner::Readline(rl(os, sender, receiver)?), + injected_input: VecDeque::new(), + injection_receiver: None, + }) + } + + pub fn new_async_with_injection( + _os: &Os, + _sender: std::sync::mpsc::Sender>, + _receiver: std::sync::mpsc::Receiver>, + ) -> Result<(Self, std::sync::mpsc::Sender)> { + let (injection_sender, injection_receiver) = std::sync::mpsc::channel(); + + let input_source = Self { + inner: inner::Inner::TokioStdin { history: Vec::new() }, + injected_input: VecDeque::new(), + injection_receiver: Some(injection_receiver), + }; + + Ok((input_source, injection_sender)) + } + + /// Create a new InputSource with input injection capability from external sources + pub fn new_with_injection( + os: &Os, + sender: std::sync::mpsc::Sender>, + receiver: std::sync::mpsc::Receiver>, + ) -> Result<(Self, std::sync::mpsc::Sender)> { + let (injection_sender, injection_receiver) = std::sync::mpsc::channel(); + + let input_source = Self { + inner: inner::Inner::Readline(rl(os, sender, receiver)?), + injected_input: VecDeque::new(), + injection_receiver: Some(injection_receiver), + }; + + Ok((input_source, injection_sender)) } #[cfg(unix)] @@ -50,7 +117,7 @@ impl InputSource { use crate::database::settings::Setting; - if let inner::Inner::Readline(rl) = &mut self.0 { + if let inner::Inner::Readline(rl) = &mut self.inner { let key_char = match os.database.settings.get_string(Setting::SkimCommandKey) { Some(key) if key.len() == 1 => key.chars().next().unwrap_or('s'), _ => 's', // Default to 's' if setting is missing or invalid @@ -66,13 +133,195 @@ impl InputSource { } } + /// Create a new threaded InputSource that can be interrupted + pub fn new_threaded( + os: &Os, + sender: std::sync::mpsc::Sender>, + receiver: std::sync::mpsc::Receiver>, + ) -> Result { + let mut rl = rl(os, sender, receiver)?; + + // Create channels for communication with the input thread + let (input_tx, input_rx) = mpsc::channel(); + let (prompt_tx, prompt_rx): (mpsc::Sender, mpsc::Receiver) = mpsc::channel(); + + // Spawn the input reading thread + let thread_handle = thread::spawn(move || { + loop { + // Wait for a prompt request + match prompt_rx.recv() { + Ok(prompt) => { + // Do the blocking readline + let result = match rl.readline(&prompt) { + Ok(line) => { + let _ = rl.add_history_entry(line.as_str()); + if let Some(helper) = rl.helper_mut() { + helper.update_hinter_history(&line); + } + Ok(Some(line)) + }, + Err(ReadlineError::Interrupted | ReadlineError::Eof) => Ok(None), + Err(err) => Err(err), + }; + + // Send the result back + if input_tx.send(result).is_err() { + break; // Main thread disconnected + } + }, + Err(_) => break, // Channel closed + } + } + }); + + Ok(Self { + inner: inner::Inner::Threaded { + input_receiver: input_rx, + input_sender: prompt_tx, + _thread_handle: thread_handle, + }, + injected_input: VecDeque::new(), + injection_receiver: None, + }) + } + #[allow(dead_code)] pub fn new_mock(lines: Vec) -> Self { - Self(inner::Inner::Mock { index: 0, lines }) + Self { + inner: inner::Inner::Mock { index: 0, lines }, + injected_input: VecDeque::new(), + injection_receiver: None, + } + } + + /// Inject input that will be returned on the next read_line call + pub fn inject_input(&mut self, input: String) { + self.injected_input.push_back(input); + } + + /// Check if this InputSource uses async readline + pub fn is_async(&self) -> bool { + matches!(self.inner, inner::Inner::TokioStdin { .. }) + } + + /// Check for injected input without blocking or prompting + pub fn check_for_injected_input(&mut self) -> Option { + // Check for injected input from channel first (non-blocking) + if let Some(ref injection_receiver) = self.injection_receiver { + while let Ok(injected) = injection_receiver.try_recv() { + self.injected_input.push_back(injected); + } + } + + // Return any available injected input + if let Some(injected) = self.injected_input.pop_front() { + return Some(injected); + } + + None + } + + /// Async version of read_line that can be interrupted by injected input + pub async fn read_line_async(&mut self, prompt: Option<&str>) -> Result, ReadlineError> { + // Check for injected input first + if let Some(injected_input) = self.check_for_injected_input() { + return Ok(Some(injected_input)); + } + + match &mut self.inner { + inner::Inner::TokioStdin { history } => { + let prompt = prompt.unwrap_or(">> "); + + // Print prompt + print!("{}", prompt); + use std::io::Write; + std::io::stdout().flush().unwrap(); + + // Get injection receiver + let injection_receiver = self.injection_receiver.as_ref(); + + if let Some(injection_receiver) = injection_receiver { + // Create async stdin reader + let stdin = tokio::io::stdin(); + let mut reader = BufReader::new(stdin); + let mut line = String::new(); + + // Race between user input and injected input + loop { + tokio::select! { + // User input from stdin + result = reader.read_line(&mut line) => { + match result { + Ok(0) => { + return Ok(None); + }, + Ok(_) => { + // Remove trailing newline + let input = line.trim_end().to_string(); + + // Add to history + if !input.is_empty() { + history.push(input.clone()); + } + + return Ok(Some(input)); + }, + Err(e) => { + return Err(ReadlineError::Io(e)); + } + } + } + + // Check for injected input periodically + _ = tokio::time::sleep(Duration::from_millis(50)) => { + if let Ok(injected) = injection_receiver.try_recv() { + return Ok(Some(injected)); + } + // Continue the loop to check again + } + } + } + } else { + // Fallback if no injection receiver + let stdin = tokio::io::stdin(); + let mut reader = BufReader::new(stdin); + let mut line = String::new(); + + match reader.read_line(&mut line).await { + Ok(0) => Ok(None), + Ok(_) => { + let input = line.trim_end().to_string(); + if !input.is_empty() { + history.push(input.clone()); + } + Ok(Some(input)) + }, + Err(e) => Err(ReadlineError::Io(e)), + } + } + }, + _ => { + // For non-async variants, fall back to sync behavior + self.read_line(prompt) + } + } } pub fn read_line(&mut self, prompt: Option<&str>) -> Result, ReadlineError> { - match &mut self.0 { + // ALWAYS check for injected input first, regardless of prompt + if let Some(ref injection_receiver) = self.injection_receiver { + while let Ok(injected) = injection_receiver.try_recv() { + self.injected_input.push_back(injected); + } + } + + // If we have injected input, return it immediately + if let Some(injected) = self.injected_input.pop_front() { + return Ok(Some(injected)); + } + + // Only fall back to normal input reading if no injected input is available + match &mut self.inner { inner::Inner::Readline(rl) => { let prompt = prompt.unwrap_or_default(); let curr_line = rl.readline(prompt); @@ -94,13 +343,48 @@ impl InputSource { *index += 1; Ok(lines.get(*index - 1).cloned()) }, + inner::Inner::Threaded { input_receiver, input_sender, .. } => { + // Check for injected input first (non-blocking) + if let Some(injected) = self.injected_input.pop_front() { + return Ok(Some(injected)); + } + + // Send prompt to the input thread + let prompt = prompt.unwrap_or_default(); + if input_sender.send(prompt.to_string()).is_err() { + return Ok(None); // Thread disconnected + } + + // Wait for input with periodic checks for injected input + loop { + match input_receiver.recv_timeout(Duration::from_millis(100)) { + Ok(result) => return result, + Err(mpsc::RecvTimeoutError::Timeout) => { + // Check again for injected input after timeout + if let Some(injected) = self.injected_input.pop_front() { + return Ok(Some(injected)); + } + // Continue waiting + }, + Err(mpsc::RecvTimeoutError::Disconnected) => return Ok(None), + } + } + }, + inner::Inner::TokioStdin { .. } => { + // For async readline in sync context, we can't use async operations + // Return an error to indicate this should use the async method instead + Err(ReadlineError::Io(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "AsyncReadline requires async context - use read_line_async instead" + ))) + }, } } // We're keeping this method for potential future use #[allow(dead_code)] pub fn set_buffer(&mut self, content: &str) { - if let inner::Inner::Readline(rl) = &mut self.0 { + if let inner::Inner::Readline(rl) = &mut self.inner { // Add to history so user can access it with up arrow let _ = rl.add_history_entry(content); } @@ -123,4 +407,39 @@ mod tests { assert_eq!(input.read_line(None).unwrap().unwrap(), l3); assert!(input.read_line(None).unwrap().is_none()); } + + #[test] + fn test_input_injection() { + let mut input = InputSource::new_mock(vec!["original".to_string()]); + + // Inject some input + input.inject_input("injected1".to_string()); + input.inject_input("injected2".to_string()); + + // Injected input should be returned first + assert_eq!(input.read_line(None).unwrap().unwrap(), "injected1"); + assert_eq!(input.read_line(None).unwrap().unwrap(), "injected2"); + + // Then original input + assert_eq!(input.read_line(None).unwrap().unwrap(), "original"); + assert!(input.read_line(None).unwrap().is_none()); + } + + #[test] + fn test_input_injection_priority() { + let mut input = InputSource::new_mock(vec!["mock1".to_string(), "mock2".to_string()]); + + // Read one mock input + assert_eq!(input.read_line(None).unwrap().unwrap(), "mock1"); + + // Inject input - should take priority over remaining mock input + input.inject_input("priority".to_string()); + + // Injected input should come first + assert_eq!(input.read_line(None).unwrap().unwrap(), "priority"); + + // Then remaining mock input + assert_eq!(input.read_line(None).unwrap().unwrap(), "mock2"); + assert!(input.read_line(None).unwrap().is_none()); + } } diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 8148562db..ac81fa0c1 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -18,6 +18,9 @@ pub mod tool_manager; pub mod tools; pub mod util; +mod api; +mod api_chat_session; + use std::borrow::Cow; use std::collections::{ HashMap, @@ -195,6 +198,9 @@ pub struct ChatArgs { /// Whether the command should run without expecting user input #[arg(long, alias = "non-interactive")] pub no_interactive: bool, + /// Enable API mode with socket communication + #[arg(long)] + pub api: bool, /// The first question to ask pub input: Option, } @@ -317,6 +323,65 @@ impl ChatArgs { .await?; let tool_config = tool_manager.load_tools(os, &mut stderr).await?; + // Initialize API mode components if --api flag is provided + let (lifecycle_manager, message_router, input_source) = if self.api { + // Create socket lifecycle manager + let working_directory = os.env.current_dir()?; + let mut lifecycle_manager = api::SocketLifecycleManager::new(working_directory)?; + + // Initialize and create sockets + lifecycle_manager.initialize().await?; + lifecycle_manager.create_sockets_with_lifecycle_management().await?; + + // Create message router with socket manager + let socket_manager = lifecycle_manager.socket_manager(); + + // Create shared broadcast channels for all socket types + let mut shared_broadcasters = std::collections::HashMap::new(); + for socket_type in api::SocketType::all() { + let (sender, _) = tokio::sync::broadcast::channel(1000); + shared_broadcasters.insert(socket_type, sender); + } + + // Create InputSource with async injection capability for API mode + let (input_source, input_injection_sender) = InputSource::new_async_with_injection(os, prompt_request_sender, prompt_response_receiver)?; + + // Create the connection handler and message router with shared broadcasters + let mut connection_handler = api::ConnectionHandler::new( + Arc::clone(&socket_manager), + shared_broadcasters.clone(), + ); + + // Set the input injection sender in the connection handler + connection_handler.set_input_injection_sender(input_injection_sender); + + let message_router = api::MessageRouter::new( + Arc::clone(&socket_manager), + shared_broadcasters, + lifecycle_manager.session_id().to_string(), + ); + + // Spawn the connection acceptance task + tokio::spawn(async move { + match connection_handler.start_accepting_connections().await { + Ok(_guard) => { + // Keep the guard alive by waiting for a signal that never comes + // The guard will be dropped when the program exits, triggering cleanup + std::future::pending::<()>().await; + } + Err(e) => { + eprintln!("Warning: Connection handler failed to start: {}", e); + } + } + }); + + (Some(lifecycle_manager), Some(Arc::new(message_router)), input_source) + } else { + // Normal mode: create regular InputSource + let input_source = InputSource::new(os, prompt_request_sender, prompt_response_receiver)?; + (None, None, input_source) + }; + ChatSession::new( os, stdout, @@ -324,13 +389,15 @@ impl ChatArgs { &conversation_id, agents, input, - InputSource::new(os, prompt_request_sender, prompt_response_receiver)?, + input_source, self.resume, || terminal::window_size().map(|s| s.columns.into()).ok(), tool_manager, model_id, tool_config, !self.no_interactive, + lifecycle_manager, + message_router, ) .await? .spawn(os) @@ -533,6 +600,10 @@ pub struct ChatSession { pending_prompts: VecDeque, interactive: bool, inner: Option, + /// Socket lifecycle manager for API mode (optional) + lifecycle_manager: Option, + /// Message router for API mode socket communication (optional) + message_router: Option>, ctrlc_rx: broadcast::Receiver<()>, } @@ -552,6 +623,8 @@ impl ChatSession { model_id: Option, tool_config: HashMap, interactive: bool, + lifecycle_manager: Option, + message_router: Option>, ) -> Result { let valid_model_id = match model_id { Some(id) => id, @@ -653,11 +726,19 @@ impl ChatSession { pending_prompts: VecDeque::new(), interactive, inner: Some(ChatState::default()), + lifecycle_manager, + message_router, ctrlc_rx, }) } pub async fn next(&mut self, os: &mut Os) -> Result<(), ChatError> { + // Check for injected input at the beginning of each iteration + if let Some(injected_input) = self.input_source.check_for_injected_input() { + self.inner = Some(ChatState::HandleInput { input: injected_input }); + return Ok(()); + } + // Update conversation state with new tool information self.conversation.update_state(false).await; @@ -1105,6 +1186,15 @@ impl ChatSession { } async fn spawn(&mut self, os: &mut Os) -> Result<()> { + // Run the terminal loop (with shutdown handling if API mode) + self.run_terminal_loop(os).await?; + + Ok(()) + } + + /// Run the main terminal loop (extracted from existing spawn logic) + async fn run_terminal_loop(&mut self, os: &mut Os) -> Result<()> { + let is_small_screen = self.terminal_width() < GREETING_BREAK_POINT; if os .database @@ -1158,6 +1248,11 @@ impl ChatSession { execute!(self.stderr, style::Print("\n"), style::SetForegroundColor(Color::Reset))?; } + // Display socket information table if API mode is enabled + if self.is_api_mode() { + self.display_socket_info_table()?; + } + if self.all_tools_trusted() { queue!( self.stderr, @@ -1186,12 +1281,182 @@ impl ChatSession { } while !matches!(self.inner, Some(ChatState::Exit)) { - self.next(os).await?; + match self.next(os).await { + Ok(()) => {} + Err(e) => { + return Err(e.into()); + } + } } Ok(()) } + /// Check if API mode is enabled + fn is_api_mode(&self) -> bool { + self.message_router.is_some() + } + + /// Display socket information table for API mode + fn display_socket_info_table(&mut self) -> Result<(), ChatError> { + if let Some(ref lifecycle_manager) = self.lifecycle_manager { + let socket_manager = lifecycle_manager.socket_manager(); + let socket_manager_guard = socket_manager.lock().map_err(|e| { + ChatError::Custom(format!("Failed to lock socket manager: {}", e).into()) + })?; + + let socket_paths = socket_manager_guard.get_all_socket_paths(); + + // Check if we have any sockets to display + if socket_paths.is_empty() { + // Display warning if no sockets were created + execute!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ Warning: No API sockets were successfully created\n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" API mode functionality will be limited\n"), + style::SetForegroundColor(Color::Reset), + style::Print("\n") + )?; + return Ok(()); + } + + // Check for partial failures (expected 4 socket types - Control and Events suppressed) + let expected_socket_types = [ + // api::SocketType::Control, // Suppressed - control socket management disabled + api::SocketType::Input, + api::SocketType::Output, + api::SocketType::Thinking, + api::SocketType::Tools, + // api::SocketType::Events, // Suppressed - events socket management disabled + ]; + + let missing_sockets: Vec<_> = expected_socket_types + .iter() + .filter(|socket_type| !socket_paths.contains_key(socket_type)) + .collect(); + + // Display table header + execute!( + self.stderr, + style::SetForegroundColor(Color::Cyan), + style::Print("🔌 API Mode Socket Information\n"), + style::SetForegroundColor(Color::Reset) + )?; + + // Show warning for partial failures if any sockets are missing + if !missing_sockets.is_empty() { + execute!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print(format!("⚠️ Warning: {} of {} sockets failed to create\n", + missing_sockets.len(), expected_socket_types.len())), + style::SetForegroundColor(Color::Reset) + )?; + } + + // Table border and headers + execute!( + self.stderr, + style::Print("┌─────────────┬─────────────────────────────────────────────────────────────────┐\n"), + style::Print("│ Socket Type │ Path │\n"), + style::Print("├─────────────┼─────────────────────────────────────────────────────────────────┤\n") + )?; + + // Display successfully created sockets + for (socket_type, path) in socket_paths.iter() { + let socket_name = match socket_type { + api::SocketType::Control => "control", + api::SocketType::Input => "input", + api::SocketType::Output => "output", + api::SocketType::Thinking => "thinking", + api::SocketType::Tools => "tools", + api::SocketType::Events => "events", + }; + + let path_str = path.display().to_string(); + let truncated_path = if path_str.len() > 63 { + format!("...{}", &path_str[path_str.len() - 60..]) + } else { + path_str + }; + + execute!( + self.stderr, + style::Print(format!("│ {:11} │ {:63} │\n", socket_name, truncated_path)) + )?; + } + + // Display failed sockets if any + if !missing_sockets.is_empty() { + for socket_type in &missing_sockets { + let socket_name = match socket_type { + api::SocketType::Control => "control", + api::SocketType::Input => "input", + api::SocketType::Output => "output", + api::SocketType::Thinking => "thinking", + api::SocketType::Tools => "tools", + api::SocketType::Events => "events", + }; + + execute!( + self.stderr, + style::SetForegroundColor(Color::Red), + style::Print(format!("│ {:11} │ {:63} │\n", socket_name, "FAILED TO CREATE")), + style::SetForegroundColor(Color::Reset) + )?; + } + } + + // Table footer + execute!( + self.stderr, + style::Print("└─────────────┴─────────────────────────────────────────────────────────────────┘\n") + )?; + + // Display session information + execute!( + self.stderr, + style::SetForegroundColor(Color::DarkGrey), + style::Print(format!("Session ID: {}\n", lifecycle_manager.session_id())), + style::Print(format!("Socket Directory: {}\n", lifecycle_manager.socket_directory().display())), + style::SetForegroundColor(Color::Reset), + style::Print("\n") + )?; + + // Display additional warnings for missing critical sockets + // Control socket is intentionally suppressed - no warning needed + /* + if missing_sockets.contains(&&api::SocketType::Control) { + execute!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ Control socket failed - remote shutdown commands will not work\n"), + style::SetForegroundColor(Color::Reset) + )?; + } + */ + + if missing_sockets.contains(&&api::SocketType::Input) { + execute!( + self.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("⚠️ Input socket failed - programmatic input will not work\n"), + style::SetForegroundColor(Color::Reset) + )?; + } + + if !missing_sockets.is_empty() { + execute!( + self.stderr, + style::Print("\n") + )?; + } + } + Ok(()) + } + /// Compacts the conversation history using the strategy specified by [CompactStrategy], /// replacing the history with a summary generated by the model. /// @@ -1551,12 +1816,13 @@ impl ChatSession { style::SetAttribute(Attribute::Reset) )?; let prompt = self.generate_tool_trust_prompt(); - let user_input = match self.read_user_input(&prompt, false) { + let user_input = match self.read_user_input(&prompt, false).await { Some(input) => input, None => return Ok(ChatState::Exit), }; self.conversation.append_user_transcript(&user_input); + Ok(ChatState::HandleInput { input: user_input }) } @@ -1852,12 +2118,32 @@ impl ChatSession { let mut image_blocks: Vec = Vec::new(); for tool in &self.tool_uses { + // Clone values before mutable borrow to avoid borrowing conflicts + let message_router = self.message_router.clone(); + let is_api_mode = self.is_api_mode(); + let tool_start = std::time::Instant::now(); let mut tool_telemetry = self.tool_use_telemetry_events.entry(tool.id.clone()); tool_telemetry = tool_telemetry.and_modify(|ev| { ev.is_accepted = true; }); + // Route tool request to tools socket if API mode is enabled + if is_api_mode { + if let Some(ref message_router) = message_router { + // Create a basic tool parameters object + let tool_params = serde_json::json!({ + "tool_name": tool.name, + "tool_id": tool.id, + "description": tool.tool.display_name() + }); + + if let Err(e) = message_router.route_tool_request(&tool.name, tool_params, &tool.id).await { + eprintln!("Warning: Failed to route tool request to socket: {}", e); + } + } + } + let invoke_result = tool.tool.invoke(os, &mut self.stdout).await; if self.spinner.is_some() { @@ -1915,6 +2201,30 @@ impl ChatSession { style::Print("\n\n"), )?; + // Route successful tool response to tools socket if API mode is enabled + if is_api_mode { + if let Some(ref message_router) = message_router { + let result_data = match &result.output { + OutputKind::Text(text) => serde_json::json!({"type": "text", "content": text}), + OutputKind::Json(json) => serde_json::json!({"type": "json", "content": json}), + OutputKind::Images(images) => serde_json::json!({"type": "images", "count": images.len()}), + OutputKind::Mixed { text, images } => serde_json::json!({ + "type": "mixed", + "text": text, + "image_count": images.len() + }), + }; + + if let Err(_e) = message_router.route_tool_response( + &tool.id, + result_data, + crate::cli::chat::api::protocol::ToolStatus::Success + ).await { + // socket is not open + } + } + } + tool_telemetry = tool_telemetry.and_modify(|ev| ev.is_success = Some(true)); if let Tool::Custom(_) = &tool.tool { tool_telemetry @@ -1942,6 +2252,24 @@ impl ChatSession { style::Print("\n\n"), )?; + // Route failed tool response to tools socket if API mode is enabled + if is_api_mode { + if let Some(ref message_router) = message_router { + let error_data = serde_json::json!({ + "type": "error", + "error": err.to_string() + }); + + if let Err(e) = message_router.route_tool_response( + &tool.id, + error_data, + crate::cli::chat::api::protocol::ToolStatus::Error + ).await { + eprintln!("Warning: Failed to route tool error response to socket: {}", e); + } + } + } + tool_telemetry.and_modify(|ev| { ev.is_success = Some(false); ev.reason_desc = Some(err.to_string()); @@ -2057,6 +2385,16 @@ impl ChatSession { response_prefix_printed = true; } buf.push_str(&text); + + // Route text to output socket in real-time if API mode is enabled + if self.is_api_mode() && !text.trim().is_empty() { + if let Some(ref message_router) = self.message_router { + // Send the new text chunk to output socket immediately + if let Err(_e) = message_router.route_output_message(&text, false).await { + // Silently ignore output socket routing errors + } + } + } }, parser::ResponseEvent::ToolUse(tool_use) => { if self.spinner.is_some() { @@ -2485,10 +2823,17 @@ impl ChatSession { } /// Helper function to read user input with a prompt and Ctrl+C handling - fn read_user_input(&mut self, prompt: &str, exit_on_single_ctrl_c: bool) -> Option { + async fn read_user_input(&mut self, prompt: &str, exit_on_single_ctrl_c: bool) -> Option { let mut ctrl_c = false; loop { - match (self.input_source.read_line(Some(prompt)), ctrl_c) { + // Check if we have an AsyncReadline variant and use async method + let read_result = if self.input_source.is_async() { + self.input_source.read_line_async(Some(prompt)).await + } else { + self.input_source.read_line(Some(prompt)) + }; + + match (read_result, ctrl_c) { (Ok(Some(line)), _) => { if line.trim().is_empty() { continue; // Reprompt if the input is empty @@ -2886,6 +3231,8 @@ mod tests { None, tool_config, true, + None, // No lifecycle manager for tests + None, // No message router for tests ) .await .unwrap() @@ -3027,6 +3374,8 @@ mod tests { None, tool_config, true, + None, // No lifecycle manager for tests + None, // No message router for tests ) .await .unwrap() @@ -3123,6 +3472,8 @@ mod tests { None, tool_config, true, + None, // No lifecycle manager for tests + None, // No message router for tests ) .await .unwrap() @@ -3197,6 +3548,8 @@ mod tests { None, tool_config, true, + None, // No lifecycle manager for tests + None, // No message router for tests ) .await .unwrap() @@ -3247,6 +3600,8 @@ mod tests { None, tool_config, true, + None, // No lifecycle manager for tests + None, // No message router for tests ) .await .unwrap() diff --git a/crates/chat-cli/src/cli/mod.rs b/crates/chat-cli/src/cli/mod.rs index c51e5df3e..b5b96f0ad 100644 --- a/crates/chat-cli/src/cli/mod.rs +++ b/crates/chat-cli/src/cli/mod.rs @@ -362,6 +362,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: false, + api: false, })), verbose: 2, help_all: false, @@ -401,6 +402,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: false, + api: false, }) ); } @@ -417,6 +419,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: false, + api: false, }) ); } @@ -433,6 +436,7 @@ mod test { trust_all_tools: true, trust_tools: None, no_interactive: false, + api: false, }) ); } @@ -449,6 +453,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: true, + api: false, }) ); assert_parse!( @@ -461,6 +466,7 @@ mod test { trust_all_tools: false, trust_tools: None, no_interactive: true, + api: false, }) ); } @@ -477,6 +483,7 @@ mod test { trust_all_tools: true, trust_tools: None, no_interactive: false, + api: false, }) ); } @@ -493,6 +500,7 @@ mod test { trust_all_tools: false, trust_tools: Some(vec!["".to_string()]), no_interactive: false, + api: false, }) ); } @@ -509,6 +517,7 @@ mod test { trust_all_tools: false, trust_tools: Some(vec!["fs_read".to_string(), "fs_write".to_string()]), no_interactive: false, + api: false, }) ); }