diff --git a/crates/q_cli/src/cli/chat/command.rs b/crates/q_cli/src/cli/chat/command.rs
index fcde98d46e..9f81659d3c 100644
--- a/crates/q_cli/src/cli/chat/command.rs
+++ b/crates/q_cli/src/cli/chat/command.rs
@@ -104,7 +104,7 @@ pub enum HooksSubcommand {
name: String,
#[arg(long, value_parser = ["per_prompt", "conversation_start"])]
- r#type: String,
+ trigger: String,
#[arg(long, value_parser = clap::value_parser!(String))]
command: String,
@@ -191,7 +191,7 @@ impl ContextSubcommand {
hooks add [--global] <> Add a new command context hook
--global: Add to global hooks
- --type <> Type of hook, valid options: `per_prompt` or `conversation_start`
+ --trigger <> When to trigger the hook, valid options: `per_prompt` or `conversation_start`
--command <> Shell command to execute
hooks rm [--global] <> Remove an existing context hook
@@ -853,12 +853,12 @@ mod tests {
context!(ContextSubcommand::Hooks { subcommand: None }),
),
(
- "/context hooks add test --type per_prompt --command 'echo 1' --global",
+ "/context hooks add test --trigger per_prompt --command 'echo 1' --global",
context!(ContextSubcommand::Hooks {
subcommand: Some(HooksSubcommand::Add {
name: "test".to_string(),
global: true,
- r#type: "per_prompt".to_string(),
+ trigger: "per_prompt".to_string(),
command: "echo 1".to_string()
})
}),
diff --git a/crates/q_cli/src/cli/chat/context.rs b/crates/q_cli/src/cli/chat/context.rs
index 57c039bd25..86bdaa84cb 100644
--- a/crates/q_cli/src/cli/chat/context.rs
+++ b/crates/q_cli/src/cli/chat/context.rs
@@ -1,3 +1,4 @@
+use std::collections::HashMap;
use std::io::Write;
use std::path::{
Path,
@@ -22,7 +23,6 @@ use super::hooks::{
Hook,
HookExecutor,
};
-use crate::cli::chat::hooks::HookConfig;
pub const AMAZONQ_FILENAME: &str = "AmazonQ.md";
@@ -32,7 +32,9 @@ pub const AMAZONQ_FILENAME: &str = "AmazonQ.md";
pub struct ContextConfig {
/// List of file paths or glob patterns to include in the context.
pub paths: Vec,
- pub hooks: HookConfig,
+
+ /// Map of Hook Name to [`Hook`]. The hook name serves as the hook's ID.
+ pub hooks: HashMap,
}
#[allow(dead_code)]
@@ -353,6 +355,7 @@ impl ContextManager {
/// A Result indicating success or an error
pub async fn switch_profile(&mut self, name: &str) -> Result<()> {
validate_profile_name(name)?;
+ self.hook_executor.profile_cache.clear();
// Special handling for default profile - it always exists
if name == "default" {
@@ -459,37 +462,17 @@ impl ContextManager {
/// config.
/// * `conversation_start` - If true, add the hook to conversation_start. Otherwise, it will be
/// added to per_prompt.
- pub async fn add_hook(&mut self, hook: Hook, global: bool, conversation_start: bool) -> Result<()> {
- if self.num_hooks_with_name(&hook.name) > 0 {
- return Err(eyre!(
- "Cannot add hook, another hook with this name already exists in global or profile context."
- ));
- }
-
+ pub async fn add_hook(&mut self, name: String, hook: Hook, global: bool) -> Result<()> {
let config = self.get_config_mut(global);
- let hook_vec = if conversation_start {
- &mut config.hooks.conversation_start
- } else {
- &mut config.hooks.per_prompt
- };
+ if config.hooks.contains_key(&name) {
+ return Err(eyre!("name already exists."));
+ }
- hook_vec.push(hook);
+ config.hooks.insert(name, hook);
self.save_config(global).await
}
- fn num_hooks_with_name(&self, name: &str) -> usize {
- self.global_config
- .hooks
- .conversation_start
- .iter()
- .chain(self.global_config.hooks.per_prompt.iter())
- .chain(self.profile_config.hooks.conversation_start.iter())
- .chain(self.profile_config.hooks.per_prompt.iter())
- .filter(|h| h.name == name)
- .count()
- }
-
/// Delete hook(s) by name
/// # Arguments
/// * `name` - name of the hook to delete
@@ -498,8 +481,11 @@ impl ContextManager {
pub async fn remove_hook(&mut self, name: &str, global: bool) -> Result<()> {
let config = self.get_config_mut(global);
- config.hooks.conversation_start.retain(|h| h.name != name);
- config.hooks.per_prompt.retain(|h| h.name != name);
+ if !config.hooks.contains_key(name) {
+ return Err(eyre!("does not exist."));
+ }
+
+ config.hooks.remove(name);
self.save_config(global).await
}
@@ -510,13 +496,13 @@ impl ContextManager {
pub async fn set_hook_disabled(&mut self, name: &str, global: bool, disable: bool) -> Result<()> {
let config = self.get_config_mut(global);
- config
- .hooks
- .conversation_start
- .iter_mut()
- .chain(config.hooks.per_prompt.iter_mut())
- .filter(|h| h.name == name)
- .for_each(|h| h.disabled = disable);
+ if !config.hooks.contains_key(name) {
+ return Err(eyre!("does not exist."));
+ }
+
+ if let Some(hook) = config.hooks.get_mut(name) {
+ hook.disabled = disable;
+ }
self.save_config(global).await
}
@@ -527,12 +513,7 @@ impl ContextManager {
pub async fn set_all_hooks_disabled(&mut self, global: bool, disable: bool) -> Result<()> {
let config = self.get_config_mut(global);
- config
- .hooks
- .conversation_start
- .iter_mut()
- .chain(config.hooks.per_prompt.iter_mut())
- .for_each(|h| h.disabled = disable);
+ config.hooks.iter_mut().for_each(|(_, h)| h.disabled = disable);
self.save_config(global).await
}
@@ -545,27 +526,19 @@ impl ContextManager {
pub async fn run_hooks(&mut self, updates: &mut impl Write) -> Vec<(Hook, String)> {
let mut hooks: Vec<&Hook> = Vec::new();
- // Collect all conversation start hooks
- hooks.extend(
- self.global_config
- .hooks
- .conversation_start
- .iter_mut()
- .chain(self.profile_config.hooks.conversation_start.iter_mut())
- .map(|h| {
- h.is_conversation_start = true;
- &*h
- }),
- );
-
- // Collect all per-prompt hooks
- hooks.extend(
- self.global_config
- .hooks
- .per_prompt
- .iter()
- .chain(self.profile_config.hooks.per_prompt.iter()),
- );
+ // Set internal hook states
+ let configs = [
+ (&mut self.global_config.hooks, true),
+ (&mut self.profile_config.hooks, false),
+ ];
+
+ for (hook_list, is_global) in configs {
+ hooks.extend(hook_list.iter_mut().map(|(name, h)| {
+ h.name = name.to_string();
+ h.is_global = is_global;
+ &*h
+ }));
+ }
self.hook_executor.run_hooks(hooks, updates).await
}
@@ -600,7 +573,7 @@ async fn load_global_config(ctx: &Context) -> Result {
"README.md".to_string(),
AMAZONQ_FILENAME.to_string(),
],
- hooks: HookConfig::default(),
+ hooks: HashMap::new(),
})
}
}
diff --git a/crates/q_cli/src/cli/chat/hooks.rs b/crates/q_cli/src/cli/chat/hooks.rs
index c08176e313..5dd4407850 100644
--- a/crates/q_cli/src/cli/chat/hooks.rs
+++ b/crates/q_cli/src/cli/chat/hooks.rs
@@ -30,8 +30,7 @@ const DEFAULT_CACHE_TTL_SECONDS: u64 = 0;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Hook {
- /// Unique name of the hook
- pub name: String,
+ pub trigger: HookTrigger,
pub r#type: HookType,
@@ -47,8 +46,8 @@ pub struct Hook {
pub max_output_size: usize,
/// How long the hook output is cached before it will be executed again
- #[serde(skip_serializing_if = "Option::is_none")]
- pub cache_ttl_seconds: Option,
+ #[serde(default = "Hook::default_cache_ttl_seconds")]
+ pub cache_ttl_seconds: u64,
// Type-specific fields
/// The bash command to execute
@@ -56,21 +55,23 @@ pub struct Hook {
// Internal data
#[serde(skip)]
- pub is_conversation_start: bool,
+ pub name: String,
+ #[serde(skip)]
+ pub is_global: bool,
}
impl Hook {
- #[allow(dead_code)] // TODO: Remove
- pub fn new_inline_hook(name: &str, command: String, is_conversation_start: bool) -> Self {
+ pub fn new_inline_hook(trigger: HookTrigger, command: String) -> Self {
Self {
- name: name.to_string(),
+ trigger,
r#type: HookType::Inline,
disabled: Self::default_disabled(),
timeout_ms: Self::default_timeout_ms(),
max_output_size: Self::default_max_output_size(),
- cache_ttl_seconds: None,
+ cache_ttl_seconds: Self::default_cache_ttl_seconds(),
command: Some(command),
- is_conversation_start,
+ is_global: false,
+ name: "new hook".to_string(),
}
}
@@ -85,20 +86,24 @@ impl Hook {
fn default_max_output_size() -> usize {
DEFAULT_MAX_OUTPUT_SIZE
}
+
+ fn default_cache_ttl_seconds() -> u64 {
+ DEFAULT_CACHE_TTL_SECONDS
+ }
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
+#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum HookType {
// Execute an inline shell command
Inline,
}
-#[derive(Debug, Clone, Serialize, Deserialize, Default)]
-#[serde(default)]
-pub struct HookConfig {
- pub conversation_start: Vec,
- pub per_prompt: Vec,
+#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
+#[serde(rename_all = "snake_case")]
+pub enum HookTrigger {
+ ConversationStart,
+ PerPrompt,
}
#[derive(Debug, Clone)]
@@ -107,15 +112,18 @@ pub struct CachedHook {
expiry: Option,
}
+/// Maps a hook name to a [`CachedHook`]
#[derive(Debug, Clone)]
pub struct HookExecutor {
- pub execution_cache: HashMap,
+ pub global_cache: HashMap,
+ pub profile_cache: HashMap,
}
impl HookExecutor {
pub fn new() -> Self {
Self {
- execution_cache: HashMap::new(),
+ global_cache: HashMap::new(),
+ profile_cache: HashMap::new(),
}
}
@@ -134,7 +142,7 @@ impl HookExecutor {
}
// Check if the hook is cached. If so, push a completed future.
- if let Some(cached) = self.get_cache(&hook.name) {
+ if let Some(cached) = self.get_cache(hook) {
futures.push(Either::Left(future::ready((
hook,
Ok(cached.clone()),
@@ -186,16 +194,12 @@ impl HookExecutor {
for (hook, result, _) in results {
if result.is_ok() {
// Conversation start hooks are always cached as they are expected to run once per session.
- let expiry = if hook.is_conversation_start {
- None
- } else {
- Some(
- Instant::now()
- + Duration::from_secs(hook.cache_ttl_seconds.unwrap_or(DEFAULT_CACHE_TTL_SECONDS)),
- )
+ let expiry = match hook.trigger {
+ HookTrigger::ConversationStart => None,
+ HookTrigger::PerPrompt => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)),
};
- self.insert_cache(&hook.name, CachedHook {
+ self.insert_cache(hook, CachedHook {
output: result.as_ref().cloned().unwrap(),
expiry,
});
@@ -247,8 +251,14 @@ impl HookExecutor {
}
/// Will return a cached hook's output if it exists and isn't expired.
- fn get_cache(&self, name: &str) -> Option {
- self.execution_cache.get(name).and_then(|o| {
+ fn get_cache(&self, hook: &Hook) -> Option {
+ let cache = if hook.is_global {
+ &self.global_cache
+ } else {
+ &self.profile_cache
+ };
+
+ cache.get(&hook.name).and_then(|o| {
if let Some(expiry) = o.expiry {
if Instant::now() < expiry {
Some(o.output.clone())
@@ -261,7 +271,13 @@ impl HookExecutor {
})
}
- fn insert_cache(&mut self, name: &str, hook_output: CachedHook) {
- self.execution_cache.insert(name.to_string(), hook_output);
+ fn insert_cache(&mut self, hook: &Hook, hook_output: CachedHook) {
+ let cache = if hook.is_global {
+ &mut self.global_cache
+ } else {
+ &mut self.profile_cache
+ };
+
+ cache.insert(hook.name.clone(), hook_output);
}
}
diff --git a/crates/q_cli/src/cli/chat/mod.rs b/crates/q_cli/src/cli/chat/mod.rs
index 80f09fbbc6..aa6f063410 100644
--- a/crates/q_cli/src/cli/chat/mod.rs
+++ b/crates/q_cli/src/cli/chat/mod.rs
@@ -68,7 +68,10 @@ use fig_api_client::model::{
use fig_os_shim::Context;
use fig_settings::Settings;
use fig_util::CLI_BINARY_NAME;
-use hooks::Hook;
+use hooks::{
+ Hook,
+ HookTrigger,
+};
use summarization_state::{
SummarizationState,
TokenWarningLevel,
@@ -829,8 +832,9 @@ where
let hook_results = cm.run_hooks(&mut self.output).await;
- let (start_hooks, prompt_hooks): (Vec<_>, Vec<_>) =
- hook_results.iter().partition(|(hook, _)| hook.is_conversation_start);
+ let (start_hooks, prompt_hooks): (Vec<_>, Vec<_>) = hook_results
+ .iter()
+ .partition(|(hook, _)| hook.trigger == HookTrigger::ConversationStart);
(
(!start_hooks.is_empty()).then(|| format_context(&start_hooks, true)),
@@ -907,7 +911,8 @@ where
if ["y", "Y"].contains(&user_input.as_str()) {
self.conversation_state.clear(true);
if let Some(cm) = self.conversation_state.context_manager.as_mut() {
- cm.hook_executor.execution_cache.clear();
+ cm.hook_executor.global_cache.clear();
+ cm.hook_executor.profile_cache.clear();
}
execute!(
self.output,
@@ -1481,60 +1486,128 @@ where
match subcommand {
command::HooksSubcommand::Add {
name,
- r#type,
+ trigger,
command,
global,
} => {
- context_manager
- .add_hook(
- Hook::new_inline_hook(&name, command, false),
- global,
- r#type == "conversation_start",
- )
- .await
- .map_err(map_chat_error)?;
- execute!(
- self.output,
- style::SetForegroundColor(Color::Green),
- style::Print(format!("\nAdded {} hook '{name}'.\n\n", scope(global))),
- style::SetForegroundColor(Color::Reset)
- )?;
+ let trigger = if trigger == "conversation_start" {
+ HookTrigger::ConversationStart
+ } else {
+ HookTrigger::PerPrompt
+ };
+
+ let result = context_manager
+ .add_hook(name.clone(), Hook::new_inline_hook(trigger, command), global)
+ .await;
+ match result {
+ Ok(_) => {
+ execute!(
+ self.output,
+ style::SetForegroundColor(Color::Green),
+ style::Print(format!(
+ "\nAdded {} hook '{name}'.\n\n",
+ scope(global)
+ )),
+ style::SetForegroundColor(Color::Reset)
+ )?;
+ },
+ Err(e) => {
+ execute!(
+ self.output,
+ style::SetForegroundColor(Color::Red),
+ style::Print(format!(
+ "\nCannot add {} hook '{name}': {}\n\n",
+ scope(global),
+ e
+ )),
+ style::SetForegroundColor(Color::Reset)
+ )?;
+ },
+ }
},
command::HooksSubcommand::Remove { name, global } => {
- context_manager
- .remove_hook(&name, global)
- .await
- .map_err(map_chat_error)?;
- execute!(
- self.output,
- style::SetForegroundColor(Color::Green),
- style::Print(format!("\nRemoved {} hook '{name}'.\n\n", scope(global))),
- style::SetForegroundColor(Color::Reset)
- )?;
+ let result = context_manager.remove_hook(&name, global).await;
+ match result {
+ Ok(_) => {
+ execute!(
+ self.output,
+ style::SetForegroundColor(Color::Green),
+ style::Print(format!(
+ "\nRemoved {} hook '{name}'.\n\n",
+ scope(global)
+ )),
+ style::SetForegroundColor(Color::Reset)
+ )?;
+ },
+ Err(e) => {
+ execute!(
+ self.output,
+ style::SetForegroundColor(Color::Red),
+ style::Print(format!(
+ "\nCannot remove {} hook '{name}': {}\n\n",
+ scope(global),
+ e
+ )),
+ style::SetForegroundColor(Color::Reset)
+ )?;
+ },
+ }
},
command::HooksSubcommand::Enable { name, global } => {
- context_manager
- .set_hook_disabled(&name, global, false)
- .await
- .map_err(map_chat_error)?;
- execute!(
- self.output,
- style::SetForegroundColor(Color::Green),
- style::Print(format!("\nEnabled {} hook '{name}'.\n\n", scope(global))),
- style::SetForegroundColor(Color::Reset)
- )?;
+ let result = context_manager.set_hook_disabled(&name, global, false).await;
+ match result {
+ Ok(_) => {
+ execute!(
+ self.output,
+ style::SetForegroundColor(Color::Green),
+ style::Print(format!(
+ "\nEnabled {} hook '{name}'.\n\n",
+ scope(global)
+ )),
+ style::SetForegroundColor(Color::Reset)
+ )?;
+ },
+ Err(e) => {
+ execute!(
+ self.output,
+ style::SetForegroundColor(Color::Red),
+ style::Print(format!(
+ "\nCannot enable {} hook '{name}': {}\n\n",
+ scope(global),
+ e
+ )),
+ style::SetForegroundColor(Color::Reset)
+ )?;
+ },
+ }
},
command::HooksSubcommand::Disable { name, global } => {
- context_manager
- .set_hook_disabled(&name, global, true)
- .await
- .map_err(map_chat_error)?;
- execute!(
- self.output,
- style::SetForegroundColor(Color::Green),
- style::Print(format!("\nDisabled {} hook '{name}'.\n\n", scope(global))),
- style::SetForegroundColor(Color::Reset)
- )?;
+ let result = context_manager.set_hook_disabled(&name, global, true).await;
+ match result {
+ Ok(_) => {
+ execute!(
+ self.output,
+ style::SetForegroundColor(Color::Green),
+ style::Print(format!(
+ "\nDisabled {} hook '{name}'.\n\n",
+ scope(global)
+ )),
+ style::SetForegroundColor(Color::Reset)
+ )?;
+ },
+ Err(e) => {
+ execute!(
+ self.output,
+ style::SetForegroundColor(Color::Red),
+ style::Print(format!(
+ "\nCannot disable {} hook '{name}': {}\n\n",
+ scope(global),
+ e
+ )),
+ style::SetForegroundColor(Color::Reset)
+ )?;
+ },
+ }
},
command::HooksSubcommand::EnableAll { global } => {
context_manager
@@ -1572,14 +1645,16 @@ where
} else {
fn print_hook_section(
output: &mut impl Write,
- hooks: &Vec,
- conversation_start: bool,
+ hooks: &HashMap,
+ trigger: HookTrigger,
) -> Result<()> {
- let section = if conversation_start {
- "Conversation start"
- } else {
- "Per prompt"
+ let section = match trigger {
+ HookTrigger::ConversationStart => "Conversation Start",
+ HookTrigger::PerPrompt => "Per Prompt",
};
+ let hooks: Vec<(&String, &Hook)> =
+ hooks.iter().filter(|(_, h)| h.trigger == trigger).collect();
+
queue!(
output,
style::SetForegroundColor(Color::Cyan),
@@ -1595,16 +1670,16 @@ where
style::SetForegroundColor(Color::Reset)
)?;
} else {
- for hook in hooks {
+ for (name, hook) in hooks {
if hook.disabled {
queue!(
output,
style::SetForegroundColor(Color::DarkGrey),
- style::Print(format!(" {} (disabled)\n", hook.name)),
+ style::Print(format!(" {} (disabled)\n", name)),
style::SetForegroundColor(Color::Reset)
)?;
} else {
- queue!(output, style::Print(format!(" {}\n", hook.name)),)?;
+ queue!(output, style::Print(format!(" {}\n", name)),)?;
}
}
}
@@ -1620,14 +1695,14 @@ where
print_hook_section(
&mut self.output,
- &context_manager.global_config.hooks.conversation_start,
- true,
+ &context_manager.global_config.hooks,
+ HookTrigger::ConversationStart,
)
.map_err(map_chat_error)?;
print_hook_section(
&mut self.output,
- &context_manager.global_config.hooks.per_prompt,
- false,
+ &context_manager.global_config.hooks,
+ HookTrigger::PerPrompt,
)
.map_err(map_chat_error)?;
@@ -1641,14 +1716,14 @@ where
print_hook_section(
&mut self.output,
- &context_manager.profile_config.hooks.conversation_start,
- true,
+ &context_manager.profile_config.hooks,
+ HookTrigger::ConversationStart,
)
.map_err(map_chat_error)?;
print_hook_section(
&mut self.output,
- &context_manager.profile_config.hooks.per_prompt,
- false,
+ &context_manager.profile_config.hooks,
+ HookTrigger::PerPrompt,
)
.map_err(map_chat_error)?;