diff --git a/crates/q_cli/src/cli/chat/command.rs b/crates/q_cli/src/cli/chat/command.rs index fcde98d46e..9f81659d3c 100644 --- a/crates/q_cli/src/cli/chat/command.rs +++ b/crates/q_cli/src/cli/chat/command.rs @@ -104,7 +104,7 @@ pub enum HooksSubcommand { name: String, #[arg(long, value_parser = ["per_prompt", "conversation_start"])] - r#type: String, + trigger: String, #[arg(long, value_parser = clap::value_parser!(String))] command: String, @@ -191,7 +191,7 @@ impl ContextSubcommand { hooks add [--global] <> Add a new command context hook --global: Add to global hooks - --type <> Type of hook, valid options: `per_prompt` or `conversation_start` + --trigger <> When to trigger the hook, valid options: `per_prompt` or `conversation_start` --command <> Shell command to execute hooks rm [--global] <> Remove an existing context hook @@ -853,12 +853,12 @@ mod tests { context!(ContextSubcommand::Hooks { subcommand: None }), ), ( - "/context hooks add test --type per_prompt --command 'echo 1' --global", + "/context hooks add test --trigger per_prompt --command 'echo 1' --global", context!(ContextSubcommand::Hooks { subcommand: Some(HooksSubcommand::Add { name: "test".to_string(), global: true, - r#type: "per_prompt".to_string(), + trigger: "per_prompt".to_string(), command: "echo 1".to_string() }) }), diff --git a/crates/q_cli/src/cli/chat/context.rs b/crates/q_cli/src/cli/chat/context.rs index 57c039bd25..86bdaa84cb 100644 --- a/crates/q_cli/src/cli/chat/context.rs +++ b/crates/q_cli/src/cli/chat/context.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::io::Write; use std::path::{ Path, @@ -22,7 +23,6 @@ use super::hooks::{ Hook, HookExecutor, }; -use crate::cli::chat::hooks::HookConfig; pub const AMAZONQ_FILENAME: &str = "AmazonQ.md"; @@ -32,7 +32,9 @@ pub const AMAZONQ_FILENAME: &str = "AmazonQ.md"; pub struct ContextConfig { /// List of file paths or glob patterns to include in the context. pub paths: Vec, - pub hooks: HookConfig, + + /// Map of Hook Name to [`Hook`]. The hook name serves as the hook's ID. + pub hooks: HashMap, } #[allow(dead_code)] @@ -353,6 +355,7 @@ impl ContextManager { /// A Result indicating success or an error pub async fn switch_profile(&mut self, name: &str) -> Result<()> { validate_profile_name(name)?; + self.hook_executor.profile_cache.clear(); // Special handling for default profile - it always exists if name == "default" { @@ -459,37 +462,17 @@ impl ContextManager { /// config. /// * `conversation_start` - If true, add the hook to conversation_start. Otherwise, it will be /// added to per_prompt. - pub async fn add_hook(&mut self, hook: Hook, global: bool, conversation_start: bool) -> Result<()> { - if self.num_hooks_with_name(&hook.name) > 0 { - return Err(eyre!( - "Cannot add hook, another hook with this name already exists in global or profile context." - )); - } - + pub async fn add_hook(&mut self, name: String, hook: Hook, global: bool) -> Result<()> { let config = self.get_config_mut(global); - let hook_vec = if conversation_start { - &mut config.hooks.conversation_start - } else { - &mut config.hooks.per_prompt - }; + if config.hooks.contains_key(&name) { + return Err(eyre!("name already exists.")); + } - hook_vec.push(hook); + config.hooks.insert(name, hook); self.save_config(global).await } - fn num_hooks_with_name(&self, name: &str) -> usize { - self.global_config - .hooks - .conversation_start - .iter() - .chain(self.global_config.hooks.per_prompt.iter()) - .chain(self.profile_config.hooks.conversation_start.iter()) - .chain(self.profile_config.hooks.per_prompt.iter()) - .filter(|h| h.name == name) - .count() - } - /// Delete hook(s) by name /// # Arguments /// * `name` - name of the hook to delete @@ -498,8 +481,11 @@ impl ContextManager { pub async fn remove_hook(&mut self, name: &str, global: bool) -> Result<()> { let config = self.get_config_mut(global); - config.hooks.conversation_start.retain(|h| h.name != name); - config.hooks.per_prompt.retain(|h| h.name != name); + if !config.hooks.contains_key(name) { + return Err(eyre!("does not exist.")); + } + + config.hooks.remove(name); self.save_config(global).await } @@ -510,13 +496,13 @@ impl ContextManager { pub async fn set_hook_disabled(&mut self, name: &str, global: bool, disable: bool) -> Result<()> { let config = self.get_config_mut(global); - config - .hooks - .conversation_start - .iter_mut() - .chain(config.hooks.per_prompt.iter_mut()) - .filter(|h| h.name == name) - .for_each(|h| h.disabled = disable); + if !config.hooks.contains_key(name) { + return Err(eyre!("does not exist.")); + } + + if let Some(hook) = config.hooks.get_mut(name) { + hook.disabled = disable; + } self.save_config(global).await } @@ -527,12 +513,7 @@ impl ContextManager { pub async fn set_all_hooks_disabled(&mut self, global: bool, disable: bool) -> Result<()> { let config = self.get_config_mut(global); - config - .hooks - .conversation_start - .iter_mut() - .chain(config.hooks.per_prompt.iter_mut()) - .for_each(|h| h.disabled = disable); + config.hooks.iter_mut().for_each(|(_, h)| h.disabled = disable); self.save_config(global).await } @@ -545,27 +526,19 @@ impl ContextManager { pub async fn run_hooks(&mut self, updates: &mut impl Write) -> Vec<(Hook, String)> { let mut hooks: Vec<&Hook> = Vec::new(); - // Collect all conversation start hooks - hooks.extend( - self.global_config - .hooks - .conversation_start - .iter_mut() - .chain(self.profile_config.hooks.conversation_start.iter_mut()) - .map(|h| { - h.is_conversation_start = true; - &*h - }), - ); - - // Collect all per-prompt hooks - hooks.extend( - self.global_config - .hooks - .per_prompt - .iter() - .chain(self.profile_config.hooks.per_prompt.iter()), - ); + // Set internal hook states + let configs = [ + (&mut self.global_config.hooks, true), + (&mut self.profile_config.hooks, false), + ]; + + for (hook_list, is_global) in configs { + hooks.extend(hook_list.iter_mut().map(|(name, h)| { + h.name = name.to_string(); + h.is_global = is_global; + &*h + })); + } self.hook_executor.run_hooks(hooks, updates).await } @@ -600,7 +573,7 @@ async fn load_global_config(ctx: &Context) -> Result { "README.md".to_string(), AMAZONQ_FILENAME.to_string(), ], - hooks: HookConfig::default(), + hooks: HashMap::new(), }) } } diff --git a/crates/q_cli/src/cli/chat/hooks.rs b/crates/q_cli/src/cli/chat/hooks.rs index c08176e313..5dd4407850 100644 --- a/crates/q_cli/src/cli/chat/hooks.rs +++ b/crates/q_cli/src/cli/chat/hooks.rs @@ -30,8 +30,7 @@ const DEFAULT_CACHE_TTL_SECONDS: u64 = 0; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Hook { - /// Unique name of the hook - pub name: String, + pub trigger: HookTrigger, pub r#type: HookType, @@ -47,8 +46,8 @@ pub struct Hook { pub max_output_size: usize, /// How long the hook output is cached before it will be executed again - #[serde(skip_serializing_if = "Option::is_none")] - pub cache_ttl_seconds: Option, + #[serde(default = "Hook::default_cache_ttl_seconds")] + pub cache_ttl_seconds: u64, // Type-specific fields /// The bash command to execute @@ -56,21 +55,23 @@ pub struct Hook { // Internal data #[serde(skip)] - pub is_conversation_start: bool, + pub name: String, + #[serde(skip)] + pub is_global: bool, } impl Hook { - #[allow(dead_code)] // TODO: Remove - pub fn new_inline_hook(name: &str, command: String, is_conversation_start: bool) -> Self { + pub fn new_inline_hook(trigger: HookTrigger, command: String) -> Self { Self { - name: name.to_string(), + trigger, r#type: HookType::Inline, disabled: Self::default_disabled(), timeout_ms: Self::default_timeout_ms(), max_output_size: Self::default_max_output_size(), - cache_ttl_seconds: None, + cache_ttl_seconds: Self::default_cache_ttl_seconds(), command: Some(command), - is_conversation_start, + is_global: false, + name: "new hook".to_string(), } } @@ -85,20 +86,24 @@ impl Hook { fn default_max_output_size() -> usize { DEFAULT_MAX_OUTPUT_SIZE } + + fn default_cache_ttl_seconds() -> u64 { + DEFAULT_CACHE_TTL_SECONDS + } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[serde(rename_all = "lowercase")] pub enum HookType { // Execute an inline shell command Inline, } -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -#[serde(default)] -pub struct HookConfig { - pub conversation_start: Vec, - pub per_prompt: Vec, +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum HookTrigger { + ConversationStart, + PerPrompt, } #[derive(Debug, Clone)] @@ -107,15 +112,18 @@ pub struct CachedHook { expiry: Option, } +/// Maps a hook name to a [`CachedHook`] #[derive(Debug, Clone)] pub struct HookExecutor { - pub execution_cache: HashMap, + pub global_cache: HashMap, + pub profile_cache: HashMap, } impl HookExecutor { pub fn new() -> Self { Self { - execution_cache: HashMap::new(), + global_cache: HashMap::new(), + profile_cache: HashMap::new(), } } @@ -134,7 +142,7 @@ impl HookExecutor { } // Check if the hook is cached. If so, push a completed future. - if let Some(cached) = self.get_cache(&hook.name) { + if let Some(cached) = self.get_cache(hook) { futures.push(Either::Left(future::ready(( hook, Ok(cached.clone()), @@ -186,16 +194,12 @@ impl HookExecutor { for (hook, result, _) in results { if result.is_ok() { // Conversation start hooks are always cached as they are expected to run once per session. - let expiry = if hook.is_conversation_start { - None - } else { - Some( - Instant::now() - + Duration::from_secs(hook.cache_ttl_seconds.unwrap_or(DEFAULT_CACHE_TTL_SECONDS)), - ) + let expiry = match hook.trigger { + HookTrigger::ConversationStart => None, + HookTrigger::PerPrompt => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)), }; - self.insert_cache(&hook.name, CachedHook { + self.insert_cache(hook, CachedHook { output: result.as_ref().cloned().unwrap(), expiry, }); @@ -247,8 +251,14 @@ impl HookExecutor { } /// Will return a cached hook's output if it exists and isn't expired. - fn get_cache(&self, name: &str) -> Option { - self.execution_cache.get(name).and_then(|o| { + fn get_cache(&self, hook: &Hook) -> Option { + let cache = if hook.is_global { + &self.global_cache + } else { + &self.profile_cache + }; + + cache.get(&hook.name).and_then(|o| { if let Some(expiry) = o.expiry { if Instant::now() < expiry { Some(o.output.clone()) @@ -261,7 +271,13 @@ impl HookExecutor { }) } - fn insert_cache(&mut self, name: &str, hook_output: CachedHook) { - self.execution_cache.insert(name.to_string(), hook_output); + fn insert_cache(&mut self, hook: &Hook, hook_output: CachedHook) { + let cache = if hook.is_global { + &mut self.global_cache + } else { + &mut self.profile_cache + }; + + cache.insert(hook.name.clone(), hook_output); } } diff --git a/crates/q_cli/src/cli/chat/mod.rs b/crates/q_cli/src/cli/chat/mod.rs index 80f09fbbc6..aa6f063410 100644 --- a/crates/q_cli/src/cli/chat/mod.rs +++ b/crates/q_cli/src/cli/chat/mod.rs @@ -68,7 +68,10 @@ use fig_api_client::model::{ use fig_os_shim::Context; use fig_settings::Settings; use fig_util::CLI_BINARY_NAME; -use hooks::Hook; +use hooks::{ + Hook, + HookTrigger, +}; use summarization_state::{ SummarizationState, TokenWarningLevel, @@ -829,8 +832,9 @@ where let hook_results = cm.run_hooks(&mut self.output).await; - let (start_hooks, prompt_hooks): (Vec<_>, Vec<_>) = - hook_results.iter().partition(|(hook, _)| hook.is_conversation_start); + let (start_hooks, prompt_hooks): (Vec<_>, Vec<_>) = hook_results + .iter() + .partition(|(hook, _)| hook.trigger == HookTrigger::ConversationStart); ( (!start_hooks.is_empty()).then(|| format_context(&start_hooks, true)), @@ -907,7 +911,8 @@ where if ["y", "Y"].contains(&user_input.as_str()) { self.conversation_state.clear(true); if let Some(cm) = self.conversation_state.context_manager.as_mut() { - cm.hook_executor.execution_cache.clear(); + cm.hook_executor.global_cache.clear(); + cm.hook_executor.profile_cache.clear(); } execute!( self.output, @@ -1481,60 +1486,128 @@ where match subcommand { command::HooksSubcommand::Add { name, - r#type, + trigger, command, global, } => { - context_manager - .add_hook( - Hook::new_inline_hook(&name, command, false), - global, - r#type == "conversation_start", - ) - .await - .map_err(map_chat_error)?; - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nAdded {} hook '{name}'.\n\n", scope(global))), - style::SetForegroundColor(Color::Reset) - )?; + let trigger = if trigger == "conversation_start" { + HookTrigger::ConversationStart + } else { + HookTrigger::PerPrompt + }; + + let result = context_manager + .add_hook(name.clone(), Hook::new_inline_hook(trigger, command), global) + .await; + match result { + Ok(_) => { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!( + "\nAdded {} hook '{name}'.\n\n", + scope(global) + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!( + "\nCannot add {} hook '{name}': {}\n\n", + scope(global), + e + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + } }, command::HooksSubcommand::Remove { name, global } => { - context_manager - .remove_hook(&name, global) - .await - .map_err(map_chat_error)?; - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nRemoved {} hook '{name}'.\n\n", scope(global))), - style::SetForegroundColor(Color::Reset) - )?; + let result = context_manager.remove_hook(&name, global).await; + match result { + Ok(_) => { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!( + "\nRemoved {} hook '{name}'.\n\n", + scope(global) + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!( + "\nCannot remove {} hook '{name}': {}\n\n", + scope(global), + e + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + } }, command::HooksSubcommand::Enable { name, global } => { - context_manager - .set_hook_disabled(&name, global, false) - .await - .map_err(map_chat_error)?; - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nEnabled {} hook '{name}'.\n\n", scope(global))), - style::SetForegroundColor(Color::Reset) - )?; + let result = context_manager.set_hook_disabled(&name, global, false).await; + match result { + Ok(_) => { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!( + "\nEnabled {} hook '{name}'.\n\n", + scope(global) + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!( + "\nCannot enable {} hook '{name}': {}\n\n", + scope(global), + e + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + } }, command::HooksSubcommand::Disable { name, global } => { - context_manager - .set_hook_disabled(&name, global, true) - .await - .map_err(map_chat_error)?; - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nDisabled {} hook '{name}'.\n\n", scope(global))), - style::SetForegroundColor(Color::Reset) - )?; + let result = context_manager.set_hook_disabled(&name, global, true).await; + match result { + Ok(_) => { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!( + "\nDisabled {} hook '{name}'.\n\n", + scope(global) + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!( + "\nCannot disable {} hook '{name}': {}\n\n", + scope(global), + e + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + } }, command::HooksSubcommand::EnableAll { global } => { context_manager @@ -1572,14 +1645,16 @@ where } else { fn print_hook_section( output: &mut impl Write, - hooks: &Vec, - conversation_start: bool, + hooks: &HashMap, + trigger: HookTrigger, ) -> Result<()> { - let section = if conversation_start { - "Conversation start" - } else { - "Per prompt" + let section = match trigger { + HookTrigger::ConversationStart => "Conversation Start", + HookTrigger::PerPrompt => "Per Prompt", }; + let hooks: Vec<(&String, &Hook)> = + hooks.iter().filter(|(_, h)| h.trigger == trigger).collect(); + queue!( output, style::SetForegroundColor(Color::Cyan), @@ -1595,16 +1670,16 @@ where style::SetForegroundColor(Color::Reset) )?; } else { - for hook in hooks { + for (name, hook) in hooks { if hook.disabled { queue!( output, style::SetForegroundColor(Color::DarkGrey), - style::Print(format!(" {} (disabled)\n", hook.name)), + style::Print(format!(" {} (disabled)\n", name)), style::SetForegroundColor(Color::Reset) )?; } else { - queue!(output, style::Print(format!(" {}\n", hook.name)),)?; + queue!(output, style::Print(format!(" {}\n", name)),)?; } } } @@ -1620,14 +1695,14 @@ where print_hook_section( &mut self.output, - &context_manager.global_config.hooks.conversation_start, - true, + &context_manager.global_config.hooks, + HookTrigger::ConversationStart, ) .map_err(map_chat_error)?; print_hook_section( &mut self.output, - &context_manager.global_config.hooks.per_prompt, - false, + &context_manager.global_config.hooks, + HookTrigger::PerPrompt, ) .map_err(map_chat_error)?; @@ -1641,14 +1716,14 @@ where print_hook_section( &mut self.output, - &context_manager.profile_config.hooks.conversation_start, - true, + &context_manager.profile_config.hooks, + HookTrigger::ConversationStart, ) .map_err(map_chat_error)?; print_hook_section( &mut self.output, - &context_manager.profile_config.hooks.per_prompt, - false, + &context_manager.profile_config.hooks, + HookTrigger::PerPrompt, ) .map_err(map_chat_error)?;