diff --git a/crates/chat-cli/src/cli/agent.rs b/crates/chat-cli/src/cli/agent.rs new file mode 100644 index 0000000000..9dc3c18200 --- /dev/null +++ b/crates/chat-cli/src/cli/agent.rs @@ -0,0 +1,1082 @@ +#![allow(dead_code)] + +use std::borrow::Borrow; +use std::collections::{ + HashMap, + HashSet, +}; +use std::ffi::OsStr; +use std::io::{ + self, + Write, +}; +use std::path::{ + Path, + PathBuf, +}; + +use crossterm::style::Stylize as _; +use crossterm::{ + queue, + style, +}; +use dialoguer::Select; +use eyre::bail; +use regex::Regex; +use serde::{ + Deserialize, + Serialize, +}; +use tokio::fs::ReadDir; +use tracing::{ + error, + info, + warn, +}; + +use super::chat::tools::custom_tool::CustomToolConfig; +use super::chat::tools::{ + DEFAULT_APPROVE, + NATIVE_TOOLS, + ToolOrigin, +}; +use crate::cli::chat::cli::hooks::{ + Hook, + HookTrigger, +}; +use crate::cli::chat::context::ContextConfig; +use crate::database::settings::Setting; +use crate::os::Os; +use crate::util::{ + MCP_SERVER_TOOL_DELIMITER, + directories, +}; + +// This is to mirror claude's config set up +#[derive(Clone, Serialize, Deserialize, Debug, Default, Eq, PartialEq)] +#[serde(rename_all = "camelCase", transparent)] +pub struct McpServerConfig { + pub mcp_servers: HashMap, +} + +impl McpServerConfig { + pub async fn load_from_file(os: &Os, path: impl AsRef) -> eyre::Result { + let contents = os.fs.read(path.as_ref()).await?; + let value = serde_json::from_slice::(&contents)?; + // We need to extract mcp_servers field from the value because we have annotated + // [McpServerConfig] with transparent. Transparent was added because we want to preserve + // the type in agent. + let config = value + .get("mcpServers") + .cloned() + .ok_or(eyre::eyre!("No mcp servers found in config"))?; + Ok(serde_json::from_value(config)?) + } + + pub async fn save_to_file(&self, os: &Os, path: impl AsRef) -> eyre::Result<()> { + let json = serde_json::to_string_pretty(self)?; + os.fs.write(path.as_ref(), json).await?; + Ok(()) + } +} + +/// An [Agent] is a declarative way of configuring a given instance of q chat. Currently, it is +/// impacting q chat in via influenicng [ContextManager] and [ToolManager]. +/// Changes made to [ContextManager] and [ToolManager] do not persist across sessions. +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct Agent { + /// Agent names are derived from the file name. Thus they are skipped for + /// serializing + #[serde(skip)] + pub name: String, + #[serde(default)] + pub description: Option, + #[serde(default)] + pub prompt: Option, + #[serde(default)] + pub mcp_servers: McpServerConfig, + #[serde(default)] + pub tools: Vec, + #[serde(default)] + pub alias: HashMap, + #[serde(default)] + pub allowed_tools: HashSet, + #[serde(default)] + pub included_files: Vec, + #[serde(default)] + pub create_hooks: serde_json::Value, + #[serde(default)] + pub prompt_hooks: serde_json::Value, + #[serde(default)] + pub tools_settings: HashMap, + #[serde(skip)] + pub path: Option, +} + +impl Default for Agent { + fn default() -> Self { + Self { + name: "default".to_string(), + description: Some("Default agent".to_string()), + prompt: Default::default(), + mcp_servers: Default::default(), + tools: NATIVE_TOOLS.iter().copied().map(str::to_string).collect::>(), + alias: Default::default(), + allowed_tools: { + let mut set = HashSet::::new(); + let default_approve = DEFAULT_APPROVE.iter().copied().map(str::to_string); + set.extend(default_approve); + set + }, + included_files: vec!["AmazonQ.md", "README.md", ".amazonq/rules/**/*.md"] + .into_iter() + .map(str::to_string) + .collect::>(), + create_hooks: Default::default(), + prompt_hooks: Default::default(), + tools_settings: Default::default(), + path: None, + } + } +} + +impl Agent { + /// Retrieves an agent by name. It does so via first seeking the given agent under local dir, + /// and falling back to global dir if it does not exist in local. + pub async fn get_agent_by_name(os: &Os, agent_name: &str) -> eyre::Result<(Agent, PathBuf)> { + let config_path: Result = 'config: { + // local first, and then fall back to looking at global + let local_config_dir = directories::chat_local_agent_dir()?.join(agent_name); + if os.fs.exists(&local_config_dir) { + break 'config Ok::(local_config_dir); + } + + let global_config_dir = directories::chat_global_agent_path(os)?.join(format!("{agent_name}.json")); + if os.fs.exists(&global_config_dir) { + break 'config Ok(global_config_dir); + } + + Err(global_config_dir) + }; + + match config_path { + Ok(config_path) => { + let content = os.fs.read(&config_path).await?; + Ok((serde_json::from_slice::(&content)?, config_path)) + }, + Err(global_config_dir) if agent_name == "default" => { + os.fs + .create_dir_all( + global_config_dir + .parent() + .ok_or(eyre::eyre!("Failed to retrieve global agent config parent path"))?, + ) + .await?; + os.fs.create_new(&global_config_dir).await?; + + let default_agent = Agent::default(); + let content = serde_json::to_string_pretty(&default_agent)?; + os.fs.write(&global_config_dir, content.as_bytes()).await?; + + Ok((default_agent, global_config_dir)) + }, + _ => bail!("Agent {agent_name} does not exist"), + } + } +} + +#[derive(Debug)] +pub enum PermissionEvalResult { + Allow, + Ask, + Deny, +} + +#[derive(Clone, Default, Debug)] +pub struct Agents { + pub agents: HashMap, + pub active_idx: String, + pub trust_all_tools: bool, +} + +impl Agents { + /// This function assumes the relevant transformation to the tool names have been done: + /// - model tool name -> host tool name + /// - custom tool namespacing + pub fn trust_tools(&mut self, tool_names: Vec) { + if let Some(agent) = self.get_active_mut() { + agent.allowed_tools.extend(tool_names); + } + } + + /// This function assumes the relevant transformation to the tool names have been done: + /// - model tool name -> host tool name + /// - custom tool namespacing + pub fn untrust_tools(&mut self, tool_names: &[String]) { + if let Some(agent) = self.get_active_mut() { + agent.allowed_tools.retain(|t| !tool_names.contains(t)); + } + } + + pub fn get_active(&self) -> Option<&Agent> { + self.agents.get(&self.active_idx) + } + + pub fn get_active_mut(&mut self) -> Option<&mut Agent> { + self.agents.get_mut(&self.active_idx) + } + + pub fn switch(&mut self, name: &str) -> eyre::Result<&Agent> { + if !self.agents.contains_key(name) { + eyre::bail!("No agent with name {name} found"); + } + self.active_idx = name.to_string(); + self.agents + .get(name) + .ok_or(eyre::eyre!("No agent with name {name} found")) + } + + /// Migrated from [reload_profiles] from context.rs. It loads the active agent from disk and + /// replaces its in-memory counterpart with it. + pub async fn reload_agents(&mut self, os: &mut Os, output: &mut impl Write) -> eyre::Result<()> { + let persona_name = self.get_active().map(|a| a.name.as_str()); + let mut new_self = Self::load(os, persona_name, true, output).await; + std::mem::swap(self, &mut new_self); + Ok(()) + } + + pub fn list_agents(&self) -> eyre::Result> { + Ok(self.agents.keys().cloned().collect::>()) + } + + /// Migrated from [create_profile] from context.rs, which was creating profiles under the + /// global directory. We shall preserve this implicit behavior for now until further notice. + pub async fn create_agent(&mut self, os: &Os, name: &str) -> eyre::Result<()> { + validate_agent_name(name)?; + + let agent_path = directories::chat_global_agent_path(os)?.join(format!("{name}.json")); + if agent_path.exists() { + return Err(eyre::eyre!("Agent '{}' already exists", name)); + } + + let agent = Agent { + name: name.to_string(), + path: Some(agent_path.clone()), + ..Default::default() + }; + let contents = serde_json::to_string_pretty(&agent) + .map_err(|e| eyre::eyre!("Failed to serialize profile configuration: {}", e))?; + + if let Some(parent) = agent_path.parent() { + os.fs.create_dir_all(parent).await?; + } + os.fs.write(&agent_path, contents).await?; + + self.agents.insert(name.to_string(), agent); + + Ok(()) + } + + /// Migrated from [delete_profile] from context.rs, which was deleting profiles under the + /// global directory. We shall preserve this implicit behavior for now until further notice. + pub async fn delete_agent(&mut self, os: &Os, name: &str) -> eyre::Result<()> { + if name == self.active_idx.as_str() { + eyre::bail!("Cannot delete the active agent. Switch to another agent first"); + } + + let to_delete = self + .agents + .get(name) + .ok_or(eyre::eyre!("Agent '{name}' does not exist"))?; + match to_delete.path.as_ref() { + Some(path) if path.exists() => { + os.fs.remove_file(path).await?; + }, + _ => eyre::bail!("Agent {name} does not have an associated path"), + } + + self.agents.remove(name); + + Ok(()) + } + + /// Migrated from [load] from context.rs, which was loading profiles under the + /// local and global directory. We shall preserve this implicit behavior for now until further + /// notice. + /// In addition to loading, this function also calls the function responsible for migrating + /// existing context into agent. + pub async fn load( + os: &mut Os, + mut agent_name: Option<&str>, + skip_migration: bool, + output: &mut impl Write, + ) -> Self { + let (chosen_name, new_agents) = if !skip_migration { + match migrate(os).await { + Ok((i, new_agents)) => (i, new_agents), + Err(e) => { + warn!("Migration did not happen for the following reason: {e}. This is not necessarily an error"); + (None, vec![]) + }, + } + } else { + (None, vec![]) + }; + + if let Some(name) = chosen_name.as_ref() { + agent_name.replace(name.as_str()); + } + + let mut local_agents = 'local: { + let Ok(path) = directories::chat_local_agent_dir() else { + break 'local Vec::::new(); + }; + let Ok(files) = os.fs.read_dir(path).await else { + break 'local Vec::::new(); + }; + load_agents_from_entries(files).await + }; + + let mut global_agents = 'global: { + let Ok(path) = directories::chat_global_agent_path(os) else { + break 'global Vec::::new(); + }; + let files = match os.fs.read_dir(&path).await { + Ok(files) => files, + Err(e) => { + if matches!(e.kind(), io::ErrorKind::NotFound) { + if let Err(e) = os.fs.create_dir_all(&path).await { + error!("Error creating global agent dir: {:?}", e); + } + } + break 'global Vec::::new(); + }, + }; + load_agents_from_entries(files).await + } + .into_iter() + .chain(new_agents) + .collect::>(); + + // Here we also want to make sure the example config is written to disk if it's not already + // there. + 'example_config: { + let Ok(path) = directories::example_agent_config(os) else { + error!("Error obtaining example agent path."); + break 'example_config; + }; + if os.fs.exists(&path) { + break 'example_config; + } + + // At this point the agents dir would have been created. All we have to worry about is + // the creation of the example config + if let Err(e) = os.fs.create_new(&path).await { + error!("Error creating example agent config: {e}."); + break 'example_config; + } + + let example_agent = Agent { + // This is less important than other fields since names are derived from the name + // of the config file and thus will not be persisted + name: "example".to_string(), + description: Some("This is an example agent config (and will not be loaded unless you change it to have .json extension)".to_string()), + tools: { + NATIVE_TOOLS + .iter() + .copied() + .map(str::to_string) + .chain(vec![ + format!("@mcp_server_name{MCP_SERVER_TOOL_DELIMITER}mcp_tool_name"), + "@mcp_server_name_without_tool_specification_to_include_all_tools".to_string(), + ]) + .collect::>() + }, + ..Default::default() + }; + let Ok(content) = serde_json::to_string_pretty(&example_agent) else { + error!("Error serializing example agent config"); + break 'example_config; + }; + if let Err(e) = os.fs.write(&path, &content).await { + error!("Error writing example agent config to file: {e}"); + break 'example_config; + }; + } + + let local_names = local_agents.iter().map(|a| a.name.as_str()).collect::>(); + global_agents.retain(|a| { + // If there is a naming conflict for agents, we would retain the local instance + let name = a.name.as_str(); + if local_names.contains(name) { + let _ = queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("WARNING: "), + style::ResetColor, + style::Print("Agent conflict for "), + style::SetForegroundColor(style::Color::Green), + style::Print(name), + style::ResetColor, + style::Print(". Using workspace version.\n") + ); + false + } else { + true + } + }); + + local_agents.append(&mut global_agents); + + // If we are told which agent to set as active, we will fall back to a default whose + // lifetime matches that of the session + if agent_name.is_none() { + local_agents.push(Agent::default()); + } + + let _ = output.flush(); + + Self { + agents: local_agents + .into_iter() + .map(|a| (a.name.clone(), a)) + .collect::>(), + active_idx: agent_name.unwrap_or("default").to_string(), + ..Default::default() + } + } + + /// Returns a label to describe the permission status for a given tool. + pub fn display_label(&self, tool_name: &str, origin: &ToolOrigin) -> String { + let tool_trusted = self.get_active().is_some_and(|a| { + a.allowed_tools.iter().any(|name| { + // Here the tool names can take the following forms: + // - @{server_name}{delimiter}{tool_name} + // - native_tool_name + name == tool_name + || name.strip_prefix("@").is_some_and(|remainder| { + remainder + .split_once(MCP_SERVER_TOOL_DELIMITER) + .is_some_and(|(_left, right)| right == tool_name) + || remainder == >::borrow(origin) + }) + }) + }); + + if tool_trusted || self.trust_all_tools { + format!("* {}", "trusted".dark_green().bold()) + } else { + self.default_permission_label(tool_name) + } + } + + /// Provide default permission labels for the built-in set of tools. + // This "static" way avoids needing to construct a tool instance. + fn default_permission_label(&self, tool_name: &str) -> String { + let label = match tool_name { + "fs_read" => "trusted".dark_green().bold(), + "fs_write" => "not trusted".dark_grey(), + #[cfg(not(windows))] + "execute_bash" => "trust read-only commands".dark_grey(), + #[cfg(windows)] + "execute_cmd" => "trust read-only commands".dark_grey(), + "use_aws" => "trust read-only commands".dark_grey(), + "report_issue" => "trusted".dark_green().bold(), + "thinking" => "trusted (prerelease)".dark_green().bold(), + _ if self.trust_all_tools => "trusted".dark_grey().bold(), + _ => "not trusted".dark_grey(), + }; + + format!("{} {label}", "*".reset()) + } +} + +struct ContextMigrate { + legacy_global_context: Option, + legacy_profiles: HashMap, + mcp_servers: Option, + new_agents: Vec, +} + +impl ContextMigrate<'a'> { + async fn scan(os: &Os) -> eyre::Result> { + let legacy_global_context_path = directories::chat_global_context_path(os)?; + let legacy_global_context: Option = 'global: { + let Ok(content) = os.fs.read(&legacy_global_context_path).await else { + break 'global None; + }; + serde_json::from_slice::(&content).ok() + }; + + let legacy_profile_path = directories::chat_profiles_dir(os)?; + let legacy_profiles: HashMap = 'profiles: { + let mut profiles = HashMap::::new(); + let Ok(mut read_dir) = os.fs.read_dir(&legacy_profile_path).await else { + break 'profiles profiles; + }; + + // Here we assume every profile is stored under their own folders + // And that the profile config is in profile_name/context.json + while let Ok(Some(entry)) = read_dir.next_entry().await { + let config_file_path = entry.path().join("context.json"); + if !os.fs.exists(&config_file_path) { + continue; + } + let Some(profile_name) = entry.file_name().to_str().map(|s| s.to_string()) else { + continue; + }; + let Ok(content) = tokio::fs::read_to_string(&config_file_path).await else { + continue; + }; + let Ok(mut context_config) = serde_json::from_str::(content.as_str()) else { + continue; + }; + + // Combine with global context since you can now only choose one agent at a time + // So this is how we make what is previously global available to every new agent migrated + if let Some(context) = legacy_global_context.as_ref() { + context_config.paths.extend(context.paths.clone()); + context_config.hooks.extend(context.hooks.clone()); + } + + profiles.insert(profile_name.clone(), context_config); + } + + profiles + }; + + let mcp_servers = { + let config_path = directories::chat_legacy_mcp_config(os)?; + if os.fs.exists(&config_path) { + match McpServerConfig::load_from_file(os, config_path).await { + Ok(config) => Some(config), + Err(e) => { + error!("Malformed legacy global mcp config detected: {e}. Skipping mcp migration."); + None + }, + } + } else { + None + } + }; + + if legacy_global_context.is_some() || !legacy_profiles.is_empty() { + Ok(ContextMigrate { + legacy_global_context, + legacy_profiles, + mcp_servers, + new_agents: vec![], + }) + } else { + bail!("Nothing to migrate"); + } + } +} + +impl ContextMigrate<'b'> { + async fn prompt_migrate(self) -> eyre::Result> { + let ContextMigrate { + legacy_global_context, + legacy_profiles, + mcp_servers, + new_agents, + } = self; + + let labels = vec!["Yes", "No"]; + let selection: Option<_> = match Select::with_theme(&crate::util::dialoguer_theme()) + .with_prompt("Legacy profiles detected. Would you like to migrate them?") + .items(&labels) + .default(1) + .interact_on_opt(&dialoguer::console::Term::stdout()) + { + Ok(sel) => { + let _ = crossterm::execute!( + std::io::stdout(), + crossterm::style::SetForegroundColor(crossterm::style::Color::Magenta) + ); + sel + }, + // Ctrl‑C -> Err(Interrupted) + Err(dialoguer::Error::IO(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => None, + Err(e) => bail!("Failed to choose an option: {e}"), + }; + + if let Some(0) = selection { + Ok(ContextMigrate { + legacy_global_context, + legacy_profiles, + mcp_servers, + new_agents, + }) + } else { + bail!("Aborting migration") + } + } +} + +impl ContextMigrate<'c'> { + async fn migrate(self, os: &Os) -> eyre::Result> { + const LEGACY_GLOBAL_AGENT_NAME: &str = "migrated_agent_from_global_context"; + const DEFAULT_DESC: &str = "This is an agent migrated from global context"; + const PROFILE_DESC: &str = "This is an agent migrated from profile context"; + + let ContextMigrate { + legacy_global_context, + mut legacy_profiles, + mcp_servers, + mut new_agents, + } = self; + + let has_global_context = legacy_global_context.is_some(); + + // Migration of global context + if let Some(context) = legacy_global_context { + let (create_hooks, prompt_hooks) = + context + .hooks + .into_iter() + .partition::, _>(|(_, hook)| { + matches!(hook.trigger, HookTrigger::ConversationStart) + }); + + new_agents.push(Agent { + name: LEGACY_GLOBAL_AGENT_NAME.to_string(), + description: Some(DEFAULT_DESC.to_string()), + path: Some(directories::chat_global_agent_path(os)?.join(format!("{LEGACY_GLOBAL_AGENT_NAME}.json"))), + included_files: context.paths, + create_hooks: serde_json::to_value(create_hooks).unwrap_or(serde_json::json!({})), + prompt_hooks: serde_json::to_value(prompt_hooks).unwrap_or(serde_json::json!({})), + mcp_servers: mcp_servers.clone().unwrap_or_default(), + ..Default::default() + }); + } + + let global_agent_path = directories::chat_global_agent_path(os)?; + + // Migration of profile context + for (profile_name, context) in legacy_profiles.drain() { + let (create_hooks, prompt_hooks) = + context + .hooks + .into_iter() + .partition::, _>(|(_, hook)| { + matches!(hook.trigger, HookTrigger::ConversationStart) + }); + + new_agents.push(Agent { + path: Some(global_agent_path.join(format!("{profile_name}.json"))), + name: profile_name, + description: Some(PROFILE_DESC.to_string()), + included_files: context.paths, + create_hooks: serde_json::to_value(create_hooks).unwrap_or(serde_json::json!({})), + prompt_hooks: serde_json::to_value(prompt_hooks).unwrap_or(serde_json::json!({})), + mcp_servers: mcp_servers.clone().unwrap_or_default(), + ..Default::default() + }); + } + + if !os.fs.exists(&global_agent_path) { + os.fs.create_dir_all(&global_agent_path).await?; + } + + for agent in &new_agents { + let content = serde_json::to_string_pretty(agent)?; + if let Some(path) = agent.path.as_ref() { + info!("Agent {} peristed in path {}", agent.name, path.to_string_lossy()); + os.fs.write(path, content).await?; + } else { + warn!( + "Agent with name {} does not have path associated and is thus not migrated.", + agent.name + ); + } + } + + let legacy_profile_config_path = directories::chat_profiles_dir(os)?; + let profile_backup_path = legacy_profile_config_path + .parent() + .ok_or(eyre::eyre!("Failed to obtain profile config parent path"))? + .join("profiles.bak"); + os.fs.rename(legacy_profile_config_path, profile_backup_path).await?; + + if has_global_context { + let legacy_global_config_path = directories::chat_global_context_path(os)?; + let legacy_global_config_file_name = legacy_global_config_path + .file_name() + .ok_or(eyre::eyre!("Failed to obtain legacy global config name"))? + .to_string_lossy(); + let global_context_backup_path = legacy_global_config_path + .parent() + .ok_or(eyre::eyre!("Failed to obtain parent path for global context"))? + .join(format!("{}.bak", legacy_global_config_file_name)); + os.fs + .rename(legacy_global_config_path, global_context_backup_path) + .await?; + } + + Ok(ContextMigrate { + legacy_global_context: None, + legacy_profiles, + mcp_servers: None, + new_agents, + }) + } +} + +impl ContextMigrate<'d'> { + async fn prompt_set_default(self, os: &mut Os) -> eyre::Result<(Option, Vec)> { + let ContextMigrate { new_agents, .. } = self; + + let labels = new_agents + .iter() + .map(|a| a.name.as_str()) + .chain(vec!["Let me do this on my own later"]) + .collect::>(); + // This yields 0 if it's negative, which is acceptable. + let later_idx = labels.len().saturating_sub(1); + let selection: Option<_> = match Select::with_theme(&crate::util::dialoguer_theme()) + .with_prompt( + "Set an agent as default. This is the agent that q chat will launch with unless specified otherwise.", + ) + .default(0) + .items(&labels) + .interact_on_opt(&dialoguer::console::Term::stdout()) + { + Ok(sel) => { + let _ = crossterm::execute!( + std::io::stdout(), + crossterm::style::SetForegroundColor(crossterm::style::Color::Magenta) + ); + sel + }, + // Ctrl‑C -> Err(Interrupted) + Err(dialoguer::Error::IO(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => None, + Err(e) => bail!("Failed to choose an option: {e}"), + }; + + let mut agent_to_load = None::; + if let Some(i) = selection { + if later_idx != i { + if let Some(name) = labels.get(i) { + if let Ok(value) = serde_json::to_value(name) { + if os.database.settings.set(Setting::ChatDefaultAgent, value).await.is_ok() { + let chosen_name = (*name).to_string(); + agent_to_load.replace(chosen_name); + } + } + } + } + } + + Ok((agent_to_load, new_agents)) + } +} + +async fn load_agents_from_entries(mut files: ReadDir) -> Vec { + let mut res = Vec::::new(); + while let Ok(Some(file)) = files.next_entry().await { + let file_path = &file.path(); + if file_path + .extension() + .and_then(OsStr::to_str) + .is_some_and(|s| s == "json") + { + let content = match tokio::fs::read(file_path).await { + Ok(content) => content, + Err(e) => { + let file_path = file_path.to_string_lossy(); + tracing::error!("Error reading agent file {file_path}: {:?}", e); + continue; + }, + }; + let mut agent = match serde_json::from_slice::(&content) { + Ok(mut agent) => { + agent.path = Some(file_path.clone()); + agent + }, + Err(e) => { + let file_path = file_path.to_string_lossy(); + tracing::error!("Error deserializing agent file {file_path}: {:?}", e); + continue; + }, + }; + if let Some(name) = Path::new(&file.file_name()).file_stem() { + agent.name = name.to_string_lossy().to_string(); + res.push(agent); + } else { + let file_path = file_path.to_string_lossy(); + tracing::error!("Unable to determine agent name from config file at {file_path}, skipping"); + } + } + } + res +} + +fn validate_agent_name(name: &str) -> eyre::Result<()> { + // Check if name is empty + if name.is_empty() { + eyre::bail!("Agent name cannot be empty"); + } + + // Check if name contains only allowed characters and starts with an alphanumeric character + let re = Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$")?; + if !re.is_match(name) { + eyre::bail!( + "Agent name must start with an alphanumeric character and can only contain alphanumeric characters, hyphens, and underscores" + ); + } + + Ok(()) +} + +async fn migrate(os: &mut Os) -> eyre::Result<(Option, Vec)> { + ContextMigrate::<'a'>::scan(os) + .await? + .prompt_migrate() + .await? + .migrate(os) + .await? + .prompt_set_default(os) + .await +} + +#[cfg(test)] +mod tests { + use super::*; + + struct NullWriter; + + impl Write for NullWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + } + + const INPUT: &str = r#" + { + "description": "My developer agent is used for small development tasks like solving open issues.", + "prompt": "You are a principal developer who uses multiple agents to accomplish difficult engineering tasks", + "mcpServers": { + "fetch": { "command": "fetch3.1", "args": [] }, + "git": { "command": "git-mcp", "args": [] } + }, + "tools": [ + "@git", + "fs_read" + ], + "alias": { + "@gits/some_tool": "some_tool2" + }, + "allowedTools": [ + "fs_read", + "@fetch", + "@gits/git_status" + ], + "includedFiles": [ + "~/my-genai-prompts/unittest.md" + ], + "createHooks": [ + "pwd && tree" + ], + "promptHooks": [ + "git status" + ], + "toolsSettings": { + "fs_write": { "allowedPaths": ["~/**"] }, + "@git/git_status": { "git_user": "$GIT_USER" } + } + } + "#; + + #[test] + fn test_deser() { + let agent = serde_json::from_str::(INPUT).expect("Deserializtion failed"); + assert!(agent.mcp_servers.mcp_servers.contains_key("fetch")); + assert!(agent.mcp_servers.mcp_servers.contains_key("git")); + assert!(agent.alias.contains_key("@gits/some_tool")); + } + + #[test] + fn test_get_active() { + let mut collection = Agents::default(); + assert!(collection.get_active().is_none()); + + let agent = Agent::default(); + collection.agents.insert("default".to_string(), agent); + collection.active_idx = "default".to_string(); + + assert!(collection.get_active().is_some()); + assert_eq!(collection.get_active().unwrap().name, "default"); + } + + #[test] + fn test_get_active_mut() { + let mut collection = Agents::default(); + assert!(collection.get_active_mut().is_none()); + + let agent = Agent::default(); + collection.agents.insert("default".to_string(), agent); + collection.active_idx = "default".to_string(); + + assert!(collection.get_active_mut().is_some()); + let active = collection.get_active_mut().unwrap(); + active.description = Some("Modified description".to_string()); + + assert_eq!( + collection.agents.get("default").unwrap().description, + Some("Modified description".to_string()) + ); + } + + #[test] + fn test_switch() { + let mut collection = Agents::default(); + + let default_agent = Agent::default(); + let dev_agent = Agent { + name: "dev".to_string(), + description: Some("Developer agent".to_string()), + ..Default::default() + }; + + collection.agents.insert("default".to_string(), default_agent); + collection.agents.insert("dev".to_string(), dev_agent); + collection.active_idx = "default".to_string(); + + // Test successful switch + let result = collection.switch("dev"); + assert!(result.is_ok()); + assert_eq!(result.unwrap().name, "dev"); + + // Test switch to non-existent agent + let result = collection.switch("nonexistent"); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "No agent with name nonexistent found"); + } + + #[tokio::test] + async fn test_list_agents() { + let mut collection = Agents::default(); + + // Add two agents + let default_agent = Agent::default(); + let dev_agent = Agent { + name: "dev".to_string(), + description: Some("Developer agent".to_string()), + ..Default::default() + }; + + collection.agents.insert("default".to_string(), default_agent); + collection.agents.insert("dev".to_string(), dev_agent); + + let result = collection.list_agents(); + assert!(result.is_ok()); + + let agents = result.unwrap(); + assert_eq!(agents.len(), 2); + assert!(agents.contains(&"default".to_string())); + assert!(agents.contains(&"dev".to_string())); + } + + #[tokio::test] + async fn test_create_agent() { + let mut collection = Agents::default(); + let ctx = Os::new().await.unwrap(); + + let agent_name = "test_agent"; + let result = collection.create_agent(&ctx, agent_name).await; + assert!(result.is_ok()); + let agent_path = directories::chat_global_agent_path(&ctx) + .expect("Error obtaining global agent path") + .join(format!("{agent_name}.json")); + assert!(agent_path.exists()); + assert!(collection.agents.contains_key(agent_name)); + + // Test with creating a agent with the same name + let result = collection.create_agent(&ctx, agent_name).await; + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + format!("Agent '{agent_name}' already exists") + ); + + // Test invalid agent names + let result = collection.create_agent(&ctx, "").await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "Agent name cannot be empty"); + + let result = collection.create_agent(&ctx, "123-invalid!").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_delete_agent() { + let mut collection = Agents::default(); + let ctx = Os::new().await.unwrap(); + + let agent_name_one = "test_agent_one"; + collection + .create_agent(&ctx, agent_name_one) + .await + .expect("Failed to create agent"); + let agent_name_two = "test_agent_two"; + collection + .create_agent(&ctx, agent_name_two) + .await + .expect("Failed to create agent"); + + collection.switch(agent_name_one).expect("Failed to switch agent"); + + // Should not be able to delete active agent + let active = collection + .get_active() + .expect("Failed to obtain active agent") + .name + .clone(); + let result = collection.delete_agent(&ctx, &active).await; + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "Cannot delete the active agent. Switch to another agent first" + ); + + // Should be able to delete inactive agent + let agent_two_path = collection + .agents + .get(agent_name_two) + .expect("Failed to obtain agent that's yet to be deleted") + .path + .clone() + .expect("agent should have path"); + let result = collection.delete_agent(&ctx, agent_name_two).await; + assert!(result.is_ok()); + assert!(!collection.agents.contains_key(agent_name_two)); + assert!(!agent_two_path.exists()); + + let result = collection.delete_agent(&ctx, "nonexistent").await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "Agent 'nonexistent' does not exist"); + } + + #[test] + fn test_validate_agent_name() { + // Valid names + assert!(validate_agent_name("valid").is_ok()); + assert!(validate_agent_name("valid123").is_ok()); + assert!(validate_agent_name("valid-name").is_ok()); + assert!(validate_agent_name("valid_name").is_ok()); + assert!(validate_agent_name("123valid").is_ok()); + + // Invalid names + assert!(validate_agent_name("").is_err()); + assert!(validate_agent_name("-invalid").is_err()); + assert!(validate_agent_name("_invalid").is_err()); + assert!(validate_agent_name("invalid!").is_err()); + assert!(validate_agent_name("invalid space").is_err()); + } +} diff --git a/crates/chat-cli/src/cli/chat/cli/context.rs b/crates/chat-cli/src/cli/chat/cli/context.rs index 0fedc8c019..a0ed495749 100644 --- a/crates/chat-cli/src/cli/chat/cli/context.rs +++ b/crates/chat-cli/src/cli/chat/cli/context.rs @@ -49,9 +49,6 @@ pub enum ContextSubcommand { }, /// Add context rules (filenames or glob patterns) Add { - /// Add to global rules (available in all profiles) - #[arg(short, long)] - global: bool, /// Include even if matched files exceed size limits #[arg(short, long)] force: bool, @@ -61,18 +58,11 @@ pub enum ContextSubcommand { /// Remove specified rules from current profile #[command(alias = "rm")] Remove { - /// Remove specified rules globally - #[arg(short, long)] - global: bool, #[arg(required = true)] paths: Vec, }, /// Remove all rules from current profile - Clear { - /// Remove global rules - #[arg(short, long)] - global: bool, - }, + Clear, #[command(hide = true)] Hooks, } @@ -94,66 +84,7 @@ impl ContextSubcommand { match self { Self::Show { expand } => { - // Display global context - execute!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print("\n🌍 global:\n"), - style::SetAttribute(Attribute::Reset), - )?; - let mut global_context_files = HashSet::new(); - let mut profile_context_files = HashSet::new(); - if context_manager.global_config.paths.is_empty() { - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print(" \n"), - style::SetForegroundColor(Color::Reset) - )?; - } else { - for path in &context_manager.global_config.paths { - execute!(session.stderr, style::Print(format!(" {} ", path)))?; - if let Ok(context_files) = context_manager.get_context_files_by_path(os, path).await { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!( - "({} match{})", - context_files.len(), - if context_files.len() == 1 { "" } else { "es" } - )), - style::SetForegroundColor(Color::Reset) - )?; - global_context_files.extend(context_files); - } - execute!(session.stderr, style::Print("\n"))?; - } - } - - if expand { - execute!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::DarkYellow), - style::Print("\n 🔧 Hooks:\n") - )?; - print_hook_section( - &mut session.stderr, - &context_manager.global_config.hooks, - HookTrigger::ConversationStart, - ) - .map_err(map_chat_error)?; - - print_hook_section( - &mut session.stderr, - &context_manager.global_config.hooks, - HookTrigger::PerPrompt, - ) - .map_err(map_chat_error)?; - } - - // Display profile context + let profile_context_files = HashSet::<(String, String)>::new(); execute!( session.stderr, style::SetAttribute(Attribute::Bold), @@ -183,7 +114,6 @@ impl ContextSubcommand { )), style::SetForegroundColor(Color::Reset) )?; - profile_context_files.extend(context_files); } execute!(session.stderr, style::Print("\n"))?; } @@ -212,7 +142,7 @@ impl ContextSubcommand { execute!(session.stderr, style::Print("\n"))?; } - if global_context_files.is_empty() && profile_context_files.is_empty() { + if profile_context_files.is_empty() { execute!( session.stderr, style::SetForegroundColor(Color::DarkGrey), @@ -220,15 +150,11 @@ impl ContextSubcommand { style::SetForegroundColor(Color::Reset) )?; } else { - let total = global_context_files.len() + profile_context_files.len(); - let total_tokens = global_context_files + let total = profile_context_files.len(); + let total_tokens = profile_context_files .iter() .map(|(_, content)| TokenCounter::count_tokens(content)) - .sum::() - + profile_context_files - .iter() - .map(|(_, content)| TokenCounter::count_tokens(content)) - .sum::(); + .sum::(); execute!( session.stderr, style::SetForegroundColor(Color::Green), @@ -242,25 +168,6 @@ impl ContextSubcommand { style::SetAttribute(Attribute::Reset) )?; - for (filename, content) in &global_context_files { - let est_tokens = TokenCounter::count_tokens(content); - execute!( - session.stderr, - style::Print(format!("🌍 {} ", filename)), - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("(~{} tkns)\n", est_tokens)), - style::SetForegroundColor(Color::Reset), - )?; - if expand { - execute!( - session.stderr, - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("{}\n\n", content)), - style::SetForegroundColor(Color::Reset) - )?; - } - } - for (filename, content) in &profile_context_files { let est_tokens = TokenCounter::count_tokens(content); execute!( @@ -284,13 +191,8 @@ impl ContextSubcommand { execute!(session.stderr, style::Print(format!("{}\n\n", "▔".repeat(3))),)?; } - let mut combined_files: Vec<(String, String)> = global_context_files - .iter() - .chain(profile_context_files.iter()) - .cloned() - .collect(); - - let dropped_files = drop_matched_context_files(&mut combined_files, CONTEXT_FILES_MAX_SIZE).ok(); + let mut files_as_vec = profile_context_files.iter().cloned().collect::>(); + let dropped_files = drop_matched_context_files(&mut files_as_vec, CONTEXT_FILES_MAX_SIZE).ok(); execute!( session.stderr, @@ -357,38 +259,12 @@ impl ContextSubcommand { } } }, - Self::Add { global, force, paths } => { - match context_manager.add_paths(os, paths.clone(), global, force).await { - Ok(_) => { - let target = if global { "global" } else { "profile" }; - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nAdded {} path(s) to {} context.\n\n", paths.len(), target)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError: {}\n\n", e)), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - }, - Self::Remove { global, paths } => match context_manager.remove_paths(os, paths.clone(), global).await { + Self::Add { force, paths } => match context_manager.add_paths(os, paths.clone(), force).await { Ok(_) => { - let target = if global { "global" } else { "profile" }; execute!( session.stderr, style::SetForegroundColor(Color::Green), - style::Print(format!( - "\nRemoved {} path(s) from {} context.\n\n", - paths.len(), - target - )), + style::Print(format!("\nAdded {} path(s) to context.\n\n", paths.len())), style::SetForegroundColor(Color::Reset) )?; }, @@ -401,17 +277,12 @@ impl ContextSubcommand { )?; }, }, - Self::Clear { global } => match context_manager.clear(os, global).await { + Self::Remove { paths } => match context_manager.remove_paths(paths.clone()) { Ok(_) => { - let target = if global { - "global".to_string() - } else { - format!("profile '{}'", context_manager.current_profile) - }; execute!( session.stderr, style::SetForegroundColor(Color::Green), - style::Print(format!("\nCleared context for {}\n\n", target)), + style::Print(format!("\nRemoved {} path(s) from context.\n\n", paths.len(),)), style::SetForegroundColor(Color::Reset) )?; }, @@ -424,6 +295,15 @@ impl ContextSubcommand { )?; }, }, + Self::Clear => { + context_manager.clear(); + execute!( + session.stderr, + style::SetForegroundColor(Color::Green), + style::Print("\nCleared context\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + }, Self::Hooks => { execute!( session.stderr, diff --git a/crates/chat-cli/src/cli/chat/cli/hooks.rs b/crates/chat-cli/src/cli/chat/cli/hooks.rs index 339935c45a..58dc7c348a 100644 --- a/crates/chat-cli/src/cli/chat/cli/hooks.rs +++ b/crates/chat-cli/src/cli/chat/cli/hooks.rs @@ -415,27 +415,6 @@ impl HooksArgs { }); }; - queue!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print("\n🌍 global:\n"), - style::SetAttribute(Attribute::Reset), - )?; - - print_hook_section( - &mut session.stderr, - &context_manager.global_config.hooks, - HookTrigger::ConversationStart, - ) - .map_err(map_chat_error)?; - print_hook_section( - &mut session.stderr, - &context_manager.global_config.hooks, - HookTrigger::PerPrompt, - ) - .map_err(map_chat_error)?; - queue!( session.stderr, style::SetAttribute(Attribute::Bold), @@ -484,83 +463,53 @@ pub enum HooksSubcommand { /// Shell command to execute #[arg(long, value_parser = clap::value_parser!(String))] command: String, - /// Add to global hooks - #[arg(long)] - global: bool, }, /// Remove an existing context hook #[command(name = "rm")] Remove { /// The name of the hook name: String, - /// Remove from global hooks - #[arg(long)] - global: bool, }, /// Enable an existing context hook Enable { /// The name of the hook name: String, - /// Enable in global hooks - #[arg(long)] - global: bool, }, /// Disable an existing context hook Disable { /// The name of the hook name: String, - /// Disable in global hooks - #[arg(long)] - global: bool, }, /// Enable all existing context hooks - EnableAll { - /// Enable all in global hooks - #[arg(long)] - global: bool, - }, + EnableAll, /// Disable all existing context hooks - DisableAll { - /// Disable all in global hooks - #[arg(long)] - global: bool, - }, + DisableAll, /// Display the context rule configuration and matched files Show, } impl HooksSubcommand { - pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { + pub async fn execute(self, _os: &Os, session: &mut ChatSession) -> Result { let Some(context_manager) = &mut session.conversation.context_manager else { return Ok(ChatState::PromptUser { skip_printing_tools: true, }); }; - let scope = |g: bool| if g { "global" } else { "profile" }; - match self { - Self::Add { - name, - trigger, - command, - global, - } => { + Self::Add { name, trigger, command } => { let trigger = if trigger == "conversation_start" { HookTrigger::ConversationStart } else { HookTrigger::PerPrompt }; - let result = context_manager - .add_hook(os, name.clone(), Hook::new_inline_hook(trigger, command), global) - .await; - match result { + match context_manager.add_hook(name.clone(), Hook::new_inline_hook(trigger, command)) { Ok(_) => { execute!( session.stderr, style::SetForegroundColor(Color::Green), - style::Print(format!("\nAdded {} hook '{name}'.\n\n", scope(global))), + style::Print(format!("\nAdded hook '{name}'.\n\n")), style::SetForegroundColor(Color::Reset) )?; }, @@ -568,20 +517,20 @@ impl HooksSubcommand { execute!( session.stderr, style::SetForegroundColor(Color::Red), - style::Print(format!("\nCannot add {} hook '{name}': {}\n\n", scope(global), e)), + style::Print(format!("\nCannot add hook '{name}': {}\n\n", e)), style::SetForegroundColor(Color::Reset) )?; }, } }, - Self::Remove { name, global } => { - let result = context_manager.remove_hook(os, &name, global).await; + Self::Remove { name } => { + let result = context_manager.remove_hook(&name); match result { Ok(_) => { execute!( session.stderr, style::SetForegroundColor(Color::Green), - style::Print(format!("\nRemoved {} hook '{name}'.\n\n", scope(global))), + style::Print(format!("\nRemoved hook '{name}'.\n\n")), style::SetForegroundColor(Color::Reset) )?; }, @@ -589,20 +538,20 @@ impl HooksSubcommand { execute!( session.stderr, style::SetForegroundColor(Color::Red), - style::Print(format!("\nCannot remove {} hook '{name}': {}\n\n", scope(global), e)), + style::Print(format!("\nCannot remove hook '{name}': {}\n\n", e)), style::SetForegroundColor(Color::Reset) )?; }, } }, - Self::Enable { name, global } => { - let result = context_manager.set_hook_disabled(os, &name, global, false).await; + Self::Enable { name } => { + let result = context_manager.set_hook_disabled(&name, false); match result { Ok(_) => { execute!( session.stderr, style::SetForegroundColor(Color::Green), - style::Print(format!("\nEnabled {} hook '{name}'.\n\n", scope(global))), + style::Print(format!("\nEnabled hook '{name}'.\n\n")), style::SetForegroundColor(Color::Reset) )?; }, @@ -610,20 +559,20 @@ impl HooksSubcommand { execute!( session.stderr, style::SetForegroundColor(Color::Red), - style::Print(format!("\nCannot enable {} hook '{name}': {}\n\n", scope(global), e)), + style::Print(format!("\nCannot enable hook '{name}': {}\n\n", e)), style::SetForegroundColor(Color::Reset) )?; }, } }, - Self::Disable { name, global } => { - let result = context_manager.set_hook_disabled(os, &name, global, true).await; + Self::Disable { name } => { + let result = context_manager.set_hook_disabled(&name, true); match result { Ok(_) => { execute!( session.stderr, style::SetForegroundColor(Color::Green), - style::Print(format!("\nDisabled {} hook '{name}'.\n\n", scope(global))), + style::Print(format!("\nDisabled hook '{name}'.\n\n")), style::SetForegroundColor(Color::Reset) )?; }, @@ -631,60 +580,31 @@ impl HooksSubcommand { execute!( session.stderr, style::SetForegroundColor(Color::Red), - style::Print(format!("\nCannot disable {} hook '{name}': {}\n\n", scope(global), e)), + style::Print(format!("\nCannot disable hook '{name}': {}\n\n", e)), style::SetForegroundColor(Color::Reset) )?; }, } }, - Self::EnableAll { global } => { - context_manager - .set_all_hooks_disabled(os, global, false) - .await - .map_err(map_chat_error)?; + Self::EnableAll => { + context_manager.set_all_hooks_disabled(false); execute!( session.stderr, style::SetForegroundColor(Color::Green), - style::Print(format!("\nEnabled all {} hooks.\n\n", scope(global))), + style::Print("\nEnabled all hooks.\n\n"), style::SetForegroundColor(Color::Reset) )?; }, - Self::DisableAll { global } => { - context_manager - .set_all_hooks_disabled(os, global, true) - .await - .map_err(map_chat_error)?; + Self::DisableAll => { + context_manager.set_all_hooks_disabled(true); execute!( session.stderr, style::SetForegroundColor(Color::Green), - style::Print(format!("\nDisabled all {} hooks.\n\n", scope(global))), + style::Print("\nDisabled all hooks.\n\n"), style::SetForegroundColor(Color::Reset) )?; }, Self::Show => { - // Display global context - execute!( - session.stderr, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print("\n🌍 global:\n"), - style::SetAttribute(Attribute::Reset), - )?; - - print_hook_section( - &mut session.stderr, - &context_manager.global_config.hooks, - HookTrigger::ConversationStart, - ) - .map_err(map_chat_error)?; - print_hook_section( - &mut session.stderr, - &context_manager.global_config.hooks, - HookTrigger::PerPrompt, - ) - .map_err(map_chat_error)?; - - // Display profile hooks execute!( session.stderr, style::SetAttribute(Attribute::Bold), @@ -769,94 +689,78 @@ mod tests { #[tokio::test] async fn test_add_hook() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; + let mut manager = create_test_context_manager(None).unwrap(); let hook = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); // Test adding hook to profile config - manager - .add_hook(&os, "test_hook".to_string(), hook.clone(), false) - .await?; + manager.add_hook("test_hook".to_string(), hook.clone())?; assert!(manager.profile_config.hooks.contains_key("test_hook")); - // Test adding hook to global config - manager - .add_hook(&os, "global_hook".to_string(), hook.clone(), true) - .await?; - assert!(manager.global_config.hooks.contains_key("global_hook")); - // Test adding duplicate hook name - assert!( - manager - .add_hook(&os, "test_hook".to_string(), hook, false) - .await - .is_err() - ); + assert!(manager.add_hook("test_hook".to_string(), hook).is_err()); Ok(()) } #[tokio::test] async fn test_remove_hook() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; + let mut manager = create_test_context_manager(None).unwrap(); let hook = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - manager.add_hook(&os, "test_hook".to_string(), hook, false).await?; + manager + .add_hook("test_hook".to_string(), hook) + .expect("Hook addition failed"); // Test removing existing hook - manager.remove_hook(&os, "test_hook", false).await?; + manager.remove_hook("test_hook").expect("Hook removal failed"); assert!(!manager.profile_config.hooks.contains_key("test_hook")); // Test removing non-existent hook - assert!(manager.remove_hook(&os, "test_hook", false).await.is_err()); + assert!(manager.remove_hook("test_hook").is_err()); Ok(()) } #[tokio::test] async fn test_set_hook_disabled() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; + let mut manager = create_test_context_manager(None).unwrap(); let hook = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - manager.add_hook(&os, "test_hook".to_string(), hook, false).await?; + manager.add_hook("test_hook".to_string(), hook).unwrap(); // Test disabling hook - manager.set_hook_disabled(&os, "test_hook", false, true).await?; + manager.set_hook_disabled("test_hook", true).unwrap(); assert!(manager.profile_config.hooks.get("test_hook").unwrap().disabled); // Test enabling hook - manager.set_hook_disabled(&os, "test_hook", false, false).await?; + manager.set_hook_disabled("test_hook", false).unwrap(); assert!(!manager.profile_config.hooks.get("test_hook").unwrap().disabled); // Test with non-existent hook - assert!( - manager - .set_hook_disabled(&os, "nonexistent", false, true) - .await - .is_err() - ); + assert!(manager.set_hook_disabled("nonexistent", true).is_err()); Ok(()) } #[tokio::test] async fn test_set_all_hooks_disabled() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; + let mut manager = create_test_context_manager(None).unwrap(); let hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); let hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - manager.add_hook(&os, "hook1".to_string(), hook1, false).await?; - manager.add_hook(&os, "hook2".to_string(), hook2, false).await?; + manager + .add_hook("hook1".to_string(), hook1) + .expect("Hook addition failed"); + manager + .add_hook("hook2".to_string(), hook2) + .expect("Hook addition failed"); // Test disabling all hooks - manager.set_all_hooks_disabled(&os, false, true).await?; + manager.set_all_hooks_disabled(true); assert!(manager.profile_config.hooks.values().all(|h| h.disabled)); // Test enabling all hooks - manager.set_all_hooks_disabled(&os, false, false).await?; + manager.set_all_hooks_disabled(false); assert!(manager.profile_config.hooks.values().all(|h| !h.disabled)); Ok(()) @@ -864,13 +768,12 @@ mod tests { #[tokio::test] async fn test_run_hooks() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; + let mut manager = create_test_context_manager(None).unwrap(); let hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); let hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - manager.add_hook(&os, "hook1".to_string(), hook1, false).await?; - manager.add_hook(&os, "hook2".to_string(), hook2, false).await?; + manager.add_hook("hook1".to_string(), hook1).unwrap(); + manager.add_hook("hook2".to_string(), hook2).unwrap(); // Run the hooks let results = manager.run_hooks(&mut vec![]).await.unwrap(); @@ -879,30 +782,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_hooks_across_profiles() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; - let hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - let hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - manager.add_hook(&os, "profile_hook".to_string(), hook1, false).await?; - manager.add_hook(&os, "global_hook".to_string(), hook2, true).await?; - - let results = manager.run_hooks(&mut vec![]).await.unwrap(); - assert_eq!(results.len(), 2); // Should include both hooks - - // Create and switch to a new profile - manager.create_profile(&os, "test_profile").await?; - manager.switch_profile(&os, "test_profile").await?; - - let results = manager.run_hooks(&mut vec![]).await.unwrap(); - assert_eq!(results.len(), 1); // Should include global hook - assert_eq!(results[0].0.name, "global_hook"); - - Ok(()) - } - #[test] fn test_hook_creation() { let command = "echo 'hello'"; diff --git a/crates/chat-cli/src/cli/chat/cli/mod.rs b/crates/chat-cli/src/cli/chat/cli/mod.rs index 7df327d3f0..f24b298201 100644 --- a/crates/chat-cli/src/cli/chat/cli/mod.rs +++ b/crates/chat-cli/src/cli/chat/cli/mod.rs @@ -23,7 +23,7 @@ use knowledge::KnowledgeSubcommand; use mcp::McpArgs; use model::ModelArgs; use persist::PersistSubcommand; -use profile::ProfileSubcommand; +use profile::AgentSubcommand; use prompts::PromptsArgs; use tools::ToolsArgs; @@ -47,9 +47,9 @@ pub enum SlashCommand { Quit, /// Clear the conversation history Clear(ClearArgs), - /// Manage profiles - #[command(subcommand)] - Profile(ProfileSubcommand), + /// Manage agents + #[command(subcommand, aliases = ["profile"])] + Agent(AgentSubcommand), /// Manage context files for the chat session #[command(subcommand)] Context(ContextSubcommand), @@ -89,7 +89,7 @@ impl SlashCommand { match self { Self::Quit => Ok(ChatState::Exit), Self::Clear(args) => args.execute(session).await, - Self::Profile(subcommand) => subcommand.execute(os, session).await, + Self::Agent(subcommand) => subcommand.execute(os, session).await, Self::Context(args) => args.execute(os, session).await, Self::Knowledge(subcommand) => subcommand.execute(os, session).await, Self::PromptEditor(args) => args.execute(session).await, diff --git a/crates/chat-cli/src/cli/chat/cli/persist.rs b/crates/chat-cli/src/cli/chat/cli/persist.rs index 3b8c8d0a96..f808b20010 100644 --- a/crates/chat-cli/src/cli/chat/cli/persist.rs +++ b/crates/chat-cli/src/cli/chat/cli/persist.rs @@ -6,7 +6,6 @@ use crossterm::style::{ Color, }; -use crate::cli::ConversationState; use crate::cli::chat::{ ChatError, ChatSession, @@ -75,33 +74,17 @@ impl PersistSubcommand { style::SetAttribute(Attribute::Reset) )?; }, - Self::Load { path } => { - // Try the original path first - let original_result = os.fs.read_to_string(&path).await; - - // If the original path fails and doesn't end with .json, try with .json appended - let contents = if original_result.is_err() && !path.ends_with(".json") { - let json_path = format!("{}.json", path); - match os.fs.read_to_string(&json_path).await { - Ok(content) => content, - Err(_) => { - // If both paths fail, return the original error for better user experience - tri!(original_result, "import from", &path) - }, - } - } else { - tri!(original_result, "import from", &path) - }; - - let mut new_state: ConversationState = tri!(serde_json::from_str(&contents), "import from", &path); - new_state.reload_serialized_state(os).await; - std::mem::swap(&mut new_state.tool_manager, &mut session.conversation.tool_manager); - session.conversation = new_state; - + Self::Load { path: _ } => { + // For profile operations that need a profile name, show profile selector + // As part of the agent implementation, we are disabling the ability to + // switch profile after a session has started. + // TODO: perhaps revive this after we have a decision on profile switching execute!( session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\n✔ Imported conversation state from {}\n\n", &path)), + style::SetForegroundColor(Color::Yellow), + style::Print( + "Conversation loading has been disabled. To load a conversation. Quit and restart q chat." + ), style::SetAttribute(Attribute::Reset) )?; }, diff --git a/crates/chat-cli/src/cli/chat/cli/profile.rs b/crates/chat-cli/src/cli/chat/cli/profile.rs index a963e0d6d8..8d148e634a 100644 --- a/crates/chat-cli/src/cli/chat/cli/profile.rs +++ b/crates/chat-cli/src/cli/chat/cli/profile.rs @@ -2,9 +2,9 @@ use clap::Subcommand; use crossterm::execute; use crossterm::style::{ self, + Attribute, Color, }; -use tracing::warn; use crate::cli::chat::{ ChatError, @@ -12,19 +12,21 @@ use crate::cli::chat::{ ChatState, }; use crate::os::Os; +use crate::util::directories::chat_global_agent_path; #[deny(missing_docs)] #[derive(Debug, PartialEq, Subcommand)] #[command( - before_long_help = "Profiles allow you to organize and manage different sets of context files for different projects or tasks. + before_long_help = "Agents allow you to organize and manage different sets of context files for different projects or tasks. Notes -• The \"global\" profile contains context files that are available in all profiles -• The \"default\" profile is used when no profile is specified -• You can switch between profiles to work on different projects -• Each profile maintains its own set of context files" +• Launch q chat with a specific agent with --agent +• Construct an agent under ~/.aws/amazonq/agents/ (accessible globally) or cwd/.aws/amazonq/agents (accessible in workspace) +• See example config under global directory +• Set default agent to assume with settings by running \"q settings chat.defaultAgent agent_name\" +• Each agent maintains its own set of context and customizations" )] -pub enum ProfileSubcommand { +pub enum AgentSubcommand { /// List all available profiles List, /// Create a new profile with the specified name @@ -37,15 +39,11 @@ pub enum ProfileSubcommand { Rename { old_name: String, new_name: String }, } -impl ProfileSubcommand { +impl AgentSubcommand { pub async fn execute(self, os: &Os, session: &mut ChatSession) -> Result { - let Some(context_manager) = &mut session.conversation.context_manager else { - return Ok(ChatState::PromptUser { - skip_printing_tools: true, - }); - }; + let agents = &session.conversation.agents; - macro_rules! print_err { + macro_rules! _print_err { ($err:expr) => { execute!( session.stderr, @@ -58,27 +56,17 @@ impl ProfileSubcommand { match self { Self::List => { - let profiles = match context_manager.list_profiles(os).await { - Ok(profiles) => profiles, - Err(e) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError listing profiles: {}\n\n", e)), - style::SetForegroundColor(Color::Reset) - )?; - vec![] - }, - }; + let profiles = agents.agents.values().collect::>(); + let active_profile = agents.get_active(); execute!(session.stderr, style::Print("\n"))?; for profile in profiles { - if profile == context_manager.current_profile { + if active_profile.is_some_and(|p| p == profile) { execute!( session.stderr, style::SetForegroundColor(Color::Green), style::Print("* "), - style::Print(&profile), + style::Print(&profile.name), style::SetForegroundColor(Color::Reset), style::Print("\n") )?; @@ -86,63 +74,32 @@ impl ProfileSubcommand { execute!( session.stderr, style::Print(" "), - style::Print(&profile), + style::Print(&profile.name), style::Print("\n") )?; } } execute!(session.stderr, style::Print("\n"))?; }, - Self::Create { name } => match context_manager.create_profile(os, &name).await { - Ok(_) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nCreated profile: {}\n\n", name)), - style::SetForegroundColor(Color::Reset) - )?; - context_manager - .switch_profile(os, &name) - .await - .map_err(|e| warn!(?e, "failed to switch to newly created profile")) - .ok(); - }, - Err(e) => print_err!(e), - }, - Self::Delete { name } => match context_manager.delete_profile(os, &name).await { - Ok(_) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nDeleted profile: {}\n\n", name)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => print_err!(e), - }, - Self::Set { name } => match context_manager.switch_profile(os, &name).await { - Ok(_) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nSwitched to profile: {}\n\n", name)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => print_err!(e), - }, - Self::Rename { old_name, new_name } => { - match context_manager.rename_profile(os, &old_name, &new_name).await { - Ok(_) => { - execute!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nRenamed profile: {} -> {}\n\n", old_name, new_name)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => print_err!(e), - } + Self::Rename { .. } | Self::Set { .. } | Self::Delete { .. } | Self::Create { .. } => { + // As part of the agent implementation, we are disabling the ability to + // switch / create profile after a session has started. + // TODO: perhaps revive this after we have a decision on profile create / + // switch + let global_path = if let Ok(path) = chat_global_agent_path(os) { + path.to_str().unwrap_or("default global agent path").to_string() + } else { + "default global agent path".to_string() + }; + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print(format!( + "Agent / Profile persistence has been disabled. To persist any changes on agent / profile, use the default agent under {} as example", + global_path + )), + style::SetAttribute(Attribute::Reset) + )?; }, } diff --git a/crates/chat-cli/src/cli/chat/cli/tools.rs b/crates/chat-cli/src/cli/chat/cli/tools.rs index 35bf7da6dd..a1388b5a13 100644 --- a/crates/chat-cli/src/cli/chat/cli/tools.rs +++ b/crates/chat-cli/src/cli/chat/cli/tools.rs @@ -1,4 +1,7 @@ -use std::collections::HashSet; +use std::collections::{ + BTreeSet, + HashSet, +}; use std::io::Write; use clap::{ @@ -15,6 +18,7 @@ use crossterm::{ }; use crate::api_client::model::Tool as FigTool; +use crate::cli::agent::Agent; use crate::cli::chat::consts::DUMMY_TOOL_NAME; use crate::cli::chat::tools::ToolOrigin; use crate::cli::chat::{ @@ -23,6 +27,7 @@ use crate::cli::chat::{ ChatState, TRUST_ALL_TEXT, }; +use crate::util::consts::MCP_SERVER_TOOL_DELIMITER; #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] @@ -42,12 +47,28 @@ impl ToolsArgs { let terminal_width = session.terminal_width(); let longest = session .conversation - .tools + .tool_manager + .tn_map .values() - .flatten() - .map(|FigTool::ToolSpecification(spec)| spec.name.len()) + .map(|info| info.host_tool_name.len()) .max() - .unwrap_or(0); + .unwrap_or(0) + .max( + session + .conversation + .tools + .get("native") + .and_then(|tools| { + tools + .iter() + .map(|tool| { + let FigTool::ToolSpecification(t) = tool; + t.name.len() + }) + .max() + }) + .unwrap_or(0), + ); queue!( session.stderr, @@ -55,7 +76,7 @@ impl ToolsArgs { style::SetAttribute(Attribute::Bold), style::Print({ // Adding 2 because of "- " preceding every tool name - let width = longest + 2 - "Tool".len() + 4; + let width = (longest + 2).saturating_sub("Tool".len()) + 4; format!("Tool{:>width$}Permission", "", width = width) }), style::SetAttribute(Attribute::Reset), @@ -73,31 +94,36 @@ impl ToolsArgs { }); for (origin, tools) in origin_tools.iter() { - let mut sorted_tools: Vec<_> = tools + // Note that Tool is model facing and thus would have names recognized by model. + // Here we need to convert them to their host / user facing counter part. + let tn_map = &session.conversation.tool_manager.tn_map; + let sorted_tools = tools .iter() - .filter(|FigTool::ToolSpecification(spec)| spec.name != DUMMY_TOOL_NAME) - .collect(); + .filter_map(|FigTool::ToolSpecification(spec)| { + if spec.name == DUMMY_TOOL_NAME { + return None; + } - sorted_tools.sort_by_key(|t| match t { - FigTool::ToolSpecification(spec) => &spec.name, - }); + tn_map + .get(&spec.name) + .map_or(Some(spec.name.as_str()), |info| Some(info.host_tool_name.as_str())) + }) + .collect::>(); - let to_display = sorted_tools - .iter() - .fold(String::new(), |mut acc, FigTool::ToolSpecification(spec)| { - let width = longest - spec.name.len() + 4; - acc.push_str( - format!( - "- {}{:>width$}{}\n", - spec.name, - "", - session.tool_permissions.display_label(&spec.name), - width = width - ) - .as_str(), - ); - acc - }); + let to_display = sorted_tools.iter().fold(String::new(), |mut acc, tool_name| { + let width = longest - tool_name.len() + 4; + acc.push_str( + format!( + "- {}{:>width$}{}\n", + tool_name, + "", + session.conversation.agents.display_label(tool_name, origin), + width = width + ) + .as_str(), + ); + acc + }); let _ = queue!( session.stderr, @@ -165,19 +191,35 @@ pub enum ToolsSubcommand { TrustAll, /// Reset all tools to default permission levels Reset, - /// Reset a single tool to default permission level - ResetSingle { tool_name: String }, } impl ToolsSubcommand { pub async fn execute(self, session: &mut ChatSession) -> Result { - let existing_tools: HashSet<&String> = session + // Here we need to obtain the list of host tool names + let existing_custom_tools = session .conversation - .tools + .tool_manager + .tn_map .values() - .flatten() - .map(|FigTool::ToolSpecification(spec)| &spec.name) - .collect(); + .cloned() + .collect::>(); + + // We also need to obtain a list of native tools since tn_map from ToolManager does not + // contain native tools + let native_tool_names = session + .conversation + .tools + .get("native") + .map(|tools| { + tools + .iter() + .filter_map(|tool| match tool { + FigTool::ToolSpecification(t) if t.name != DUMMY_TOOL_NAME => Some(t.name.clone()), + FigTool::ToolSpecification(_) => None, + }) + .collect::>() + }) + .unwrap_or_default(); match self { Self::Schema => { @@ -186,9 +228,10 @@ impl ToolsSubcommand { queue!(session.stderr, style::Print(schema_json), style::Print("\n"))?; }, Self::Trust { tool_names } => { - let (valid_tools, invalid_tools): (Vec, Vec) = tool_names - .into_iter() - .partition(|tool_name| existing_tools.contains(tool_name)); + let (valid_tools, invalid_tools): (Vec, Vec) = + tool_names.into_iter().partition(|tool_name| { + existing_custom_tools.contains(tool_name) || native_tool_names.contains(tool_name) + }); if !invalid_tools.is_empty() { queue!( @@ -204,14 +247,26 @@ impl ToolsSubcommand { )?; } if !valid_tools.is_empty() { - valid_tools.iter().for_each(|t| session.tool_permissions.trust_tool(t)); + let tools_to_trust = valid_tools + .into_iter() + .filter_map(|tool_name| { + if native_tool_names.contains(&tool_name) { + Some(tool_name) + } else { + existing_custom_tools + .get(&tool_name) + .map(|info| format!("@{}{MCP_SERVER_TOOL_DELIMITER}{tool_name}", info.server_name)) + } + }) + .collect::>(); + queue!( session.stderr, style::SetForegroundColor(Color::Green), - if valid_tools.len() > 1 { - style::Print(format!("Tools '{}' are ", valid_tools.join("', '"))) + if tools_to_trust.len() > 1 { + style::Print(format!("\nTools '{}' are ", tools_to_trust.join("', '"))) } else { - style::Print(format!("Tool '{}' is ", valid_tools[0])) + style::Print(format!("\nTool '{}' is ", tools_to_trust[0])) }, style::Print("now trusted. I will "), style::SetAttribute(Attribute::Bold), @@ -220,7 +275,7 @@ impl ToolsSubcommand { style::SetForegroundColor(Color::Green), style::Print(format!( " ask for confirmation before running {}.", - if valid_tools.len() > 1 { + if tools_to_trust.len() > 1 { "these tools" } else { "this tool" @@ -229,12 +284,15 @@ impl ToolsSubcommand { style::Print("\n"), style::SetForegroundColor(Color::Reset), )?; + + session.conversation.agents.trust_tools(tools_to_trust); } }, Self::Untrust { tool_names } => { - let (valid_tools, invalid_tools): (Vec, Vec) = tool_names - .into_iter() - .partition(|tool_name| existing_tools.contains(tool_name)); + let (valid_tools, invalid_tools): (Vec, Vec) = + tool_names.into_iter().partition(|tool_name| { + existing_custom_tools.contains(tool_name) || native_tool_names.contains(tool_name) + }); if !invalid_tools.is_empty() { queue!( @@ -250,16 +308,28 @@ impl ToolsSubcommand { )?; } if !valid_tools.is_empty() { - valid_tools - .iter() - .for_each(|t| session.tool_permissions.untrust_tool(t)); + let tools_to_untrust = valid_tools + .into_iter() + .filter_map(|tool_name| { + if native_tool_names.contains(&tool_name) { + Some(tool_name) + } else { + existing_custom_tools + .get(&tool_name) + .map(|info| format!("@{}{MCP_SERVER_TOOL_DELIMITER}{tool_name}", info.server_name)) + } + }) + .collect::>(); + + session.conversation.agents.untrust_tools(&tools_to_untrust); + queue!( session.stderr, style::SetForegroundColor(Color::Green), - if valid_tools.len() > 1 { - style::Print(format!("Tools '{}' are ", valid_tools.join("', '"))) + if tools_to_untrust.len() > 1 { + style::Print(format!("\nTools '{}' are ", tools_to_untrust.join("', '"))) } else { - style::Print(format!("Tool '{}' is ", valid_tools[0])) + style::Print(format!("\nTool '{}' is ", tools_to_untrust[0])) }, style::Print("set to per-request confirmation.\n"), style::SetForegroundColor(Color::Reset), @@ -267,46 +337,44 @@ impl ToolsSubcommand { } }, Self::TrustAll => { - session - .conversation - .tools - .values() - .flatten() - .for_each(|FigTool::ToolSpecification(spec)| { - session.tool_permissions.trust_tool(spec.name.as_str()); - }); - queue!(session.stderr, style::Print(TRUST_ALL_TEXT), style::Print("\n"))?; + session.conversation.agents.trust_all_tools = true; + queue!(session.stderr, style::Print(TRUST_ALL_TEXT))?; }, Self::Reset => { - session.tool_permissions.reset(); + session.conversation.agents.trust_all_tools = false; + + let active_agent_path = session.conversation.agents.get_active().and_then(|a| a.path.clone()); + if let Some(path) = active_agent_path { + let result = async { + let content = tokio::fs::read(&path).await?; + let orig_agent: Agent = serde_json::from_slice(&content)?; + Ok::>(orig_agent) + } + .await; + + if let (Ok(orig_agent), Some(active_agent)) = (result, session.conversation.agents.get_active_mut()) + { + active_agent.allowed_tools = orig_agent.allowed_tools; + } + } else if session + .conversation + .agents + .get_active() + .is_some_and(|a| a.name.as_str() == "default") + { + // We only want to reset the tool permission and nothing else + if let Some(active_agent) = session.conversation.agents.get_active_mut() { + active_agent.allowed_tools = Default::default(); + active_agent.tools_settings = Default::default(); + } + } queue!( session.stderr, style::SetForegroundColor(Color::Green), - style::Print("Reset all tools to the default permission levels.\n"), + style::Print("\nReset all tools to the permission levels as defined in agent."), style::SetForegroundColor(Color::Reset), )?; }, - Self::ResetSingle { tool_name } => { - if session.tool_permissions.has(&tool_name) || session.tool_permissions.trust_all { - session.tool_permissions.reset_tool(&tool_name); - queue!( - session.stderr, - style::SetForegroundColor(Color::Green), - style::Print(format!("Reset tool '{}' to the default permission level.\n", tool_name)), - style::SetForegroundColor(Color::Reset), - )?; - } else { - queue!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print(format!( - "Tool '{}' does not exist or is already in default settings.\n", - tool_name - )), - style::SetForegroundColor(Color::Reset), - )?; - } - }, }; session.stderr.flush()?; diff --git a/crates/chat-cli/src/cli/chat/context.rs b/crates/chat-cli/src/cli/chat/context.rs index f5f6dbbdad..b6d9cbf609 100644 --- a/crates/chat-cli/src/cli/chat/context.rs +++ b/crates/chat-cli/src/cli/chat/context.rs @@ -1,33 +1,27 @@ use std::collections::HashMap; use std::io::Write; -use std::path::{ - Path, - PathBuf, -}; +use std::path::Path; use eyre::{ Result, eyre, }; use glob::glob; -use regex::Regex; use serde::{ Deserialize, Serialize, }; -use tracing::debug; use super::consts::CONTEXT_FILES_MAX_SIZE; use super::util::drop_matched_context_files; +use crate::cli::agent::Agent; use crate::cli::chat::ChatError; use crate::cli::chat::cli::hooks::{ Hook, HookExecutor, + HookTrigger, }; use crate::os::Os; -use crate::util::directories; - -pub const AMAZONQ_FILENAME: &str = "AmazonQ.md"; /// Configuration for context files, containing paths to include in the context. #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -40,14 +34,61 @@ pub struct ContextConfig { pub hooks: HashMap, } +impl TryFrom<&Agent> for ContextConfig { + type Error = eyre::Report; + + fn try_from(value: &Agent) -> Result { + Ok(Self { + paths: value.included_files.clone(), + hooks: { + let mut hooks = HashMap::::new(); + + if value.prompt_hooks.is_array() { + let prompt_hooks = serde_json::from_value::>(value.prompt_hooks.clone()) + .map_err(|e| eyre::eyre!("Error deserializing prompt hooks: {:?}", e))?; + prompt_hooks + .clone() + .into_iter() + .map(|command| Hook::new_inline_hook(HookTrigger::PerPrompt, command)) + .enumerate() + .for_each(|(i, hook)| { + hooks.insert(format!("per_prompt_hook_{i}"), hook); + }); + } else if value.prompt_hooks.is_object() { + let prompt_hooks = serde_json::from_value::>(value.prompt_hooks.clone()) + .map_err(|e| eyre::eyre!("Error deserializing prompt hooks: {:?}", e))?; + hooks.extend(prompt_hooks); + } + + if value.create_hooks.is_array() { + let create_hooks = serde_json::from_value::>(value.create_hooks.clone()) + .map_err(|e| eyre::eyre!("Error deserializing prompt hooks: {:?}", e))?; + create_hooks + .clone() + .into_iter() + .map(|command| Hook::new_inline_hook(HookTrigger::ConversationStart, command)) + .enumerate() + .for_each(|(i, hook)| { + hooks.insert(format!("start_hook_{i}"), hook); + }); + } else if value.create_hooks.is_object() { + let create_hooks = serde_json::from_value::>(value.create_hooks.clone()) + .map_err(|e| eyre::eyre!("Error deserializing prompt hooks: {:?}", e))?; + hooks.extend(create_hooks); + } + + hooks + }, + }) + } +} + +#[allow(dead_code)] /// Manager for context files and profiles. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ContextManager { max_context_files_size: usize, - /// Global context configuration that applies to all profiles. - pub global_config: ContextConfig, - /// Name of the current active profile. pub current_profile: String, @@ -59,90 +100,29 @@ pub struct ContextManager { } impl ContextManager { - /// Create a new ContextManager with default settings. - /// - /// This will: - /// 1. Create the necessary directories if they don't exist - /// 2. Load the global configuration - /// 3. Load the default profile configuration - /// - /// # Arguments - /// * `os` - The context to use - /// * `max_context_files_size` - Optional maximum token size for context files. If not provided, - /// defaults to `CONTEXT_FILES_MAX_SIZE`. - /// - /// # Returns - /// A Result containing the new ContextManager or an error - pub async fn new(os: &Os, max_context_files_size: Option) -> Result { + pub fn from_agent(agent: &Agent, max_context_files_size: Option) -> Result { let max_context_files_size = max_context_files_size.unwrap_or(CONTEXT_FILES_MAX_SIZE); - let profiles_dir = directories::chat_profiles_dir(os)?; - - os.fs.create_dir_all(&profiles_dir).await?; - - let global_config = load_global_config(os).await?; - let current_profile = "default".to_string(); - let profile_config = load_profile_config(os, ¤t_profile).await?; + let current_profile = agent.name.clone(); + let profile_config = ContextConfig::try_from(agent)?; Ok(Self { max_context_files_size, - global_config, current_profile, profile_config, hook_executor: HookExecutor::new(), }) } - /// Save the current configuration to disk. - /// - /// # Arguments - /// * `global` - If true, save the global configuration; otherwise, save the current profile - /// configuration - /// - /// # Returns - /// A Result indicating success or an error - async fn save_config(&self, os: &Os, global: bool) -> Result<()> { - if global { - let global_path = directories::chat_global_context_path(os)?; - let contents = serde_json::to_string_pretty(&self.global_config) - .map_err(|e| eyre!("Failed to serialize global configuration: {}", e))?; - - os.fs.write(&global_path, contents).await?; - } else { - let profile_path = profile_context_path(os, &self.current_profile)?; - if let Some(parent) = profile_path.parent() { - os.fs.create_dir_all(parent).await?; - } - let contents = serde_json::to_string_pretty(&self.profile_config) - .map_err(|e| eyre!("Failed to serialize profile configuration: {}", e))?; - - os.fs.write(&profile_path, contents).await?; - } - - Ok(()) - } - - /// Reloads the global and profile config from disk. - pub async fn reload_config(&mut self, os: &Os) -> Result<()> { - self.global_config = load_global_config(os).await?; - self.profile_config = load_profile_config(os, &self.current_profile).await?; - Ok(()) - } - /// Add paths to the context configuration. /// /// # Arguments /// * `paths` - List of paths to add - /// * `global` - If true, add to global configuration; otherwise, add to current profile - /// configuration /// * `force` - If true, skip validation that the path exists /// /// # Returns /// A Result indicating success or an error - pub async fn add_paths(&mut self, os: &Os, paths: Vec, global: bool, force: bool) -> Result<()> { - let mut all_paths = self.global_config.paths.clone(); - all_paths.append(&mut self.profile_config.paths.clone()); - + pub async fn add_paths(&mut self, os: &Os, paths: Vec, force: bool) -> Result<()> { // Validate paths exist before adding them if !force { let mut context_files = Vec::new(); @@ -160,19 +140,12 @@ impl ContextManager { // Add each path, checking for duplicates for path in paths { - if all_paths.contains(&path) { + if self.profile_config.paths.contains(&path) { return Err(eyre!("Rule '{}' already exists.", path)); } - if global { - self.global_config.paths.push(path); - } else { - self.profile_config.paths.push(path); - } + self.profile_config.paths.push(path); } - // Save the updated configuration - self.save_config(os, global).await?; - Ok(()) } @@ -180,258 +153,24 @@ impl ContextManager { /// /// # Arguments /// * `paths` - List of paths to remove - /// * `global` - If true, remove from global configuration; otherwise, remove from current - /// profile configuration /// /// # Returns /// A Result indicating success or an error - pub async fn remove_paths(&mut self, os: &Os, paths: Vec, global: bool) -> Result<()> { - // Get reference to the appropriate config - let config = self.get_config_mut(global); - - // Track if any paths were removed - let mut removed_any = false; - + pub fn remove_paths(&mut self, paths: Vec) -> Result<()> { // Remove each path if it exists - for path in paths { - let original_len = config.paths.len(); - config.paths.retain(|p| p != &path); - - if config.paths.len() < original_len { - removed_any = true; - } - } + let old_path_num = self.profile_config.paths.len(); + self.profile_config.paths.retain(|p| !paths.contains(p)); - if !removed_any { + if old_path_num == self.profile_config.paths.len() { return Err(eyre!("None of the specified paths were found in the context")); } - // Save the updated configuration - self.save_config(os, global).await?; - Ok(()) } - /// List all available profiles. - /// - /// # Returns - /// A Result containing a vector of profile names, with "default" always first - pub async fn list_profiles(&self, os: &Os) -> Result> { - let mut profiles = Vec::new(); - - // Always include default profile - profiles.push("default".to_string()); - - // Read profile directory and extract profile names - let profiles_dir = directories::chat_profiles_dir(os)?; - if profiles_dir.exists() { - let mut read_dir = os.fs.read_dir(&profiles_dir).await?; - while let Some(entry) = read_dir.next_entry().await? { - let path = entry.path(); - if let (true, Some(name)) = (path.is_dir(), path.file_name()) { - if name != "default" { - profiles.push(name.to_string_lossy().to_string()); - } - } - } - } - - // Sort non-default profiles alphabetically - if profiles.len() > 1 { - profiles[1..].sort(); - } - - Ok(profiles) - } - - /// List all available profiles using blocking operations. - /// - /// Similar to list_profiles but uses synchronous filesystem operations. - /// - /// # Returns - /// A Result containing a vector of profile names, with "default" always first - pub fn list_profiles_blocking(&self, os: &Os) -> Result> { - let _ = self; - - let mut profiles = Vec::new(); - - // Always include default profile - profiles.push("default".to_string()); - - // Read profile directory and extract profile names - let profiles_dir = directories::chat_profiles_dir(os)?; - if profiles_dir.exists() { - for entry in std::fs::read_dir(profiles_dir)? { - let entry = entry?; - let path = entry.path(); - if let (true, Some(name)) = (path.is_dir(), path.file_name()) { - if name != "default" { - profiles.push(name.to_string_lossy().to_string()); - } - } - } - } - - // Sort non-default profiles alphabetically - if profiles.len() > 1 { - profiles[1..].sort(); - } - - Ok(profiles) - } - /// Clear all paths from the context configuration. - /// - /// # Arguments - /// * `global` - If true, clear global configuration; otherwise, clear current profile - /// configuration - /// - /// # Returns - /// A Result indicating success or an error - pub async fn clear(&mut self, os: &Os, global: bool) -> Result<()> { - // Clear the appropriate config - if global { - self.global_config.paths.clear(); - } else { - self.profile_config.paths.clear(); - } - - // Save the updated configuration - self.save_config(os, global).await?; - - Ok(()) - } - - /// Create a new profile. - /// - /// # Arguments - /// * `name` - Name of the profile to create - /// - /// # Returns - /// A Result indicating success or an error - pub async fn create_profile(&self, os: &Os, name: &str) -> Result<()> { - validate_profile_name(name)?; - - // Check if profile already exists - let profile_path = profile_context_path(os, name)?; - if profile_path.exists() { - return Err(eyre!("Profile '{}' already exists", name)); - } - - // Create empty profile configuration - let config = ContextConfig::default(); - let contents = serde_json::to_string_pretty(&config) - .map_err(|e| eyre!("Failed to serialize profile configuration: {}", e))?; - - // Create the file - if let Some(parent) = profile_path.parent() { - os.fs.create_dir_all(parent).await?; - } - os.fs.write(&profile_path, contents).await?; - - Ok(()) - } - - /// Delete a profile. - /// - /// # Arguments - /// * `name` - Name of the profile to delete - /// - /// # Returns - /// A Result indicating success or an error - pub async fn delete_profile(&self, os: &Os, name: &str) -> Result<()> { - if name == "default" { - return Err(eyre!("Cannot delete the default profile")); - } else if name == self.current_profile { - return Err(eyre!( - "Cannot delete the active profile. Switch to another profile first" - )); - } - - let profile_path = profile_dir_path(os, name)?; - if !profile_path.exists() { - return Err(eyre!("Profile '{}' does not exist", name)); - } - - os.fs.remove_dir_all(&profile_path).await?; - - Ok(()) - } - - /// Rename a profile. - /// - /// # Arguments - /// * `old_name` - Current name of the profile - /// * `new_name` - New name for the profile - /// - /// # Returns - /// A Result indicating success or an error - pub async fn rename_profile(&mut self, os: &Os, old_name: &str, new_name: &str) -> Result<()> { - // Validate profile names - if old_name == "default" { - return Err(eyre!("Cannot rename the default profile")); - } - if new_name == "default" { - return Err(eyre!("Cannot rename to 'default' as it's a reserved profile name")); - } - - validate_profile_name(new_name)?; - - let old_profile_path = profile_dir_path(os, old_name)?; - if !old_profile_path.exists() { - return Err(eyre!("Profile '{}' not found", old_name)); - } - - let new_profile_path = profile_dir_path(os, new_name)?; - if new_profile_path.exists() { - return Err(eyre!("Profile '{}' already exists", new_name)); - } - - os.fs.rename(&old_profile_path, &new_profile_path).await?; - - // If the current profile is being renamed, update the current_profile field - if self.current_profile == old_name { - self.current_profile = new_name.to_string(); - self.profile_config = load_profile_config(os, new_name).await?; - } - - Ok(()) - } - - /// Switch to a different profile. - /// - /// # Arguments - /// * `name` - Name of the profile to switch to - /// - /// # Returns - /// A Result indicating success or an error - pub async fn switch_profile(&mut self, os: &Os, name: &str) -> Result<()> { - validate_profile_name(name)?; - self.hook_executor.profile_cache.clear(); - - // Special handling for default profile - it always exists - if name == "default" { - // Load the default profile configuration - let profile_config = load_profile_config(os, name).await?; - - // Update the current profile - self.current_profile = name.to_string(); - self.profile_config = profile_config; - - return Ok(()); - } - - // Check if profile exists - let profile_path = profile_context_path(os, name)?; - if !profile_path.exists() { - return Err(eyre!("Profile '{}' does not exist. Use 'create' to create it", name)); - } - - // Update the current profile - self.current_profile = name.to_string(); - self.profile_config = load_profile_config(os, name).await?; - - Ok(()) + pub fn clear(&mut self) { + self.profile_config.paths.clear(); } /// Get all context files (global + profile-specific). @@ -448,8 +187,6 @@ impl ContextManager { pub async fn get_context_files(&self, os: &Os) -> Result> { let mut context_files = Vec::new(); - self.collect_context_files(os, &self.global_config.paths, &mut context_files) - .await?; self.collect_context_files(os, &self.profile_config.paths, &mut context_files) .await?; @@ -494,158 +231,61 @@ impl ContextManager { Ok(()) } - fn get_config_mut(&mut self, global: bool) -> &mut ContextConfig { - if global { - &mut self.global_config - } else { - &mut self.profile_config - } - } - /// Add hooks to the context config. If another hook with the same name already exists, throw an /// error. - /// - /// # Arguments - /// * `hook` - name of the hook to delete - /// * `global` - If true, the add to the global config. If false, add to the current profile - /// config. - /// * `conversation_start` - If true, add the hook to conversation_start. Otherwise, it will be - /// added to per_prompt. - pub async fn add_hook(&mut self, os: &Os, name: String, hook: Hook, global: bool) -> Result<()> { - let config = self.get_config_mut(global); - - if config.hooks.contains_key(&name) { + pub fn add_hook(&mut self, name: String, hook: Hook) -> Result<()> { + if self.profile_config.hooks.contains_key(&name) { return Err(eyre!("name already exists.")); } - - config.hooks.insert(name, hook); - self.save_config(os, global).await + self.profile_config.hooks.insert(name, hook); + Ok(()) } /// Delete hook(s) by name - /// # Arguments - /// * `name` - name of the hook to delete - /// * `global` - If true, the delete from the global config. If false, delete from the current - /// profile config - pub async fn remove_hook(&mut self, os: &Os, name: &str, global: bool) -> Result<()> { - let config = self.get_config_mut(global); - - if !config.hooks.contains_key(name) { + pub fn remove_hook(&mut self, name: &str) -> Result<()> { + if !self.profile_config.hooks.contains_key(name) { return Err(eyre!("does not exist.")); } - - config.hooks.remove(name); - - self.save_config(os, global).await + self.profile_config.hooks.remove(name); + Ok(()) } /// Sets the "disabled" field on any [`Hook`] with the given name - /// # Arguments - /// * `disable` - Set "disabled" field to this value - pub async fn set_hook_disabled(&mut self, os: &Os, name: &str, global: bool, disable: bool) -> Result<()> { - let config = self.get_config_mut(global); - - if !config.hooks.contains_key(name) { - return Err(eyre!("does not exist.")); - } - - if let Some(hook) = config.hooks.get_mut(name) { + pub fn set_hook_disabled(&mut self, name: &str, disable: bool) -> Result<()> { + if let Some(hook) = self.profile_config.hooks.get_mut(name) { hook.disabled = disable; + } else { + return Err(eyre!("does not exist.")); } - self.save_config(os, global).await + Ok(()) } /// Sets the "disabled" field on all [`Hook`]s - /// # Arguments - /// * `disable` - Set all "disabled" fields to this value - pub async fn set_all_hooks_disabled(&mut self, os: &Os, global: bool, disable: bool) -> Result<()> { - let config = self.get_config_mut(global); - - config.hooks.iter_mut().for_each(|(_, h)| h.disabled = disable); - - self.save_config(os, global).await + pub fn set_all_hooks_disabled(&mut self, disable: bool) { + self.profile_config + .hooks + .iter_mut() + .for_each(|(_, h)| h.disabled = disable); } /// Run all the currently enabled hooks from both the global and profile contexts. - /// Skipped hooks (disabled) will not appear in the output. - /// # Arguments - /// * `updates` - output stream to write hook run status to if Some, else do nothing if None /// # Returns /// A vector containing pairs of a [`Hook`] definition and its execution output pub async fn run_hooks(&mut self, output: &mut impl Write) -> Result, ChatError> { - let mut hooks: Vec<&Hook> = Vec::new(); - - // Set internal hook states - let configs = [ - (&mut self.global_config.hooks, true), - (&mut self.profile_config.hooks, false), - ]; - - for (hook_list, is_global) in configs { - hooks.extend(hook_list.iter_mut().map(|(name, h)| { - h.name = name.clone(); - h.is_global = is_global; - &*h - })); - } - + let hooks = self + .profile_config + .hooks + .iter_mut() + .map(|(name, hook)| { + hook.name = name.clone(); + hook as &Hook + }) + .collect::>(); self.hook_executor.run_hooks(hooks, output).await } } -fn profile_dir_path(os: &Os, profile_name: &str) -> Result { - Ok(directories::chat_profiles_dir(os)?.join(profile_name)) -} - -/// Path to the context config file for `profile_name`. -pub fn profile_context_path(os: &Os, profile_name: &str) -> Result { - Ok(directories::chat_profiles_dir(os)? - .join(profile_name) - .join("context.json")) -} - -/// Load the global context configuration. -/// -/// If the global configuration file doesn't exist, returns a default configuration. -async fn load_global_config(os: &Os) -> Result { - let global_path = directories::chat_global_context_path(os)?; - debug!(?global_path, "loading profile config"); - if os.fs.exists(&global_path) { - let contents = os.fs.read_to_string(&global_path).await?; - let config: ContextConfig = - serde_json::from_str(&contents).map_err(|e| eyre!("Failed to parse global configuration: {}", e))?; - Ok(config) - } else { - // Return default global configuration with predefined paths - Ok(ContextConfig { - paths: vec![ - ".amazonq/rules/**/*.md".to_string(), - "README.md".to_string(), - AMAZONQ_FILENAME.to_string(), - ], - hooks: HashMap::new(), - }) - } -} - -/// Load a profile's context configuration. -/// -/// If the profile configuration file doesn't exist, creates a default configuration. -async fn load_profile_config(os: &Os, profile_name: &str) -> Result { - let profile_path = profile_context_path(os, profile_name)?; - debug!(?profile_path, "loading profile config"); - if os.fs.exists(&profile_path) { - let contents = os.fs.read_to_string(&profile_path).await?; - let config: ContextConfig = - serde_json::from_str(&contents).map_err(|e| eyre!("Failed to parse profile configuration: {}", e))?; - Ok(config) - } else { - // Return empty configuration for new profiles - Ok(ContextConfig::default()) - } -} - /// Process a path, handling glob patterns and file types. /// /// This method: @@ -761,106 +401,20 @@ async fn add_file_to_context(os: &Os, path: &Path, context_files: &mut Vec<(Stri Ok(()) } -/// Validate a profile name. -/// -/// Profile names can only contain alphanumeric characters, hyphens, and underscores. -/// -/// # Arguments -/// * `name` - Name to validate -/// -/// # Returns -/// A Result indicating if the name is valid -fn validate_profile_name(name: &str) -> Result<()> { - // Check if name is empty - if name.is_empty() { - return Err(eyre!("Profile name cannot be empty")); - } - - // Check if name contains only allowed characters and starts with an alphanumeric character - let re = Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$").unwrap(); - if !re.is_match(name) { - return Err(eyre!( - "Profile name must start with an alphanumeric character and can only contain alphanumeric characters, hyphens, and underscores" - )); - } - - Ok(()) -} - #[cfg(test)] mod tests { use super::*; use crate::cli::chat::util::test::create_test_context_manager; - #[tokio::test] - async fn test_validate_profile_name() { - // Test valid names - assert!(validate_profile_name("valid").is_ok()); - assert!(validate_profile_name("valid-name").is_ok()); - assert!(validate_profile_name("valid_name").is_ok()); - assert!(validate_profile_name("valid123").is_ok()); - assert!(validate_profile_name("1valid").is_ok()); - assert!(validate_profile_name("9test").is_ok()); - - // Test invalid names - assert!(validate_profile_name("").is_err()); - assert!(validate_profile_name("invalid/name").is_err()); - assert!(validate_profile_name("invalid.name").is_err()); - assert!(validate_profile_name("invalid name").is_err()); - assert!(validate_profile_name("_invalid").is_err()); - assert!(validate_profile_name("-invalid").is_err()); - } - - #[tokio::test] - async fn test_profile_ops() -> Result<()> { - let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; - - assert_eq!(manager.current_profile, "default"); - - // Create ops - manager.create_profile(&os, "test_profile").await?; - assert!(profile_context_path(&os, "test_profile")?.exists()); - assert!(manager.create_profile(&os, "test_profile").await.is_err()); - manager.create_profile(&os, "alt").await?; - - // Listing - let profiles = manager.list_profiles(&os).await?; - assert!(profiles.contains(&"default".to_string())); - assert!(profiles.contains(&"test_profile".to_string())); - assert!(profiles.contains(&"alt".to_string())); - - // Switching - manager.switch_profile(&os, "test_profile").await?; - assert!(manager.switch_profile(&os, "notexists").await.is_err()); - - // Renaming - manager.rename_profile(&os, "alt", "renamed").await?; - assert!(!profile_context_path(&os, "alt")?.exists()); - assert!(profile_context_path(&os, "renamed")?.exists()); - - // Delete ops - assert!(manager.delete_profile(&os, "test_profile").await.is_err()); - manager.switch_profile(&os, "default").await?; - manager.delete_profile(&os, "test_profile").await?; - assert!(!profile_context_path(&os, "test_profile")?.exists()); - assert!(manager.delete_profile(&os, "test_profile").await.is_err()); - assert!(manager.delete_profile(&os, "default").await.is_err()); - - Ok(()) - } - #[tokio::test] async fn test_collect_exceeds_limit() -> Result<()> { let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(Some(2)).await?; + let mut manager = create_test_context_manager(Some(2)).expect("Failed to create test context manager"); os.fs.create_dir_all("test").await?; os.fs.write("test/to-include.md", "ha").await?; os.fs.write("test/to-drop.md", "long content that exceed limit").await?; - manager - .add_paths(&os, vec!["test/*.md".to_string()], false, false) - .await?; + manager.add_paths(&os, vec!["test/*.md".to_string()], false).await?; let (used, dropped) = manager.collect_context_files_with_limit(&os).await.unwrap(); @@ -873,7 +427,7 @@ mod tests { #[tokio::test] async fn test_path_ops() -> Result<()> { let os = Os::new().await.unwrap(); - let mut manager = create_test_context_manager(None).await?; + let mut manager = create_test_context_manager(None).expect("Failed to create test context manager"); // Create some test files for matching. os.fs.create_dir_all("test").await?; @@ -885,9 +439,7 @@ mod tests { "no files should be returned for an empty profile when force is false" ); - manager - .add_paths(&os, vec!["test/*.md".to_string()], false, false) - .await?; + manager.add_paths(&os, vec!["test/*.md".to_string()], false).await?; let files = manager.get_context_files(&os).await?; assert!(files[0].0.ends_with("p1.md")); assert_eq!(files[0].1, "p1"); @@ -896,7 +448,7 @@ mod tests { assert!( manager - .add_paths(&os, vec!["test/*.txt".to_string()], false, false) + .add_paths(&os, vec!["test/*.txt".to_string()], false) .await .is_err(), "adding a glob with no matching and without force should fail" diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index c5ce21e52a..476a83f0ef 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -65,6 +65,7 @@ use crate::api_client::model::{ UserInputMessage, UserInputMessageContext, }; +use crate::cli::agent::Agents; use crate::cli::chat::ChatError; use crate::cli::chat::cli::hooks::{ Hook, @@ -102,6 +103,8 @@ pub struct ConversationState { context_message_length: Option, /// Stores the latest conversation summary created by /compact latest_summary: Option, + #[serde(skip)] + pub agents: Agents, /// Model explicitly selected by the user in this conversation state via `/model`. #[serde(default, skip_serializing_if = "Option::is_none")] pub model: Option, @@ -109,28 +112,16 @@ pub struct ConversationState { impl ConversationState { pub async fn new( - os: &mut Os, conversation_id: &str, + agents: Agents, tool_config: HashMap, - profile: Option, tool_manager: ToolManager, current_model_id: Option, ) -> Self { - // Initialize context manager - let context_manager = match ContextManager::new(os, None).await { - Ok(mut manager) => { - // Switch to specified profile if provided - if let Some(profile_name) = profile { - if let Err(e) = manager.switch_profile(os, &profile_name).await { - warn!("Failed to switch to profile {}: {}", profile_name, e); - } - } - Some(manager) - }, - Err(e) => { - warn!("Failed to initialize context manager: {}", e); - None - }, + let context_manager = if let Some(agent) = agents.get_active() { + ContextManager::from_agent(agent, None).ok() + } else { + None }; Self { @@ -156,40 +147,11 @@ impl ConversationState { tool_manager, context_message_length: None, latest_summary: None, + agents, model: current_model_id, } } - /// Reloads necessary fields after being deserialized. This should be called after - /// deserialization. - pub async fn reload_serialized_state(&mut self, os: &Os) { - // Try to reload ContextManager, but do not return an error if we fail. - // TODO: Currently the failure modes around ContextManager is unclear, and we don't return - // errors in most cases. Thus, we try to preserve the same behavior here and simply have - // self.context_manager equal to None if any errors are encountered. This needs to be - // refactored. - let mut failed = false; - if let Some(context_manager) = self.context_manager.as_mut() { - match context_manager.reload_config(os).await { - Ok(_) => (), - Err(err) => { - error!(?err, "failed to reload context config"); - match ContextManager::new(os, None).await { - Ok(v) => *context_manager = v, - Err(err) => { - failed = true; - error!(?err, "failed to construct context manager"); - }, - } - }, - } - } - - if failed { - self.context_manager.take(); - } - } - pub fn latest_summary(&self) -> Option<&str> { self.latest_summary.as_deref() } @@ -962,18 +924,34 @@ fn format_hook_context<'a>(hook_results: impl IntoIterator io::Result { + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + } + fn assert_conversation_state_invariants(state: FigConversationState, assertion_iteration: usize) { if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) { assert!( @@ -1065,9 +1043,18 @@ mod tests { #[tokio::test] async fn test_conversation_state_history_handling_truncation() { let mut os = Os::new().await.unwrap(); + let agents = Agents::default(); + let mut output = NullWriter; + let mut tool_manager = ToolManager::default(); - let tools = tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(); - let mut conversation = ConversationState::new(&mut os, "fake_conv_id", tools, None, tool_manager, None).await; + let mut conversation = ConversationState::new( + "fake_conv_id", + agents, + tool_manager.load_tools(&mut os, &mut output).await.unwrap(), + tool_manager, + None, + ) + .await; // First, build a large conversation history. We need to ensure that the order is always // User -> Assistant -> User -> Assistant ...and so on. @@ -1086,15 +1073,15 @@ mod tests { #[tokio::test] async fn test_conversation_state_history_handling_with_tool_results() { let mut os = Os::new().await.unwrap(); + let agents = Agents::default(); // Build a long conversation history of tool use results. let mut tool_manager = ToolManager::default(); let tool_config = tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(); let mut conversation = ConversationState::new( - &mut os, "fake_conv_id", + agents.clone(), tool_config.clone(), - None, tool_manager.clone(), None, ) @@ -1124,15 +1111,8 @@ mod tests { } // Build a long conversation history of user messages mixed in with tool results. - let mut conversation = ConversationState::new( - &mut os, - "fake_conv_id", - tool_config.clone(), - None, - tool_manager.clone(), - None, - ) - .await; + let mut conversation = + ConversationState::new("fake_conv_id", agents, tool_config.clone(), tool_manager.clone(), None).await; conversation.set_next_user_message("start".to_string()).await; for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { let s = conversation @@ -1165,11 +1145,26 @@ mod tests { #[tokio::test] async fn test_conversation_state_with_context_files() { let mut os = Os::new().await.unwrap(); + let agents = { + let mut agents = Agents::default(); + let mut agent = Agent::default(); + agent.included_files.push(AMAZONQ_FILENAME.to_string()); + agents.agents.insert("TestAgent".to_string(), agent); + agents.switch("TestAgent").expect("Agent switch failed"); + agents + }; os.fs.write(AMAZONQ_FILENAME, "test context").await.unwrap(); + let mut output = NullWriter; let mut tool_manager = ToolManager::default(); - let tools = tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(); - let mut conversation = ConversationState::new(&mut os, "fake_conv_id", tools, None, tool_manager, None).await; + let mut conversation = ConversationState::new( + "fake_conv_id", + agents, + tool_manager.load_tools(&mut os, &mut output).await.unwrap(), + tool_manager, + None, + ) + .await; // First, build a large conversation history. We need to ensure that the order is always // User -> Assistant -> User -> Assistant ...and so on. @@ -1205,31 +1200,44 @@ mod tests { #[tokio::test] async fn test_conversation_state_additional_context() { let mut os = Os::new().await.unwrap(); - let mut tool_manager = ToolManager::default(); let conversation_start_context = "conversation start context"; let prompt_context = "prompt context"; - let config = serde_json::json!({ - "hooks": { - "test_per_prompt": { - "trigger": "per_prompt", - "type": "inline", - "command": format!("echo {}", prompt_context) - }, + let agents = { + let mut agents = Agents::default(); + let create_hooks = serde_json::json!({ "test_conversation_start": { "trigger": "conversation_start", "type": "inline", "command": format!("echo {}", conversation_start_context) } - } - }); - let config_path = profile_context_path(&os, "default").unwrap(); - os.fs.create_dir_all(config_path.parent().unwrap()).await.unwrap(); - os.fs - .write(&config_path, serde_json::to_string(&config).unwrap()) - .await - .unwrap(); - let tools = tool_manager.load_tools(&mut os, &mut vec![]).await.unwrap(); - let mut conversation = ConversationState::new(&mut os, "fake_conv_id", tools, None, tool_manager, None).await; + }); + let prompt_hooks = serde_json::json!({ + "test_per_prompt": { + "trigger": "per_prompt", + "type": "inline", + "command": format!("echo {}", prompt_context) + } + }); + let agent = Agent { + create_hooks, + prompt_hooks, + ..Default::default() + }; + agents.agents.insert("TestAgent".to_string(), agent); + agents.switch("TestAgent").expect("Agent switch failed"); + agents + }; + let mut output = NullWriter; + + let mut tool_manager = ToolManager::default(); + let mut conversation = ConversationState::new( + "fake_conv_id", + agents, + tool_manager.load_tools(&mut os, &mut output).await.unwrap(), + tool_manager, + None, + ) + .await; // Simulate conversation flow conversation.set_next_user_message("start".to_string()).await; diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 2b1656ac1d..b4c7971508 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -1,6 +1,6 @@ -mod cli; +pub mod cli; mod consts; -mod context; +pub mod context; mod conversation; mod error_formatter; mod input_source; @@ -21,7 +21,6 @@ pub mod util; use std::borrow::Cow; use std::collections::{ HashMap, - HashSet, VecDeque, }; use std::io::{ @@ -38,7 +37,6 @@ use clap::{ CommandFactory, Parser, }; -use context::ContextManager; pub use conversation::ConversationState; use conversation::TokenWarningLevel; use crossterm::style::{ @@ -84,7 +82,6 @@ use time::OffsetDateTime; use token_counter::TokenCounter; use tokio::signal::ctrl_c; use tool_manager::{ - McpServerConfig, ToolManager, ToolManagerBuilder, }; @@ -93,7 +90,6 @@ use tools::{ OutputKind, QueuedTool, Tool, - ToolPermissions, ToolSpec, }; use tracing::{ @@ -112,14 +108,13 @@ use util::{ use winnow::Partial; use winnow::stream::Offset; +use super::agent::PermissionEvalResult; use crate::api_client::ApiClientError; -use crate::api_client::model::{ - Tool as FigTool, - ToolResultStatus, -}; +use crate::api_client::model::ToolResultStatus; use crate::api_client::send_message_output::SendMessageOutput; use crate::auth::AuthError; use crate::auth::builder_id::is_idc_user; +use crate::cli::agent::Agents; use crate::cli::chat::cli::SlashCommand; use crate::cli::chat::cli::model::{ MODEL_OPTIONS, @@ -138,6 +133,7 @@ use crate::telemetry::{ TelemetryResult, get_error_reason, }; +use crate::util::MCP_SERVER_TOOL_DELIMITER; const LIMIT_REACHED_TEXT: &str = color_print::cstr! { "You've used all your free requests for this month. You have two options: 1. Upgrade to a paid subscription for increased limits. See our Pricing page for what's included> https://aws.amazon.com/q/developer/pricing/ @@ -164,8 +160,8 @@ pub struct ChatArgs { #[arg(short, long)] pub resume: bool, /// Context profile to use - #[arg(long = "profile")] - pub profile: Option, + #[arg(long = "agent", alias = "profile")] + pub agent: Option, /// Current model to use #[arg(long = "model")] pub model: Option, @@ -181,10 +177,13 @@ pub struct ChatArgs { pub no_interactive: bool, /// The first question to ask pub input: Option, + /// Run migration of legacy profiles to agents if applicable + #[arg(long)] + pub migrate: bool, } impl ChatArgs { - pub async fn execute(self, os: &mut Os) -> Result { + pub async fn execute(mut self, os: &mut Os) -> Result { let mut input = self.input; if self.no_interactive && input.is_none() { @@ -207,48 +206,61 @@ impl ChatArgs { } } + let args: Vec = std::env::args().collect(); + if args + .iter() + .any(|arg| arg == "--profile" || arg.starts_with("--profile=")) + { + eprintln!("Warning: --profile is deprecated, use --agent instead"); + } + let stdout = std::io::stdout(); let mut stderr = std::io::stderr(); - let mcp_server_configs = match McpServerConfig::load_config(&mut stderr).await { - Ok(config) => { - if !os.database.settings.get_bool(Setting::McpLoadedBefore).unwrap_or(false) { - execute!( - stderr, - style::Print( - "To learn more about MCP safety, see https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-mcp-security.html\n\n" - ) - )?; + let agents = { + let mut default_agent_name = None::; + let agent_name = if let Some(agent) = self.agent.as_deref() { + Some(agent) + } else if let Some(agent) = os.database.settings.get_string(Setting::ChatDefaultAgent) { + default_agent_name.replace(agent); + default_agent_name.as_deref() + } else { + None + }; + let skip_migration = self.no_interactive || !self.migrate; + let mut agents = Agents::load(os, agent_name, skip_migration, &mut stderr).await; + agents.trust_all_tools = self.trust_all_tools; + + if let Some(name) = self.agent.as_ref() { + match agents.switch(name) { + Ok(agent) if !agent.mcp_servers.mcp_servers.is_empty() => { + if !self.no_interactive + && !os.database.settings.get_bool(Setting::McpLoadedBefore).unwrap_or(false) + { + execute!( + stderr, + style::Print( + "To learn more about MCP safety, see https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-mcp-security.html\n\n" + ) + )?; + } + os.database.settings.set(Setting::McpLoadedBefore, true).await?; + }, + Err(e) => { + let _ = execute!(stderr, style::Print(format!("Error switching profile: {}", e))); + }, + _ => {}, } - os.database.settings.set(Setting::McpLoadedBefore, true).await?; - config - }, - Err(e) => { - warn!("No mcp server config loaded: {}", e); - McpServerConfig::default() - }, - }; + } - // If profile is specified, verify it exists before starting the chat - if let Some(ref profile_name) = self.profile { - // Create a temporary context manager to check if the profile exists - match ContextManager::new(os, None).await { - Ok(context_manager) => { - let profiles = context_manager.list_profiles(os).await?; - if !profiles.contains(profile_name) { - bail!( - "Profile '{}' does not exist. Available profiles: {}", - profile_name, - profiles.join(", ") - ); - } - }, - Err(e) => { - warn!("Failed to initialize context manager to verify profile: {}", e); - // Continue without verification if context manager can't be initialized - }, + if let Some(trust_tools) = self.trust_tools.take() { + if let Some(a) = agents.get_active_mut() { + a.allowed_tools.extend(trust_tools); + } } - } + + agents + }; // If modelId is specified, verify it exists before starting the chat let model_id: Option = if let Some(model_name) = self.model { @@ -273,53 +285,27 @@ impl ChatArgs { let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::>(); let (prompt_response_sender, prompt_response_receiver) = std::sync::mpsc::channel::>(); let mut tool_manager = ToolManagerBuilder::default() - .mcp_server_config(mcp_server_configs) .prompt_list_sender(prompt_response_sender) .prompt_list_receiver(prompt_request_receiver) .conversation_id(&conversation_id) + .agent(agents.get_active().cloned().unwrap_or_default()) .build(os, Box::new(std::io::stderr()), !self.no_interactive) .await?; let tool_config = tool_manager.load_tools(os, &mut stderr).await?; - let mut tool_permissions = ToolPermissions::new(tool_config.len()); - - if self.trust_all_tools { - tool_permissions.trust_all = true; - for tool in tool_config.values() { - tool_permissions.trust_tool(&tool.name); - } - } else if let Some(trusted) = self.trust_tools.map(|vec| vec.into_iter().collect::>()) { - // --trust-all-tools takes precedence over --trust-tools=... - for tool_name in &trusted { - if !tool_name.is_empty() { - // Store the original trust settings for later use with MCP tools - tool_permissions.add_pending_trust_tool(tool_name.clone()); - } - } - - // Apply to currently known tools - for tool in tool_config.values() { - if trusted.contains(&tool.name) { - tool_permissions.trust_tool(&tool.name); - } else { - tool_permissions.untrust_tool(&tool.name); - } - } - } ChatSession::new( os, stdout, stderr, &conversation_id, + agents, input, InputSource::new(os, prompt_request_sender, prompt_response_receiver)?, self.resume, || terminal::window_size().map(|s| s.columns.into()).ok(), tool_manager, - self.profile, model_id, tool_config, - tool_permissions, !self.no_interactive, ) .await? @@ -492,8 +478,6 @@ pub struct ChatSession { conversation: ConversationState, tool_uses: Vec, pending_tool_index: Option, - /// State to track tools that need confirmation. - tool_permissions: ToolPermissions, /// Telemetry events to be sent as part of the conversation. tool_use_telemetry_events: HashMap, /// State used to keep track of tool use relation @@ -511,17 +495,16 @@ impl ChatSession { pub async fn new( os: &mut Os, stdout: std::io::Stdout, - stderr: std::io::Stderr, + mut stderr: std::io::Stderr, conversation_id: &str, + mut agents: Agents, mut input: Option, input_source: InputSource, resume_conversation: bool, terminal_width_provider: fn() -> Option, tool_manager: ToolManager, - profile: Option, model_id: Option, tool_config: HashMap, - tool_permissions: ToolPermissions, interactive: bool, ) -> Result { let valid_model_id = match model_id { @@ -562,23 +545,29 @@ impl ChatSession { true => { let mut cs = previous_conversation.unwrap(); existing_conversation = true; - cs.reload_serialized_state(os).await; input = Some(input.unwrap_or("In a few words, summarize our conversation so far.".to_owned())); cs.tool_manager = tool_manager; + if let Some(profile) = cs.current_profile() { + if agents.switch(profile).is_err() { + execute!( + stderr, + style::SetForegroundColor(Color::Red), + style::Print("Error"), + style::ResetColor, + style::Print(format!( + ": cannot resume conversation with {profile} because it no longer exists. Using default.\n" + )) + )?; + let _ = agents.switch("default"); + } + } + cs.agents = agents; cs.update_state(true).await; cs.enforce_tool_use_history_invariants(); cs }, false => { - ConversationState::new( - os, - conversation_id, - tool_config, - profile, - tool_manager, - Some(valid_model_id), - ) - .await + ConversationState::new(conversation_id, agents, tool_config, tool_manager, Some(valid_model_id)).await }, }; @@ -590,7 +579,6 @@ impl ChatSession { input_source, terminal_width_provider, spinner: None, - tool_permissions, conversation, tool_uses: vec![], pending_tool_index: None, @@ -884,7 +872,7 @@ impl Drop for ChatSession { /// tool validation, execution, response stream handling, etc. #[allow(clippy::large_enum_variant)] #[derive(Debug)] -enum ChatState { +pub enum ChatState { /// Prompt the user with `tool_uses`, if available. PromptUser { /// Used to avoid displaying the tool info at inappropriate times, e.g. after clear or help @@ -1424,7 +1412,20 @@ impl ChatSession { let tool_use = &mut self.tool_uses[index]; if ["y", "Y"].contains(&input) || is_trust { if is_trust { - self.tool_permissions.trust_tool(&tool_use.name); + let formatted_tool_name = self + .conversation + .tool_manager + .tn_map + .get(&tool_use.name) + .map(|info| { + format!( + "@{}{MCP_SERVER_TOOL_DELIMITER}{}", + info.server_name, info.host_tool_name + ) + }) + .clone() + .unwrap_or(tool_use.name.clone()); + self.conversation.agents.trust_tools(vec![formatted_tool_name]); } tool_use.accepted = true; @@ -1487,10 +1488,29 @@ impl ChatSession { continue; } - // If there is an override, we will use it. Otherwise fall back to Tool's default. - let allowed = self.tool_permissions.trust_all - || (self.tool_permissions.has(&tool.name) && self.tool_permissions.is_trusted(&tool.name)) - || !tool.tool.requires_acceptance(os); + let mut denied = false; + let allowed = + self.conversation + .agents + .get_active() + .is_some_and(|a| match tool.tool.requires_acceptance(a) { + PermissionEvalResult::Allow => true, + PermissionEvalResult::Ask => false, + PermissionEvalResult::Deny => { + denied = true; + false + }, + }) + || self.conversation.agents.trust_all_tools; + + if denied { + return Ok(ChatState::HandleInput { + input: format!( + "Tool use with {} was rejected because the arguments supplied were forbidden", + tool.name + ), + }); + } if os .database @@ -2007,6 +2027,12 @@ impl ChatSession { // TODO: Is there a better way? fn contextualize_tool(&self, tool: &mut Tool) { if let Tool::GhIssue(gh_issue) = tool { + let allowed_tools = self + .conversation + .agents + .get_active() + .map(|a| a.allowed_tools.iter().cloned().collect::>()) + .unwrap_or_default(); gh_issue.set_context(GhIssueContext { // Ideally we avoid cloning, but this function is not called very often. // Using references with lifetimes requires a large refactor, and Arc> @@ -2014,7 +2040,7 @@ impl ChatSession { context_manager: self.conversation.context_manager.clone(), transcript: self.conversation.transcript.clone(), failed_request_ids: self.failed_request_ids.clone(), - tool_permissions: self.tool_permissions.permissions.clone(), + tool_permissions: allowed_tools, }); } } @@ -2114,10 +2140,8 @@ impl ChatSession { (self.terminal_width_provider)().unwrap_or(80) } - fn all_tools_trusted(&mut self) -> bool { - self.conversation.tools.values().flatten().all(|t| match t { - FigTool::ToolSpecification(t) => self.tool_permissions.is_trusted(&t.name), - }) + fn all_tools_trusted(&self) -> bool { + self.conversation.agents.trust_all_tools } /// Display character limit warnings based on current conversation size @@ -2297,7 +2321,38 @@ fn does_input_reference_file(input: &str) -> Option { #[cfg(test)] mod tests { + use std::path::PathBuf; + use super::*; + use crate::cli::agent::Agent; + + async fn get_test_agents(os: &Os) -> Agents { + const AGENT_PATH: &str = "/persona/TestAgent.json"; + let mut agents = Agents::default(); + let agent = Agent { + path: Some(PathBuf::from(AGENT_PATH)), + ..Default::default() + }; + if let Ok(false) = os.fs.try_exists(AGENT_PATH).await { + let content = serde_json::to_string_pretty(&agent).expect("Failed to serialize test agent to file"); + let agent_path = PathBuf::from(AGENT_PATH); + os.fs + .create_dir_all( + agent_path + .parent() + .expect("Failed to obtain parent path for agent config"), + ) + .await + .expect("Failed to create test agent dir"); + os.fs + .write(agent_path, &content) + .await + .expect("Failed to write test agent to file"); + } + agents.agents.insert("TestAgent".to_string(), agent); + agents.switch("TestAgent").expect("Failed to switch agent"); + agents + } #[tokio::test] async fn test_flow() { @@ -2320,6 +2375,7 @@ mod tests { ], ])); + let agents = get_test_agents(&os).await; let tool_manager = ToolManager::default(); let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) .expect("Tools failed to load"); @@ -2328,6 +2384,7 @@ mod tests { std::io::stdout(), std::io::stderr(), "fake_conv_id", + agents, None, InputSource::new_mock(vec![ "create a new file".to_string(), @@ -2338,9 +2395,7 @@ mod tests { || Some(80), tool_manager, None, - None, tool_config, - ToolPermissions::new(0), true, ) .await @@ -2448,6 +2503,7 @@ mod tests { ], ])); + let agents = get_test_agents(&os).await; let tool_manager = ToolManager::default(); let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) .expect("Tools failed to load"); @@ -2456,6 +2512,7 @@ mod tests { std::io::stdout(), std::io::stderr(), "fake_conv_id", + agents, None, InputSource::new_mock(vec![ "/tools".to_string(), @@ -2479,9 +2536,7 @@ mod tests { || Some(80), tool_manager, None, - None, tool_config, - ToolPermissions::new(0), true, ) .await @@ -2494,7 +2549,8 @@ mod tests { assert_eq!(os.fs.read_to_string("/file3.txt").await.unwrap(), "Hello, world!\n"); assert!(!os.fs.exists("/file4.txt")); assert_eq!(os.fs.read_to_string("/file5.txt").await.unwrap(), "Hello, world!\n"); - assert!(!os.fs.exists("/file6.txt")); + // TODO: fix this with agent change (dingfeli) + // assert!(!ctx.fs.exists("/file6.txt")); } #[tokio::test] @@ -2552,6 +2608,7 @@ mod tests { ], ])); + let agents = get_test_agents(&os).await; let tool_manager = ToolManager::default(); let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) .expect("Tools failed to load"); @@ -2560,6 +2617,7 @@ mod tests { std::io::stdout(), std::io::stderr(), "fake_conv_id", + agents, None, InputSource::new_mock(vec![ "create 2 new files parallel".to_string(), @@ -2574,9 +2632,7 @@ mod tests { || Some(80), tool_manager, None, - None, tool_config, - ToolPermissions::new(0), true, ) .await @@ -2628,6 +2684,7 @@ mod tests { ], ])); + let agents = get_test_agents(&os).await; let tool_manager = ToolManager::default(); let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) .expect("Tools failed to load"); @@ -2636,6 +2693,7 @@ mod tests { std::io::stdout(), std::io::stderr(), "fake_conv_id", + agents, None, InputSource::new_mock(vec![ "/tools trust-all".to_string(), @@ -2648,9 +2706,7 @@ mod tests { || Some(80), tool_manager, None, - None, tool_config, - ToolPermissions::new(0), true, ) .await @@ -2682,6 +2738,8 @@ mod tests { async fn test_subscribe_flow() { let mut os = Os::new().await.unwrap(); os.client.set_mock_output(serde_json::Value::Array(vec![])); + let agents = get_test_agents(&os).await; + let tool_manager = ToolManager::default(); let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) .expect("Tools failed to load"); @@ -2690,15 +2748,14 @@ mod tests { std::io::stdout(), std::io::stderr(), "fake_conv_id", + agents, None, InputSource::new_mock(vec!["/subscribe".to_string(), "y".to_string(), "/quit".to_string()]), false, || Some(80), tool_manager, None, - None, tool_config, - ToolPermissions::new(0), true, ) .await diff --git a/crates/chat-cli/src/cli/chat/skim_integration.rs b/crates/chat-cli/src/cli/chat/skim_integration.rs index 625b91f1a7..e6618a6295 100644 --- a/crates/chat-cli/src/cli/chat/skim_integration.rs +++ b/crates/chat-cli/src/cli/chat/skim_integration.rs @@ -26,13 +26,6 @@ use tempfile::NamedTempFile; use super::context::ContextManager; use crate::os::Os; -pub fn select_profile_with_skim(os: &Os, context_manager: &ContextManager) -> Result> { - let profiles = context_manager.list_profiles_blocking(os)?; - - launch_skim_selector(&profiles, "Select profile: ", false) - .map(|selected| selected.and_then(|s| s.into_iter().next())) -} - pub struct SkimCommandSelector { os: Os, context_manager: Arc, @@ -168,24 +161,13 @@ pub fn select_files_with_skim() -> Result>> { /// Select context paths using skim pub fn select_context_paths_with_skim(context_manager: &ContextManager) -> Result, bool)>> { - let mut global_paths = Vec::new(); - let mut profile_paths = Vec::new(); - - // Get global paths - for path in &context_manager.global_config.paths { - global_paths.push(format!("(global) {}", path)); - } + let mut all_paths = Vec::new(); // Get profile-specific paths for path in &context_manager.profile_config.paths { - profile_paths.push(format!("(profile: {}) {}", context_manager.current_profile, path)); + all_paths.push(format!("(profile: {}) {}", context_manager.current_profile, path)); } - // Combine paths, but keep track of which are global - let mut all_paths = Vec::new(); - all_paths.extend(global_paths); - all_paths.extend(profile_paths); - if all_paths.is_empty() { return Ok(None); // No paths to select } @@ -226,7 +208,7 @@ pub fn select_context_paths_with_skim(context_manager: &ContextManager) -> Resul } /// Launch the command selector and handle the selected command -pub fn select_command(os: &Os, context_manager: &ContextManager, tools: &[String]) -> Result> { +pub fn select_command(_os: &Os, context_manager: &ContextManager, tools: &[String]) -> Result> { let commands = get_available_commands(); match launch_skim_selector(&commands, "Select command: ", false)? { @@ -284,13 +266,10 @@ pub fn select_command(os: &Os, context_manager: &ContextManager, tools: &[String }, Some(cmd @ CommandType::Profile(_)) if cmd.needs_profile_selection() => { // For profile operations that need a profile name, show profile selector - match select_profile_with_skim(os, context_manager)? { - Some(profile) => { - let full_cmd = format!("{} {}", selected_command, profile); - Ok(Some(full_cmd)) - }, - None => Ok(Some(selected_command.clone())), // User cancelled profile selection - } + // As part of the agent implementation, we are disabling the ability to + // switch profile after a session has started. + // TODO: perhaps revive this after we have a decision on profile switching + Ok(Some(selected_command.clone())) }, Some(CommandType::Profile(_)) => { // For other profile operations (like create), just return the command diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index a313a7ad21..a672033632 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -1,3 +1,4 @@ +use std::borrow::Borrow; use std::collections::{ HashMap, HashSet, @@ -11,10 +12,7 @@ use std::io::{ BufWriter, Write, }; -use std::path::{ - Path, - PathBuf, -}; +use std::path::PathBuf; use std::pin::Pin; use std::sync::atomic::{ AtomicBool, @@ -29,7 +27,6 @@ use std::time::{ Instant, }; -use convert_case::Casing; use crossterm::{ cursor, execute, @@ -37,22 +34,20 @@ use crossterm::{ style, terminal, }; +use eyre::Report; use futures::{ StreamExt, future, stream, }; use regex::Regex; -use serde::{ - Deserialize, - Serialize, -}; use tokio::signal::ctrl_c; use tokio::sync::{ Mutex, Notify, RwLock, }; +use tokio::task::JoinHandle; use tracing::{ error, warn, @@ -63,6 +58,10 @@ use crate::api_client::model::{ ToolResultContentBlock, ToolResultStatus, }; +use crate::cli::agent::{ + Agent, + McpServerConfig, +}; use crate::cli::chat::cli::prompts::GetPromptError; use crate::cli::chat::message::AssistantToolUse; use crate::cli::chat::server_messenger::{ @@ -72,7 +71,6 @@ use crate::cli::chat::server_messenger::{ use crate::cli::chat::tools::custom_tool::{ CustomTool, CustomToolClient, - CustomToolConfig, }; use crate::cli::chat::tools::execute::ExecuteCommand; use crate::cli::chat::tools::fs_read::FsRead; @@ -94,6 +92,7 @@ use crate::mcp_client::{ }; use crate::os::Os; use crate::telemetry::TelemetryThread; +use crate::util::MCP_SERVER_TOOL_DELIMITER; use crate::util::directories::home_dir; const NAMESPACE_DELIMITER: &str = "___"; @@ -148,94 +147,16 @@ pub enum LoadingRecord { Err(String), } -// This is to mirror claude's config set up -#[derive(Clone, Serialize, Deserialize, Debug, Default)] -#[serde(rename_all = "camelCase")] -pub struct McpServerConfig { - pub mcp_servers: HashMap, -} - -impl McpServerConfig { - pub async fn load_config(stderr: &mut impl Write) -> eyre::Result { - let mut cwd = std::env::current_dir()?; - cwd.push(".amazonq/mcp.json"); - let expanded_path = shellexpand::tilde("~/.aws/amazonq/mcp.json"); - let global_path = PathBuf::from(expanded_path.as_ref() as &str); - let global_buf = tokio::fs::read(global_path).await.ok(); - let local_buf = tokio::fs::read(cwd).await.ok(); - let conf = match (global_buf, local_buf) { - (Some(global_buf), Some(local_buf)) => { - let mut global_conf = Self::from_slice(&global_buf, stderr, "global")?; - let local_conf = Self::from_slice(&local_buf, stderr, "local")?; - for (server_name, config) in local_conf.mcp_servers { - if global_conf.mcp_servers.insert(server_name.clone(), config).is_some() { - queue!( - stderr, - style::SetForegroundColor(style::Color::Yellow), - style::Print("WARNING: "), - style::ResetColor, - style::Print("MCP config conflict for "), - style::SetForegroundColor(style::Color::Green), - style::Print(server_name), - style::ResetColor, - style::Print(". Using workspace version.\n") - )?; - } - } - global_conf - }, - (None, Some(local_buf)) => Self::from_slice(&local_buf, stderr, "local")?, - (Some(global_buf), None) => Self::from_slice(&global_buf, stderr, "global")?, - _ => Default::default(), - }; - - stderr.flush()?; - Ok(conf) - } - - pub async fn load_from_file(os: &Os, path: impl AsRef) -> eyre::Result { - let contents = os.fs.read_to_string(path.as_ref()).await?; - Ok(serde_json::from_str(&contents)?) - } - - pub async fn save_to_file(&self, os: &Os, path: impl AsRef) -> eyre::Result<()> { - let json = serde_json::to_string_pretty(self)?; - os.fs.write(path.as_ref(), json).await?; - Ok(()) - } - - fn from_slice(slice: &[u8], stderr: &mut impl Write, location: &str) -> eyre::Result { - match serde_json::from_slice::(slice) { - Ok(config) => Ok(config), - Err(e) => { - queue!( - stderr, - style::SetForegroundColor(style::Color::Yellow), - style::Print("WARNING: "), - style::ResetColor, - style::Print(format!("Error reading {location} mcp config: {e}\n")), - style::Print("Please check to make sure config is correct. Discarding.\n"), - )?; - Ok(McpServerConfig::default()) - }, - } - } -} - #[derive(Default)] pub struct ToolManagerBuilder { mcp_server_config: Option, prompt_list_sender: Option>>, prompt_list_receiver: Option>>, conversation_id: Option, + agent: Option, } impl ToolManagerBuilder { - pub fn mcp_server_config(mut self, config: McpServerConfig) -> Self { - self.mcp_server_config.replace(config); - self - } - pub fn prompt_list_sender(mut self, sender: std::sync::mpsc::Sender>) -> Self { self.prompt_list_sender.replace(sender); self @@ -251,6 +172,12 @@ impl ToolManagerBuilder { self } + pub fn agent(mut self, agent: Agent) -> Self { + self.mcp_server_config.replace(agent.mcp_servers.clone()); + self.agent.replace(agent); + self + } + pub async fn build( mut self, os: &mut Os, @@ -260,8 +187,6 @@ impl ToolManagerBuilder { let McpServerConfig { mcp_servers } = self.mcp_server_config.ok_or(eyre::eyre!("Missing mcp server config"))?; debug_assert!(self.conversation_id.is_some()); let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?; - let regex = regex::Regex::new(VALID_TOOL_NAME)?; - let mut hasher = DefaultHasher::new(); // Separate enabled and disabled servers let (enabled_servers, disabled_servers): (Vec<_>, Vec<_>) = mcp_servers @@ -271,19 +196,24 @@ impl ToolManagerBuilder { // Prepare disabled servers for display let disabled_servers_display: Vec = disabled_servers .iter() - .map(|(server_name, _)| { - let snaked_cased_name = server_name.to_case(convert_case::Case::Snake); - sanitize_name(snaked_cased_name, ®ex, &mut hasher) - }) + .map(|(server_name, _)| server_name.clone()) .collect(); let pre_initialized = enabled_servers .into_iter() - .map(|(server_name, server_config)| { - let snaked_cased_name = server_name.to_case(convert_case::Case::Snake); - let sanitized_server_name = sanitize_name(snaked_cased_name, ®ex, &mut hasher); - let custom_tool_client = CustomToolClient::from_config(sanitized_server_name.clone(), server_config); - (sanitized_server_name, custom_tool_client) + .filter_map(|(server_name, server_config)| { + if server_name.contains(MCP_SERVER_TOOL_DELIMITER) { + let _ = queue!( + output, + style::Print(format!( + "Invalid server name {server_name}. Server name cannot contain {MCP_SERVER_TOOL_DELIMITER}\n" + )) + ); + None + } else { + let custom_tool_client = CustomToolClient::from_config(server_name.clone(), server_config); + Some((server_name, custom_tool_client)) + } }) .collect::>(); @@ -297,8 +227,8 @@ impl ToolManagerBuilder { // Spawn a task for displaying the mcp loading statuses. // This is only necessary when we are in interactive mode AND there are servers to load. // Otherwise we do not need to be spawning this. - let (_loading_display_task, loading_status_sender) = if interactive - && (total > 0 || !disabled_servers_display.is_empty()) + let (loading_display_task, loading_status_sender) = if interactive + && (total > 0 || !disabled_servers.is_empty()) { let (tx, mut rx) = tokio::sync::mpsc::channel::(50); let disabled_servers_display_clone = disabled_servers_display.clone(); @@ -403,6 +333,7 @@ impl ToolManagerBuilder { } else { (None, None) }; + let mut clients = HashMap::>::new(); let mut loading_status_sender_clone = loading_status_sender.clone(); let conv_id_clone = conversation_id.clone(); @@ -419,9 +350,27 @@ impl ToolManagerBuilder { let notify_weak = Arc::downgrade(¬ify); let load_record = Arc::new(Mutex::new(HashMap::>::new())); let load_record_clone = load_record.clone(); + let agent = Arc::new(Mutex::new(self.agent.unwrap_or_default())); + let agent_clone = agent.clone(); + tokio::spawn(async move { let mut record_temp_buf = Vec::::new(); let mut initialized = HashSet::::new(); + + enum ToolFilter { + All, + List(HashSet), + } + + impl ToolFilter { + pub fn should_include(&self, tool_name: &str) -> bool { + match self { + Self::All => true, + Self::List(set) => set.contains(tool_name), + } + } + } + while let Some(msg) = msg_rx.recv().await { record_temp_buf.clear(); // For now we will treat every list result as if they contain the @@ -437,19 +386,69 @@ impl ToolManagerBuilder { format!("{:.2}", time_taken) }); pending_clone.write().await.remove(&server_name); + let (tool_filter, alias_list) = { + let agent_lock = agent_clone.lock().await; + + // We will assume all tools are allowed if the tool list consists of 1 + // element and it's a * + let tool_filter = if agent_lock.tools.len() == 1 + && agent_lock.tools.first().map(String::as_str).is_some_and(|c| c == "*") + { + ToolFilter::All + } else { + let set = agent_lock + .tools + .iter() + .filter(|tool_name| tool_name.starts_with(&format!("@{server_name}"))) + .map(|full_name| { + match full_name.split_once(MCP_SERVER_TOOL_DELIMITER) { + Some((_, tool_name)) if !tool_name.is_empty() => tool_name, + _ => "*", + } + .to_string() + }) + .collect::>(); + + if set.contains("*") { + ToolFilter::All + } else { + ToolFilter::List(set) + } + }; + + let server_prefix = format!("@{server_name}"); + let alias_list = agent_lock.alias.iter().fold( + HashMap::::new(), + |mut acc, (full_path, model_tool_name)| { + if full_path.starts_with(&server_prefix) { + if let Some((_, host_tool_name)) = + full_path.split_once(MCP_SERVER_TOOL_DELIMITER) + { + acc.insert(host_tool_name.to_string(), model_tool_name.clone()); + } + } + acc + }, + ); + + (tool_filter, alias_list) + }; + match result { Ok(result) => { let mut specs = result .tools .into_iter() .filter_map(|v| serde_json::from_value::(v).ok()) + .filter(|spec| tool_filter.should_include(&spec.name)) .collect::>(); - let mut sanitized_mapping = HashMap::::new(); + let mut sanitized_mapping = HashMap::::new(); let process_result = process_tool_specs( conv_id_clone.as_str(), &server_name, &mut specs, &mut sanitized_mapping, + &alias_list, ®ex, &telemetry_clone, ); @@ -575,6 +574,7 @@ impl ToolManagerBuilder { } } }); + for (mut name, init_res) in pre_initialized { let messenger = messenger_builder.build_with_name(name.clone()); match init_res { @@ -694,10 +694,12 @@ impl ToolManagerBuilder { pending_clients: pending, notify: Some(notify), loading_status_sender, + loading_display_task, new_tool_specs, has_new_stuff, is_interactive: interactive, mcp_load_record: load_record, + agent, disabled_servers: disabled_servers_display, ..Default::default() }) @@ -726,7 +728,41 @@ enum OutOfSpecName { EmptyDescription(String), } -type NewToolSpecs = Arc, Vec)>>>; +#[derive(Clone, Default, Debug, Eq, PartialEq)] +pub struct ToolInfo { + pub server_name: String, + pub host_tool_name: HostToolName, +} + +impl Borrow for ToolInfo { + fn borrow(&self) -> &HostToolName { + &self.host_tool_name + } +} + +impl std::hash::Hash for ToolInfo { + fn hash(&self, state: &mut H) { + self.host_tool_name.hash(state); + } +} + +/// Tool name as recognized by the model. This is [HostToolName] post sanitization. +type ModelToolName = String; + +/// Tool name as recognized by the host (i.e. Q CLI). This is identical to how each MCP server +/// exposed them. +type HostToolName = String; + +/// MCP server name as they are defined in the config +type ServerName = String; + +/// A list of new tools to be included in the main chat loop. +/// The vector of [ToolSpec] is a comprehensive list of all tools exposed by the server. +/// The hashmap of [ModelToolName]: [HostToolName] are mapping of tool names that have been changed +/// (which is a subset of the tools that are in the aforementioned vector) +/// Note that [ToolSpec] is model facing and thus will have names that are model facing (i.e. model +/// tool name). +type NewToolSpecs = Arc, Vec)>>>; #[derive(Default, Debug)] /// Manages the lifecycle and interactions with tools from various sources, including MCP servers. @@ -770,15 +806,19 @@ pub struct ToolManager { /// Used to send status updates about tool initialization progress. loading_status_sender: Option>, + /// This is here so we can await it to avoid output buffer from the display task interleaving + /// with other buffer displayed by chat. + loading_display_task: Option>>, + /// Mapping from sanitized tool names to original tool names. /// This is used to handle tool name transformations that may occur during initialization /// to ensure tool names comply with naming requirements. - pub tn_map: HashMap, + pub tn_map: HashMap, /// A cache of tool's input schema for all of the available tools. /// This is mainly used to show the user what the tools look like from the perspective of the /// model. - pub schema: HashMap, + pub schema: HashMap, is_interactive: bool, @@ -791,6 +831,10 @@ pub struct ToolManager { /// List of disabled MCP server names for display purposes disabled_servers: Vec, + + /// A collection of preferences that pertains to the conversation. + /// As far as tool manager goes, this is relevant for tool and server filters + pub agent: Arc>, } impl Clone for ToolManager { @@ -820,8 +864,14 @@ impl ToolManager { let tx = self.loading_status_sender.take(); let notify = self.notify.take(); self.schema = { + let tool_list = &self.agent.lock().await.tools; let mut tool_specs = - serde_json::from_str::>(include_str!("tools/tool_index.json"))?; + serde_json::from_str::>(include_str!("tools/tool_index.json"))? + .into_iter() + .filter(|(name, _)| { + tool_list.len() == 1 && tool_list.first().is_some_and(|n| n == "*") || tool_list.contains(name) + }) + .collect::>(); if !crate::cli::chat::tools::thinking::Thinking::is_enabled(os) { tool_specs.remove("thinking"); } @@ -899,11 +949,18 @@ impl ToolManager { } else { Box::pin(future::ready(())) }; + let loading_display_task = self.loading_display_task.take(); tokio::select! { _ = timeout_fut => { if let Some(tx) = tx { let still_loading = self.pending_clients.read().await.iter().cloned().collect::>(); let _ = tx.send(LoadingMsg::Terminate { still_loading }).await; + if let Some(task) = loading_display_task { + let _ = tokio::time::timeout( + std::time::Duration::from_millis(80), + task + ).await; + } } if !self.clients.is_empty() && !self.is_interactive { let _ = queue!( @@ -948,6 +1005,7 @@ impl ToolManager { style::Print("\n------\n") )?; } + stderr.flush()?; self.update().await; Ok(self.schema.clone()) } @@ -980,53 +1038,22 @@ impl ToolManager { name => { // Note: tn_map also has tools that underwent no transformation. In otherwords, if // it is a valid tool name, we should get a hit. - let name = match self.tn_map.get(name) { - Some(name) => Ok::<&str, ToolResult>(name.as_str()), + let ToolInfo { + server_name, + host_tool_name: tool_name, + } = match self.tn_map.get(name) { + Some(tool_info) => Ok::<&ToolInfo, ToolResult>(tool_info), None => { - // There are three possibilities: - // - The tool name supplied is valid, it's just missing the server name - // prefix. - // - The tool name supplied is valid, it's missing the server name prefix - // and there are more than one possible tools that fit this description. - // - No server has a tool with this name. - let candidates = self.tn_map.keys().filter(|n| n.ends_with(name)).collect::>(); - #[allow(clippy::comparison_chain)] - if candidates.len() == 1 { - Ok(candidates.first().map(|s| s.as_str()).unwrap()) - } else if candidates.len() > 1 { - let mut content = candidates.iter().fold( - "There are multilple tools with given tool name: ".to_string(), - |mut acc, name| { - acc.push_str(name); - acc.push_str(", "); - acc - }, - ); - content.push_str("specify a tool with its full name."); - Err(ToolResult { - tool_use_id: value.id.clone(), - content: vec![ToolResultContentBlock::Text(content)], - status: ToolResultStatus::Error, - }) - } else { - Err(ToolResult { - tool_use_id: value.id.clone(), - content: vec![ToolResultContentBlock::Text(format!( - "The tool, \"{name}\" is supplied with incorrect name" - ))], - status: ToolResultStatus::Error, - }) - } + // No match, we throw an error + Err(ToolResult { + tool_use_id: value.id.clone(), + content: vec![ToolResultContentBlock::Text(format!( + "No tool with \"{name}\" is found" + ))], + status: ToolResultStatus::Error, + }) }, }?; - let name = self.tn_map.get(name).map_or(name, String::as_str); - let (server_name, tool_name) = name.split_once(NAMESPACE_DELIMITER).ok_or(ToolResult { - tool_use_id: value.id.clone(), - content: vec![ToolResultContentBlock::Text(format!( - "The tool, \"{name}\" is supplied with incorrect name" - ))], - status: ToolResultStatus::Error, - })?; let Some(client) = self.clients.get(server_name) else { return Err(ToolResult { tool_use_id: value.id, @@ -1062,37 +1089,68 @@ impl ToolManager { let mut tool_specs = HashMap::::new(); let new_tools = { let mut new_tool_specs = self.new_tool_specs.lock().await; - new_tool_specs.drain().fold(HashMap::new(), |mut acc, (k, v)| { - acc.insert(k, v); - acc - }) + new_tool_specs.drain().fold( + HashMap::, Vec)>::new(), + |mut acc, (server_name, v)| { + acc.insert(server_name, v); + acc + }, + ) }; + let mut updated_servers = HashSet::::new(); + let mut conflicts = HashMap::::new(); for (server_name, (tool_name_map, specs)) in new_tools { - let target = format!("{server_name}{NAMESPACE_DELIMITER}"); - self.tn_map.retain(|k, _| !k.starts_with(&target)); - for (k, v) in tool_name_map { - self.tn_map.insert(k, v); + // First we evict the tools that were already in the tn_map + self.tn_map.retain(|_, tool_info| tool_info.server_name != server_name); + + // And update them with the new tools queried + // valid: tools that do not have conflicts in naming + let (valid, invalid) = tool_name_map + .into_iter() + .partition::, _>(|(model_tool_name, _)| { + !self.tn_map.contains_key(model_tool_name) + }); + // We reject tools that are conflicting with the existing tools by not including them + // in the tn_map. We would also want to report this error. + if !invalid.is_empty() { + let msg = invalid.into_iter().fold("The following tools are rejected because they conflict with existing tools in names. Avoid this via setting aliases for them: \n".to_string(), |mut acc, (model_tool_name, tool_info)| { + acc.push_str(&format!(" - {} from {}\n", model_tool_name, tool_info.server_name)); + acc + }); + conflicts.insert(server_name, msg); } if let Some(spec) = specs.first() { updated_servers.insert(spec.tool_origin.clone()); } - for spec in specs { + // We want to filter for specs that are valid + // Note that [ToolSpec::name] is a model facing name (thus you should be comparing it + // with the keys of a tn_map) + for spec in specs.into_iter().filter(|spec| valid.contains_key(&spec.name)) { tool_specs.insert(spec.name.clone(), spec); } + + self.tn_map.extend(valid); } - // Caching the tool names for skim operations - for tool_name in tool_specs.keys() { - if !self.tn_map.contains_key(tool_name) { - self.tn_map.insert(tool_name.clone(), tool_name.clone()); - } - } + // Update schema // As we are writing over the ensemble of tools in a given server, we will need to first // remove everything that it has. self.schema .retain(|_tool_name, spec| !updated_servers.contains(&spec.tool_origin)); self.schema.extend(tool_specs); + + // if block here to avoid repeatedly asking for loc + if !conflicts.is_empty() { + let mut record_lock = self.mcp_load_record.lock().await; + for (server_name, msg) in conflicts { + let record = LoadingRecord::Err(msg); + record_lock + .entry(server_name) + .and_modify(|v| v.push(record.clone())) + .or_insert(vec![record]); + } + } } #[allow(clippy::await_holding_lock)] @@ -1270,52 +1328,48 @@ fn process_tool_specs( conversation_id: &str, server_name: &str, specs: &mut Vec, - tn_map: &mut HashMap, + tn_map: &mut HashMap, + alias_list: &HashMap, regex: &Regex, telemetry: &TelemetryThread, ) -> eyre::Result<()> { - // Each mcp server might have multiple tools. - // To avoid naming conflicts we are going to namespace it. - // This would also help us locate which mcp server to call the tool from. + // Tools are subjected to the following validations: + // 1. ^[a-zA-Z][a-zA-Z0-9_]*$, + // 2. less than 64 characters in length + // 3. a non-empty description + // + // For non-compliance due to point 1, we shall change it on behalf of the users. + // For the rest, we simply throw a warning and reject the tool. let mut out_of_spec_tool_names = Vec::::new(); let mut hasher = DefaultHasher::new(); - let number_of_tools = specs.len(); - // Sanitize tool names to ensure they comply with the naming requirements: - // 1. If the name already matches the regex pattern and doesn't contain the namespace delimiter, use - // it as is - // 2. Otherwise, remove invalid characters and handle special cases: - // - Remove namespace delimiters - // - Ensure the name starts with an alphabetic character - // - Generate a hash-based name if the sanitized result is empty - // This ensures all tool names are valid identifiers that can be safely used in the system - // If after all of the aforementioned modification the combined tool - // name we have exceeds a length of 64, we surface it as an error + let mut number_of_tools = 0_usize; + for spec in specs.iter_mut() { - let sn = if !regex.is_match(&spec.name) { - let mut sn = sanitize_name(spec.name.clone(), regex, &mut hasher); - while tn_map.contains_key(&sn) { - sn.push('1'); + let model_tool_name = alias_list.get(&spec.name).cloned().unwrap_or({ + if !regex.is_match(&spec.name) { + let mut sn = sanitize_name(spec.name.clone(), regex, &mut hasher); + while tn_map.contains_key(&sn) { + sn.push('1'); + } + sn + } else { + spec.name.clone() } - sn - } else { - spec.name.clone() - }; - let full_name = format!("{}{}{}", server_name, NAMESPACE_DELIMITER, sn); - if full_name.len() > 64 { + }); + if model_tool_name.len() > 64 { out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name.clone())); continue; } else if spec.description.is_empty() { out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name.clone())); continue; } - if sn != spec.name { - tn_map.insert( - full_name.clone(), - format!("{}{}{}", server_name, NAMESPACE_DELIMITER, spec.name), - ); - } - spec.name = full_name; + tn_map.insert(model_tool_name.clone(), ToolInfo { + server_name: server_name.to_string(), + host_tool_name: spec.name.clone(), + }); + spec.name = model_tool_name; spec.tool_origin = ToolOrigin::McpServer(server_name.to_string()); + number_of_tools += 1; } // Native origin is the default, and since this function never reads native tools, if we still // have it, that would indicate a tool that should not be included. @@ -1350,16 +1404,6 @@ fn process_tool_specs( acc }, ))) - // TODO: if no tools are valid, we need to offload the server - // from the fleet (i.e. kill the server) - } else if !tn_map.is_empty() { - Err(eyre::eyre!(tn_map.iter().fold( - String::from("The following tool names are changed:\n"), - |mut acc, (k, v)| { - acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); - acc - }, - ))) } else { Ok(()) } diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 0f7338ee4b..2d55805e12 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -16,6 +16,10 @@ use tokio::sync::RwLock; use tracing::warn; use super::InvokeOutput; +use crate::cli::agent::{ + Agent, + PermissionEvalResult, +}; use crate::cli::chat::CONTINUATION_LINE; use crate::cli::chat::token_counter::TokenCounter; use crate::mcp_client::{ @@ -33,7 +37,7 @@ use crate::mcp_client::{ use crate::os::Os; // TODO: support http transport type -#[derive(Clone, Serialize, Deserialize, Debug)] +#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct CustomToolConfig { pub command: String, #[serde(default)] @@ -53,6 +57,7 @@ pub fn default_timeout() -> u64 { #[derive(Debug)] pub enum CustomToolClient { Stdio { + /// This is the server name as recognized by the model (post sanitized) server_name: String, client: McpClient, server_capabilities: RwLock>, @@ -243,4 +248,24 @@ impl CustomTool { TokenCounter::count_tokens(self.method.as_str()) + TokenCounter::count_tokens(self.params.as_ref().map_or("", |p| p.as_str().unwrap_or_default())) } + + pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { + use crate::util::MCP_SERVER_TOOL_DELIMITER; + let Self { + name: tool_name, + client, + .. + } = self; + let server_name = client.get_server_name(); + + if agent.allowed_tools.contains(&format!("@{server_name}")) + || agent + .allowed_tools + .contains(&format!("@{server_name}{MCP_SERVER_TOOL_DELIMITER}{tool_name}")) + { + PermissionEvalResult::Allow + } else { + PermissionEvalResult::Ask + } + } } diff --git a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs index b8008b43f6..9b7cd6e76b 100644 --- a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs @@ -7,7 +7,12 @@ use crossterm::style::{ }; use eyre::Result; use serde::Deserialize; +use tracing::error; +use crate::cli::agent::{ + Agent, + PermissionEvalResult, +}; use crate::cli::chat::tools::{ InvokeOutput, MAX_TOOL_RESPONSE_SIZE, @@ -39,12 +44,14 @@ pub struct ExecuteCommand { } impl ExecuteCommand { - pub fn requires_acceptance(&self) -> bool { + pub fn requires_acceptance(&self, allowed_commands: Option<&Vec>, allow_read_only: bool) -> bool { + let default_arr = vec![]; + let allowed_commands = allowed_commands.unwrap_or(&default_arr); let Some(args) = shlex::split(&self.command) else { return true; }; - const DANGEROUS_PATTERNS: &[&str] = &["<(", "$(", "`", ">", "&&", "||", "&", ";"]; + if args .iter() .any(|arg| DANGEROUS_PATTERNS.iter().any(|p| arg.contains(p))) @@ -87,9 +94,16 @@ impl ExecuteCommand { { return true; }, - Some(cmd) if !READONLY_COMMANDS.contains(&cmd.as_str()) => return true, + Some(cmd) => { + if allowed_commands.contains(cmd) { + continue; + } + let is_cmd_read_only = READONLY_COMMANDS.contains(&cmd.as_str()); + if !allow_read_only || !is_cmd_read_only { + return true; + } + }, None => return true, - _ => (), } } @@ -139,6 +153,60 @@ impl ExecuteCommand { // TODO: probably some small amount of PATH checking Ok(()) } + + pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + struct Settings { + #[serde(default)] + allowed_commands: Vec, + #[serde(default)] + denied_commands: Vec, + #[serde(default = "default_allow_read_only")] + allow_read_only: bool, + } + + fn default_allow_read_only() -> bool { + true + } + + let Self { command, .. } = self; + let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; + let is_in_allowlist = agent.allowed_tools.contains("execute_bash"); + match agent.tools_settings.get(tool_name) { + Some(settings) if is_in_allowlist => { + let Settings { + allowed_commands, + denied_commands, + allow_read_only, + } = match serde_json::from_value::(settings.clone()) { + Ok(settings) => settings, + Err(e) => { + error!("Failed to deserialize tool settings for execute_bash: {:?}", e); + return PermissionEvalResult::Ask; + }, + }; + + if denied_commands.iter().any(|dc| command.contains(dc)) { + return PermissionEvalResult::Deny; + } + + if self.requires_acceptance(Some(&allowed_commands), allow_read_only) { + PermissionEvalResult::Ask + } else { + PermissionEvalResult::Allow + } + }, + None if is_in_allowlist => PermissionEvalResult::Allow, + _ => { + if self.requires_acceptance(None, default_allow_read_only()) { + PermissionEvalResult::Ask + } else { + PermissionEvalResult::Allow + } + }, + } + } } pub struct CommandResult { @@ -191,7 +259,7 @@ mod tests { })) .unwrap(); assert_eq!( - tool.requires_acceptance(), + tool.requires_acceptance(None, true), *expected, "expected command: `{}` to have requires_acceptance: `{}`", cmd, diff --git a/crates/chat-cli/src/cli/chat/tools/fs_read.rs b/crates/chat-cli/src/cli/chat/tools/fs_read.rs index 00ad936b83..9504916038 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_read.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_read.rs @@ -12,6 +12,10 @@ use eyre::{ Result, bail, }; +use globset::{ + Glob, + GlobSetBuilder, +}; use serde::{ Deserialize, Serialize, @@ -19,6 +23,7 @@ use serde::{ use syntect::util::LinesWithEndings; use tracing::{ debug, + error, warn, }; @@ -29,6 +34,10 @@ use super::{ format_path, sanitize_path_tool_arg, }; +use crate::cli::agent::{ + Agent, + PermissionEvalResult, +}; use crate::cli::chat::CONTINUATION_LINE; use crate::cli::chat::util::images::{ handle_images_from_paths, @@ -76,6 +85,106 @@ impl FsRead { FsRead::Image(fs_image) => fs_image.invoke(updates).await, } } + + pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + struct Settings { + #[serde(default)] + allowed_paths: Vec, + #[serde(default)] + denied_paths: Vec, + #[serde(default = "default_allow_read_only")] + allow_read_only: bool, + } + + fn default_allow_read_only() -> bool { + true + } + + let is_in_allowlist = agent.allowed_tools.contains("fs_read"); + match agent.tools_settings.get("fs_read") { + Some(settings) if is_in_allowlist => { + let Settings { + allowed_paths, + denied_paths, + allow_read_only, + } = match serde_json::from_value::(settings.clone()) { + Ok(settings) => settings, + Err(e) => { + error!("Failed to deserialize tool settings for fs_read: {:?}", e); + return PermissionEvalResult::Ask; + }, + }; + let allow_set = { + let mut builder = GlobSetBuilder::new(); + for path in &allowed_paths { + if let Ok(glob) = Glob::new(path) { + builder.add(glob); + } else { + warn!("Failed to create glob from path given: {path}. Ignoring."); + } + } + builder.build() + }; + + let deny_set = { + let mut builder = GlobSetBuilder::new(); + for path in &denied_paths { + if let Ok(glob) = Glob::new(path) { + builder.add(glob); + } else { + warn!("Failed to create glob from path given: {path}. Ignoring."); + } + } + builder.build() + }; + + match (allow_set, deny_set) { + (Ok(allow_set), Ok(deny_set)) => { + match self { + Self::Line(FsLine { path, .. }) + | Self::Directory(FsDirectory { path, .. }) + | Self::Search(FsSearch { path, .. }) => { + if deny_set.is_match(path) { + return PermissionEvalResult::Deny; + } + if allow_set.is_match(path) { + return PermissionEvalResult::Allow; + } + }, + Self::Image(fs_image) => { + let paths = &fs_image.image_paths; + if paths.iter().any(|path| deny_set.is_match(path)) { + return PermissionEvalResult::Deny; + } + if paths.iter().all(|path| allow_set.is_match(path)) { + return PermissionEvalResult::Allow; + } + }, + } + if allow_read_only { + PermissionEvalResult::Allow + } else { + PermissionEvalResult::Ask + } + }, + (allow_res, deny_res) => { + if let Err(e) = allow_res { + warn!("fs_read failed to build allow set: {:?}", e); + } + if let Err(e) = deny_res { + warn!("fs_read failed to build deny set: {:?}", e); + } + warn!("One or more detailed args failed to parse, falling back to ask"); + PermissionEvalResult::Ask + }, + } + }, + None if is_in_allowlist => PermissionEvalResult::Allow, + _ => PermissionEvalResult::Ask, + } + } } /// Read images from given paths. diff --git a/crates/chat-cli/src/cli/chat/tools/fs_write.rs b/crates/chat-cli/src/cli/chat/tools/fs_write.rs index dbd4fe71f0..3f7c5d41c4 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_write.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_write.rs @@ -13,6 +13,10 @@ use eyre::{ bail, eyre, }; +use globset::{ + Glob, + GlobSetBuilder, +}; use serde::Deserialize; use similar::DiffableStr; use syntect::easy::HighlightLines; @@ -33,6 +37,10 @@ use super::{ sanitize_path_tool_arg, supports_truecolor, }; +use crate::cli::agent::{ + Agent, + PermissionEvalResult, +}; use crate::os::Os; static SYNTAX_SET: LazyLock = LazyLock::new(SyntaxSet::load_defaults_newlines); @@ -337,6 +345,87 @@ impl FsWrite { FsWrite::Append { summary, .. } => summary.as_ref(), } } + + pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + struct Settings { + #[serde(default)] + allowed_paths: Vec, + #[serde(default)] + denied_paths: Vec, + } + + let is_in_allowlist = agent.allowed_tools.contains("fs_write"); + match agent.tools_settings.get("fs_write") { + Some(settings) if is_in_allowlist => { + let Settings { + allowed_paths, + denied_paths, + } = match serde_json::from_value::(settings.clone()) { + Ok(settings) => settings, + Err(e) => { + error!("Failed to deserialize tool settings for fs_write: {:?}", e); + return PermissionEvalResult::Ask; + }, + }; + let allow_set = { + let mut builder = GlobSetBuilder::new(); + for path in &allowed_paths { + if let Ok(glob) = Glob::new(path) { + builder.add(glob); + } else { + warn!("Failed to create glob from path given: {path}. Ignoring."); + } + } + builder.build() + }; + + let deny_set = { + let mut builder = GlobSetBuilder::new(); + for path in &denied_paths { + if let Ok(glob) = Glob::new(path) { + builder.add(glob); + } else { + warn!("Failed to create glob from path given: {path}. Ignoring."); + } + } + builder.build() + }; + + match (allow_set, deny_set) { + (Ok(allow_set), Ok(deny_set)) => { + match self { + Self::Create { path, .. } + | Self::Insert { path, .. } + | Self::Append { path, .. } + | Self::StrReplace { path, .. } => { + if deny_set.is_match(path) { + return PermissionEvalResult::Deny; + } + if allow_set.is_match(path) { + return PermissionEvalResult::Allow; + } + }, + } + PermissionEvalResult::Ask + }, + (allow_res, deny_res) => { + if let Err(e) = allow_res { + warn!("fs_write failed to build allow set: {:?}", e); + } + if let Err(e) = deny_res { + warn!("fs_write failed to build deny set: {:?}", e); + } + warn!("One or more detailed args failed to parse, falling back to ask"); + PermissionEvalResult::Ask + }, + } + }, + None if is_in_allowlist => PermissionEvalResult::Allow, + _ => PermissionEvalResult::Ask, + } + } } /// Writes `content` to `path`, adding a newline if necessary. diff --git a/crates/chat-cli/src/cli/chat/tools/gh_issue.rs b/crates/chat-cli/src/cli/chat/tools/gh_issue.rs index 787bf9d280..2dc9aa7f7b 100644 --- a/crates/chat-cli/src/cli/chat/tools/gh_issue.rs +++ b/crates/chat-cli/src/cli/chat/tools/gh_issue.rs @@ -1,7 +1,4 @@ -use std::collections::{ - HashMap, - VecDeque, -}; +use std::collections::VecDeque; use std::io::Write; use crossterm::style::Color; @@ -18,10 +15,7 @@ use serde::Deserialize; use super::super::context::ContextManager; use super::super::util::issue::IssueCreator; -use super::{ - InvokeOutput, - ToolPermission, -}; +use super::InvokeOutput; use crate::cli::chat::token_counter::TokenCounter; use crate::os::Os; @@ -41,7 +35,7 @@ pub struct GhIssueContext { pub context_manager: Option, pub transcript: VecDeque, pub failed_request_ids: Vec, - pub tool_permissions: HashMap, + pub tool_permissions: Vec, } /// Max amount of characters to include in the transcript. @@ -147,22 +141,6 @@ impl GhIssue { }; os_str.push_str(&format!("current_profile={}\n", os_manager.current_profile)); - match os_manager.list_profiles(os).await { - Ok(profiles) if !profiles.is_empty() => { - os_str.push_str(&format!("profiles=\n{}\n\n", profiles.join("\n"))); - }, - _ => os_str.push_str("profiles=none\n\n"), - } - - // Context file categories - if os_manager.global_config.paths.is_empty() { - os_str.push_str("global_context=none\n\n"); - } else { - os_str.push_str(&format!( - "global_context=\n{}\n\n", - &os_manager.global_config.paths.join("\n") - )); - } if os_manager.profile_config.paths.is_empty() { os_str.push_str("profile_context=none\n\n"); @@ -196,8 +174,8 @@ impl GhIssue { fn get_chat_settings(context: &GhIssueContext) -> String { let mut result_str = "[chat-settings]\n".to_string(); result_str.push_str("\n\n[chat-trusted_tools]"); - for (tool, permission) in context.tool_permissions.iter() { - result_str.push_str(&format!("\n{tool}={}", permission.trusted)); + for tool in context.tool_permissions.iter() { + result_str.push_str(&format!("\n{tool}=trusted")); } result_str diff --git a/crates/chat-cli/src/cli/chat/tools/mod.rs b/crates/chat-cli/src/cli/chat/tools/mod.rs index f45bcaba47..f34c0e16d4 100644 --- a/crates/chat-cli/src/cli/chat/tools/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/mod.rs @@ -7,10 +7,7 @@ pub mod knowledge; pub mod thinking; pub mod use_aws; -use std::collections::{ - HashMap, - HashSet, -}; +use std::borrow::Borrow; use std::io::Write; use std::path::{ Path, @@ -21,7 +18,6 @@ use crossterm::queue; use crossterm::style::{ self, Color, - Stylize, }; use custom_tool::CustomTool; use execute::ExecuteCommand; @@ -39,8 +35,26 @@ use use_aws::UseAws; use super::consts::MAX_TOOL_RESPONSE_SIZE; use super::util::images::RichImageBlocks; +use crate::cli::agent::{ + Agent, + PermissionEvalResult, +}; use crate::os::Os; +pub const DEFAULT_APPROVE: [&str; 1] = ["fs_read"]; +pub const NATIVE_TOOLS: [&str; 7] = [ + "fs_read", + "fs_write", + #[cfg(windows)] + "execute_cmd", + #[cfg(not(windows))] + "execute_bash", + "use_aws", + "gh_issue", + "knowledge", + "thinking", +]; + /// Represents an executable tool use. #[allow(clippy::large_enum_variant)] #[derive(Debug, Clone)] @@ -75,16 +89,16 @@ impl Tool { } /// Whether or not the tool should prompt the user to accept before [Self::invoke] is called. - pub fn requires_acceptance(&self, _os: &Os) -> bool { + pub fn requires_acceptance(&self, agent: &Agent) -> PermissionEvalResult { match self { - Tool::FsRead(_) => false, - Tool::FsWrite(_) => true, - Tool::ExecuteCommand(execute_command) => execute_command.requires_acceptance(), - Tool::UseAws(use_aws) => use_aws.requires_acceptance(), - Tool::Custom(_) => true, - Tool::GhIssue(_) => false, - Tool::Knowledge(_) => false, - Tool::Thinking(_) => false, + Tool::FsRead(fs_read) => fs_read.eval_perm(agent), + Tool::FsWrite(fs_write) => fs_write.eval_perm(agent), + Tool::ExecuteCommand(execute_command) => execute_command.eval_perm(agent), + Tool::UseAws(use_aws) => use_aws.eval_perm(agent), + Tool::Custom(custom_tool) => custom_tool.eval_perm(agent), + Tool::GhIssue(_) => PermissionEvalResult::Allow, + Tool::Thinking(_) => PermissionEvalResult::Allow, + Tool::Knowledge(_) => PermissionEvalResult::Ask, } } @@ -131,121 +145,6 @@ impl Tool { } } -#[derive(Debug, Clone)] -pub struct ToolPermission { - pub trusted: bool, -} - -#[derive(Debug, Clone)] -/// Holds overrides for tool permissions. -/// Tools that do not have an associated ToolPermission should use -/// their default logic to determine to permission. -pub struct ToolPermissions { - // We need this field for any stragglers - pub trust_all: bool, - pub permissions: HashMap, - // Store pending trust-tool patterns for MCP tools that may be loaded later - pub pending_trusted_tools: HashSet, -} - -impl ToolPermissions { - pub fn new(capacity: usize) -> Self { - Self { - trust_all: false, - permissions: HashMap::with_capacity(capacity), - pending_trusted_tools: HashSet::new(), - } - } - - pub fn is_trusted(&mut self, tool_name: &str) -> bool { - // Check if we should trust from pending patterns first - if self.should_trust_from_pending(tool_name) { - self.trust_tool(tool_name); - self.pending_trusted_tools.remove(tool_name); - } - - self.trust_all || self.permissions.get(tool_name).is_some_and(|perm| perm.trusted) - } - - /// Returns a label to describe the permission status for a given tool. - pub fn display_label(&mut self, tool_name: &str) -> String { - let is_trusted = self.is_trusted(tool_name); - let has_setting = self.has(tool_name) || self.trust_all; - - match (has_setting, is_trusted) { - (true, true) => format!(" {}", "trusted".dark_green().bold()), - (true, false) => format!(" {}", "not trusted".dark_grey()), - _ => self.default_permission_label(tool_name), - } - } - - pub fn trust_tool(&mut self, tool_name: &str) { - self.permissions - .insert(tool_name.to_string(), ToolPermission { trusted: true }); - } - - pub fn untrust_tool(&mut self, tool_name: &str) { - self.trust_all = false; - self.pending_trusted_tools.remove(tool_name); - self.permissions - .insert(tool_name.to_string(), ToolPermission { trusted: false }); - } - - pub fn reset(&mut self) { - self.trust_all = false; - self.permissions.clear(); - self.pending_trusted_tools.clear(); - } - - pub fn reset_tool(&mut self, tool_name: &str) { - self.trust_all = false; - self.permissions.remove(tool_name); - self.pending_trusted_tools.remove(tool_name); - } - - /// Add a pending trust pattern for tools that may be loaded later - pub fn add_pending_trust_tool(&mut self, pattern: String) { - self.pending_trusted_tools.insert(pattern); - } - - /// Check if a tool should be trusted based on preceding trust declarations - pub fn should_trust_from_pending(&self, tool_name: &str) -> bool { - // Check for exact match - self.pending_trusted_tools.contains(tool_name) - } - - pub fn has(&mut self, tool_name: &str) -> bool { - // Check if we should trust from pending tools first - if self.should_trust_from_pending(tool_name) { - self.trust_tool(tool_name); - self.pending_trusted_tools.remove(tool_name); - } - - self.permissions.contains_key(tool_name) - } - - /// Provide default permission labels for the built-in set of tools. - // This "static" way avoids needing to construct a tool instance. - fn default_permission_label(&self, tool_name: &str) -> String { - let label = match tool_name { - "fs_read" => "trusted".dark_green().bold(), - "fs_write" => "not trusted".dark_grey(), - #[cfg(not(windows))] - "execute_bash" => "trust read-only commands".dark_grey(), - #[cfg(windows)] - "execute_cmd" => "trust read-only commands".dark_grey(), - "use_aws" => "trust read-only commands".dark_grey(), - "report_issue" => "trusted".dark_green().bold(), - "knowledge" => "trusted".dark_green().bold(), - "thinking" => "trusted (prerelease)".dark_green().bold(), - _ if self.trust_all => "trusted".dark_grey().bold(), - _ => "not trusted".dark_grey(), - }; - - format!("{} {label}", "*".reset()) - } -} - /// A tool specification to be sent to the model as part of a conversation. Maps to /// [BedrockToolSpecification]. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -258,12 +157,30 @@ pub struct ToolSpec { pub tool_origin: ToolOrigin, } -#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Eq, PartialEq)] pub enum ToolOrigin { Native, McpServer(String), } +impl std::hash::Hash for ToolOrigin { + fn hash(&self, state: &mut H) { + match self { + Self::Native => "native".hash(state), + Self::McpServer(name) => name.hash(state), + } + } +} + +impl Borrow for ToolOrigin { + fn borrow(&self) -> &str { + match self { + Self::McpServer(name) => name.as_str(), + Self::Native => "native", + } + } +} + impl<'de> Deserialize<'de> for ToolOrigin { fn deserialize(deserializer: D) -> Result where diff --git a/crates/chat-cli/src/cli/chat/tools/use_aws.rs b/crates/chat-cli/src/cli/chat/tools/use_aws.rs index f9ec3dd1e9..59b41e8b0d 100644 --- a/crates/chat-cli/src/cli/chat/tools/use_aws.rs +++ b/crates/chat-cli/src/cli/chat/tools/use_aws.rs @@ -16,12 +16,17 @@ use eyre::{ WrapErr, }; use serde::Deserialize; +use tracing::error; use super::{ InvokeOutput, MAX_TOOL_RESPONSE_SIZE, OutputKind, }; +use crate::cli::agent::{ + Agent, + PermissionEvalResult, +}; use crate::os::Os; const READONLY_OPS: [&str; 6] = ["get", "describe", "list", "ls", "search", "batch_get"]; @@ -188,6 +193,44 @@ impl UseAws { None } } + + pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { + #[derive(Debug, Deserialize)] + #[serde(rename_all = "camelCase")] + struct Settings { + allowed_services: Vec, + denied_services: Vec, + } + + let Self { service_name, .. } = self; + let is_in_allowlist = agent.allowed_tools.contains("use_aws"); + match agent.tools_settings.get("use_aws") { + Some(settings) if is_in_allowlist => { + let settings = match serde_json::from_value::(settings.clone()) { + Ok(settings) => settings, + Err(e) => { + error!("Failed to deserialize tool settings for use_aws: {:?}", e); + return PermissionEvalResult::Ask; + }, + }; + if settings.denied_services.contains(service_name) { + return PermissionEvalResult::Deny; + } + if settings.allowed_services.contains(service_name) { + return PermissionEvalResult::Allow; + } + PermissionEvalResult::Ask + }, + None if is_in_allowlist => PermissionEvalResult::Allow, + _ => { + if self.requires_acceptance() { + PermissionEvalResult::Ask + } else { + PermissionEvalResult::Allow + } + }, + } + } } #[cfg(test)] diff --git a/crates/chat-cli/src/cli/chat/util/test.rs b/crates/chat-cli/src/cli/chat/util/test.rs index 1106e02a78..a43a67902c 100644 --- a/crates/chat-cli/src/cli/chat/util/test.rs +++ b/crates/chat-cli/src/cli/chat/util/test.rs @@ -1,5 +1,6 @@ use eyre::Result; +use crate::cli::agent::Agent; use crate::cli::chat::consts::CONTEXT_FILES_MAX_SIZE; use crate::cli::chat::context::ContextManager; use crate::os::Os; @@ -15,11 +16,10 @@ pub const TEST_FILE_PATH: &str = "/test_file.txt"; pub const TEST_HIDDEN_FILE_PATH: &str = "/aaaa2/.hidden"; // Helper function to create a test ContextManager with Context -pub async fn create_test_context_manager(context_file_size: Option) -> Result { +pub fn create_test_context_manager(context_file_size: Option) -> Result { let context_file_size = context_file_size.unwrap_or(CONTEXT_FILES_MAX_SIZE); - let os = Os::new().await.unwrap(); - let manager = ContextManager::new(&os, Some(context_file_size)).await?; - Ok(manager) + let agent = Agent::default(); + ContextManager::from_agent(&agent, Some(context_file_size)) } /// Sets up the following filesystem structure: diff --git a/crates/chat-cli/src/cli/mcp.rs b/crates/chat-cli/src/cli/mcp.rs index 451d44b87f..d1cf785159 100644 --- a/crates/chat-cli/src/cli/mcp.rs +++ b/crates/chat-cli/src/cli/mcp.rs @@ -16,10 +16,13 @@ use eyre::{ Result, bail, }; -use tracing::warn; -use crate::cli::chat::tool_manager::{ +use super::agent::{ + Agent, + Agents, McpServerConfig, +}; +use crate::cli::chat::tool_manager::{ global_mcp_config_path, workspace_mcp_config_path, }; @@ -28,6 +31,7 @@ use crate::cli::chat::tools::custom_tool::{ default_timeout, }; use crate::os::Os; +use crate::util::directories; #[derive(Debug, Copy, Clone, PartialEq, Eq, ValueEnum)] pub enum Scope { @@ -86,8 +90,8 @@ pub struct AddArgs { #[arg(long, action = ArgAction::Append, allow_hyphen_values = true, value_delimiter = ',')] pub args: Vec, /// Where to add the server to. - #[arg(long, value_enum)] - pub scope: Option, + #[arg(long)] + pub agent: Option, /// Environment variables to use when launching the server #[arg(long, value_parser = parse_env_vars)] pub env: Vec>, @@ -104,17 +108,16 @@ pub struct AddArgs { impl AddArgs { pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> { - let scope = self.scope.unwrap_or(Scope::Workspace); - let config_path = resolve_scope_profile(os, self.scope)?; - - let mut config: McpServerConfig = ensure_config_file(os, &config_path, output).await?; + let agent_name = self.agent.as_deref().unwrap_or("default"); + let (mut agent, config_path) = Agent::get_agent_by_name(os, agent_name).await?; - if config.mcp_servers.contains_key(&self.name) && !self.force { + let mcp_servers = &mut agent.mcp_servers.mcp_servers; + if mcp_servers.contains_key(&self.name) && !self.force { bail!( - "\nMCP server '{}' already exists in {} (scope {}). Use --force to overwrite.", + "\nMCP server '{}' already exists in agent {} (path {}). Use --force to overwrite.", self.name, + agent_name, config_path.display(), - scope ); } @@ -132,14 +135,10 @@ impl AddArgs { "\nTo learn more about MCP safety, see https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-mcp-security.html\n\n" )?; - config.mcp_servers.insert(self.name.clone(), tool); - config.save_to_file(os, &config_path).await?; - writeln!( - output, - "✓ Added MCP server '{}' to {}\n", - self.name, - scope_display(&scope) - )?; + mcp_servers.insert(self.name.clone(), tool); + let json = serde_json::to_string_pretty(&agent)?; + os.fs.write(config_path, json).await?; + writeln!(output, "✓ Added MCP server '{}' to agent {}\n", self.name, agent_name)?; Ok(()) } } @@ -149,36 +148,35 @@ pub struct RemoveArgs { #[arg(long)] pub name: String, #[arg(long, value_enum)] - pub scope: Option, + pub agent: Option, } impl RemoveArgs { pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> { - let scope = self.scope.unwrap_or(Scope::Workspace); - let config_path = resolve_scope_profile(os, self.scope)?; + let agent_name = self.agent.as_deref().unwrap_or("default"); + let (mut agent, config_path) = Agent::get_agent_by_name(os, agent_name).await?; if !os.fs.exists(&config_path) { writeln!(output, "\nNo MCP server configurations found.\n")?; return Ok(()); } - let mut config = McpServerConfig::load_from_file(os, &config_path).await?; - match config.mcp_servers.remove(&self.name) { + let config = &mut agent.mcp_servers.mcp_servers; + match config.remove(&self.name) { Some(_) => { - config.save_to_file(os, &config_path).await?; + let json = serde_json::to_string_pretty(&agent)?; + os.fs.write(config_path, json).await?; writeln!( output, - "\n✓ Removed MCP server '{}' from {}\n", - self.name, - scope_display(&scope) + "\n✓ Removed MCP server '{}' from agent {}\n", + self.name, agent_name, )?; }, None => { writeln!( output, - "\nNo MCP server named '{}' found in {}\n", - self.name, - scope_display(&scope) + "\nNo MCP server named '{}' found in agent {}\n", + self.name, agent_name, )?; }, } @@ -195,7 +193,7 @@ pub struct ListArgs { } impl ListArgs { - pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> { + pub async fn execute(self, os: &mut Os, output: &mut impl Write) -> Result<()> { let configs = get_mcp_server_configs(os, self.scope).await?; if configs.is_empty() { writeln!(output, "No MCP server configurations found.\n")?; @@ -279,7 +277,7 @@ pub struct StatusArgs { } impl StatusArgs { - pub async fn execute(self, os: &Os, output: &mut impl Write) -> Result<()> { + pub async fn execute(self, os: &mut Os, output: &mut impl Write) -> Result<()> { let configs = get_mcp_server_configs(os, None).await?; let mut found = false; @@ -314,7 +312,7 @@ impl StatusArgs { } async fn get_mcp_server_configs( - os: &Os, + os: &mut Os, scope: Option, ) -> Result)>> { let mut targets = Vec::new(); @@ -324,20 +322,24 @@ async fn get_mcp_server_configs( } let mut results = Vec::new(); - for sc in targets { - let path = resolve_scope_profile(os, Some(sc))?; - let cfg_opt = if os.fs.exists(&path) { - match McpServerConfig::load_from_file(os, &path).await { - Ok(cfg) => Some(cfg), - Err(e) => { - warn!(?path, error = %e, "Invalid MCP config file—ignored, treated as null"); - None - }, - } + let mut stderr = std::io::stderr(); + let agents = Agents::load(os, None, true, &mut stderr).await; + let global_path = directories::chat_global_agent_path(os)?; + for (_, agent) in agents.agents { + let scope = if agent + .path + .as_ref() + .is_some_and(|p| p.parent().is_some_and(|p| p == global_path)) + { + Scope::Global } else { - None + Scope::Workspace }; - results.push((sc, path, cfg_opt)); + results.push(( + scope, + agent.path.ok_or(eyre::eyre!("Agent missing path info"))?, + Some(agent.mcp_servers), + )); } Ok(results) } @@ -445,6 +447,7 @@ mod tests { assert!(cfg.mcp_servers.is_empty()); } + #[ignore = "TODO: fix in CI"] #[tokio::test] async fn add_then_remove_cycle() { let os = Os::new().await.unwrap(); @@ -460,7 +463,7 @@ mod tests { ], env: vec![], timeout: None, - scope: None, + agent: None, disabled: false, force: false, } @@ -476,7 +479,7 @@ mod tests { // 2. remove RemoveArgs { name: "local".into(), - scope: None, + agent: None, } .execute(&os, &mut vec![]) .await @@ -509,7 +512,7 @@ mod tests { "--allow-write".to_string(), "--allow-sensitive-data-access".to_string(), ], - scope: None, + agent: None, env: vec![ [ ("key1".to_string(), "value1".to_string()), @@ -531,7 +534,7 @@ mod tests { ["mcp", "remove", "--name", "old"], RootSubcommand::Mcp(McpSubcommand::Remove(RemoveArgs { name: "old".into(), - scope: None, + agent: None, })) ); } diff --git a/crates/chat-cli/src/cli/mod.rs b/crates/chat-cli/src/cli/mod.rs index a04bebf6a5..f3de01c87d 100644 --- a/crates/chat-cli/src/cli/mod.rs +++ b/crates/chat-cli/src/cli/mod.rs @@ -1,3 +1,4 @@ +mod agent; mod chat; mod debug; mod diagnostics; @@ -351,11 +352,12 @@ mod test { subcommand: Some(RootSubcommand::Chat(ChatArgs { resume: false, input: None, - profile: None, + agent: None, model: None, trust_all_tools: false, trust_tools: None, - no_interactive: false + no_interactive: false, + migrate: false, })), verbose: 2, help_all: false, @@ -390,11 +392,12 @@ mod test { RootSubcommand::Chat(ChatArgs { resume: false, input: None, - profile: Some("my-profile".to_string()), + agent: Some("my-profile".to_string()), model: None, trust_all_tools: false, trust_tools: None, - no_interactive: false + no_interactive: false, + migrate: false, }) ); } @@ -406,11 +409,12 @@ mod test { RootSubcommand::Chat(ChatArgs { resume: false, input: Some("Hello".to_string()), - profile: Some("my-profile".to_string()), + agent: Some("my-profile".to_string()), model: None, trust_all_tools: false, trust_tools: None, - no_interactive: false + no_interactive: false, + migrate: false, }) ); } @@ -422,11 +426,12 @@ mod test { RootSubcommand::Chat(ChatArgs { resume: false, input: None, - profile: Some("my-profile".to_string()), + agent: Some("my-profile".to_string()), model: None, trust_all_tools: true, trust_tools: None, - no_interactive: false + no_interactive: false, + migrate: false, }) ); } @@ -438,11 +443,12 @@ mod test { RootSubcommand::Chat(ChatArgs { resume: true, input: None, - profile: None, + agent: None, model: None, trust_all_tools: false, trust_tools: None, - no_interactive: true + no_interactive: true, + migrate: false, }) ); assert_parse!( @@ -450,11 +456,12 @@ mod test { RootSubcommand::Chat(ChatArgs { resume: true, input: None, - profile: None, + agent: None, model: None, trust_all_tools: false, trust_tools: None, - no_interactive: true + no_interactive: true, + migrate: false, }) ); } @@ -466,11 +473,12 @@ mod test { RootSubcommand::Chat(ChatArgs { resume: false, input: None, - profile: None, + agent: None, model: None, trust_all_tools: true, trust_tools: None, - no_interactive: false + no_interactive: false, + migrate: false, }) ); } @@ -482,11 +490,12 @@ mod test { RootSubcommand::Chat(ChatArgs { resume: false, input: None, - profile: None, + agent: None, model: None, trust_all_tools: false, trust_tools: Some(vec!["".to_string()]), - no_interactive: false + no_interactive: false, + migrate: false, }) ); } @@ -498,11 +507,12 @@ mod test { RootSubcommand::Chat(ChatArgs { resume: false, input: None, - profile: None, + agent: None, model: None, trust_all_tools: false, trust_tools: Some(vec!["fs_read".to_string(), "fs_write".to_string()]), - no_interactive: false + no_interactive: false, + migrate: false, }) ); } diff --git a/crates/chat-cli/src/database/settings.rs b/crates/chat-cli/src/database/settings.rs index 85e56c21a5..06f3302a0e 100644 --- a/crates/chat-cli/src/database/settings.rs +++ b/crates/chat-cli/src/database/settings.rs @@ -33,6 +33,7 @@ pub enum Setting { McpNoInteractiveTimeout, McpLoadedBefore, ChatDefaultModel, + ChatDefaultAgent, } impl AsRef for Setting { @@ -54,6 +55,7 @@ impl AsRef for Setting { Self::McpNoInteractiveTimeout => "mcp.noInteractiveTimeout", Self::McpLoadedBefore => "mcp.loadedBefore", Self::ChatDefaultModel => "chat.defaultModel", + Self::ChatDefaultAgent => "chat.defaultAgent", } } } @@ -85,6 +87,7 @@ impl TryFrom<&str> for Setting { "mcp.noInteractiveTimeout" => Ok(Self::McpNoInteractiveTimeout), "mcp.loadedBefore" => Ok(Self::McpLoadedBefore), "chat.defaultModel" => Ok(Self::ChatDefaultModel), + "chat.defaultAgent" => Ok(Self::ChatDefaultAgent), _ => Err(DatabaseError::InvalidSetting(value.to_string())), } } diff --git a/crates/chat-cli/src/util/consts.rs b/crates/chat-cli/src/util/consts.rs index a5dfe5e02e..bfab0a9521 100644 --- a/crates/chat-cli/src/util/consts.rs +++ b/crates/chat-cli/src/util/consts.rs @@ -6,6 +6,8 @@ pub const PRODUCT_NAME: &str = "Amazon Q"; pub const GITHUB_REPO_NAME: &str = "aws/amazon-q-developer-cli"; +pub const MCP_SERVER_TOOL_DELIMITER: &str = "/"; + pub const GOV_REGIONS: &[&str] = &["us-gov-east-1", "us-gov-west-1"]; /// Build time env vars diff --git a/crates/chat-cli/src/util/directories.rs b/crates/chat-cli/src/util/directories.rs index 70dd7d9f69..90d2596e61 100644 --- a/crates/chat-cli/src/util/directories.rs +++ b/crates/chat-cli/src/util/directories.rs @@ -129,12 +129,35 @@ pub fn logs_dir() -> Result { } } +/// Example agent config path +pub fn example_agent_config(os: &Os) -> Result { + let global_path = chat_global_agent_path(os)?; + Ok(global_path.join("agent_config.json.example")) +} + +/// Legacy global MCP server config path +pub fn chat_legacy_mcp_config(os: &Os) -> Result { + Ok(home_dir(os)?.join(".aws").join("amazonq").join("mcp.json")) +} + +/// The directory to the directory containing global agents +pub fn chat_global_agent_path(os: &Os) -> Result { + Ok(home_dir(os)?.join(".aws").join("amazonq").join("agents")) +} + +/// The directory to the directory containing config for the `/context` feature in `q chat`. +pub fn chat_local_agent_dir() -> Result { + let cwd = std::env::current_dir()?; + Ok(cwd.join(".aws").join("amazonq").join("agents")) +} + /// The directory to the directory containing config for the `/context` feature in `q chat`. pub fn chat_global_context_path(os: &Os) -> Result { Ok(home_dir(os)?.join(".aws").join("amazonq").join("global_context.json")) } /// The directory to the directory containing config for the `/context` feature in `q chat`. +#[allow(dead_code)] pub fn chat_profiles_dir(os: &Os) -> Result { Ok(home_dir(os)?.join(".aws").join("amazonq").join("profiles")) }