Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/q_cli/src/cli/chat/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ pub enum HooksSubcommand {
name: String,

#[arg(long, value_parser = ["per_prompt", "conversation_start"])]
r#type: String,
trigger: String,

#[arg(long, value_parser = clap::value_parser!(String))]
command: String,
Expand Down Expand Up @@ -191,7 +191,7 @@ impl ContextSubcommand {

<em>hooks add [--global] <<name>></em> <black!>Add a new command context hook</black!>
<black!>--global: Add to global hooks</black!>
<em>--type <<type>></em> <black!>Type of hook, valid options: `per_prompt` or `conversation_start`</black!>
<em>--trigger <<trigger>></em> <black!>When to trigger the hook, valid options: `per_prompt` or `conversation_start`</black!>
<em>--command <<command>></em> <black!>Shell command to execute</black!>

<em>hooks rm [--global] <<name>></em> <black!>Remove an existing context hook</black!>
Expand Down Expand Up @@ -853,12 +853,12 @@ mod tests {
context!(ContextSubcommand::Hooks { subcommand: None }),
),
(
"/context hooks add test --type per_prompt --command 'echo 1' --global",
"/context hooks add test --trigger per_prompt --command 'echo 1' --global",
context!(ContextSubcommand::Hooks {
subcommand: Some(HooksSubcommand::Add {
name: "test".to_string(),
global: true,
r#type: "per_prompt".to_string(),
trigger: "per_prompt".to_string(),
command: "echo 1".to_string()
})
}),
Expand Down
101 changes: 37 additions & 64 deletions crates/q_cli/src/cli/chat/context.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::io::Write;
use std::path::{
Path,
Expand All @@ -22,7 +23,6 @@ use super::hooks::{
Hook,
HookExecutor,
};
use crate::cli::chat::hooks::HookConfig;

pub const AMAZONQ_FILENAME: &str = "AmazonQ.md";

Expand All @@ -32,7 +32,9 @@ pub const AMAZONQ_FILENAME: &str = "AmazonQ.md";
pub struct ContextConfig {
/// List of file paths or glob patterns to include in the context.
pub paths: Vec<String>,
pub hooks: HookConfig,

/// Map of Hook Name to [`Hook`]. The hook name serves as the hook's ID.
pub hooks: HashMap<String, Hook>,
}

#[allow(dead_code)]
Expand Down Expand Up @@ -353,6 +355,7 @@ impl ContextManager {
/// A Result indicating success or an error
pub async fn switch_profile(&mut self, name: &str) -> Result<()> {
validate_profile_name(name)?;
self.hook_executor.profile_cache.clear();

// Special handling for default profile - it always exists
if name == "default" {
Expand Down Expand Up @@ -459,37 +462,17 @@ impl ContextManager {
/// config.
/// * `conversation_start` - If true, add the hook to conversation_start. Otherwise, it will be
/// added to per_prompt.
pub async fn add_hook(&mut self, hook: Hook, global: bool, conversation_start: bool) -> Result<()> {
if self.num_hooks_with_name(&hook.name) > 0 {
return Err(eyre!(
"Cannot add hook, another hook with this name already exists in global or profile context."
));
}

pub async fn add_hook(&mut self, name: String, hook: Hook, global: bool) -> Result<()> {
let config = self.get_config_mut(global);

let hook_vec = if conversation_start {
&mut config.hooks.conversation_start
} else {
&mut config.hooks.per_prompt
};
if config.hooks.contains_key(&name) {
return Err(eyre!("name already exists."));
}

hook_vec.push(hook);
config.hooks.insert(name, hook);
self.save_config(global).await
}

fn num_hooks_with_name(&self, name: &str) -> usize {
self.global_config
.hooks
.conversation_start
.iter()
.chain(self.global_config.hooks.per_prompt.iter())
.chain(self.profile_config.hooks.conversation_start.iter())
.chain(self.profile_config.hooks.per_prompt.iter())
.filter(|h| h.name == name)
.count()
}

/// Delete hook(s) by name
/// # Arguments
/// * `name` - name of the hook to delete
Expand All @@ -498,8 +481,11 @@ impl ContextManager {
pub async fn remove_hook(&mut self, name: &str, global: bool) -> Result<()> {
let config = self.get_config_mut(global);

config.hooks.conversation_start.retain(|h| h.name != name);
config.hooks.per_prompt.retain(|h| h.name != name);
if !config.hooks.contains_key(name) {
return Err(eyre!("does not exist."));
}

config.hooks.remove(name);

self.save_config(global).await
}
Expand All @@ -510,13 +496,13 @@ impl ContextManager {
pub async fn set_hook_disabled(&mut self, name: &str, global: bool, disable: bool) -> Result<()> {
let config = self.get_config_mut(global);

config
.hooks
.conversation_start
.iter_mut()
.chain(config.hooks.per_prompt.iter_mut())
.filter(|h| h.name == name)
.for_each(|h| h.disabled = disable);
if !config.hooks.contains_key(name) {
return Err(eyre!("does not exist."));
}

if let Some(hook) = config.hooks.get_mut(name) {
hook.disabled = disable;
}

self.save_config(global).await
}
Expand All @@ -527,12 +513,7 @@ impl ContextManager {
pub async fn set_all_hooks_disabled(&mut self, global: bool, disable: bool) -> Result<()> {
let config = self.get_config_mut(global);

config
.hooks
.conversation_start
.iter_mut()
.chain(config.hooks.per_prompt.iter_mut())
.for_each(|h| h.disabled = disable);
config.hooks.iter_mut().for_each(|(_, h)| h.disabled = disable);

self.save_config(global).await
}
Expand All @@ -545,27 +526,19 @@ impl ContextManager {
pub async fn run_hooks(&mut self, updates: &mut impl Write) -> Vec<(Hook, String)> {
let mut hooks: Vec<&Hook> = Vec::new();

// Collect all conversation start hooks
hooks.extend(
self.global_config
.hooks
.conversation_start
.iter_mut()
.chain(self.profile_config.hooks.conversation_start.iter_mut())
.map(|h| {
h.is_conversation_start = true;
&*h
}),
);

// Collect all per-prompt hooks
hooks.extend(
self.global_config
.hooks
.per_prompt
.iter()
.chain(self.profile_config.hooks.per_prompt.iter()),
);
// Set internal hook states
let configs = [
(&mut self.global_config.hooks, true),
(&mut self.profile_config.hooks, false),
];

for (hook_list, is_global) in configs {
hooks.extend(hook_list.iter_mut().map(|(name, h)| {
h.name = name.to_string();
h.is_global = is_global;
&*h
}));
}

self.hook_executor.run_hooks(hooks, updates).await
}
Expand Down Expand Up @@ -600,7 +573,7 @@ async fn load_global_config(ctx: &Context) -> Result<ContextConfig> {
"README.md".to_string(),
AMAZONQ_FILENAME.to_string(),
],
hooks: HookConfig::default(),
hooks: HashMap::new(),
})
}
}
Expand Down
78 changes: 47 additions & 31 deletions crates/q_cli/src/cli/chat/hooks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ const DEFAULT_CACHE_TTL_SECONDS: u64 = 0;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Hook {
/// Unique name of the hook
pub name: String,
pub trigger: HookTrigger,

pub r#type: HookType,

Expand All @@ -47,30 +46,32 @@ pub struct Hook {
pub max_output_size: usize,

/// How long the hook output is cached before it will be executed again
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_ttl_seconds: Option<u64>,
#[serde(default = "Hook::default_cache_ttl_seconds")]
pub cache_ttl_seconds: u64,

// Type-specific fields
/// The bash command to execute
pub command: Option<String>, // For inline hooks

// Internal data
#[serde(skip)]
pub is_conversation_start: bool,
pub name: String,
#[serde(skip)]
pub is_global: bool,
}

impl Hook {
#[allow(dead_code)] // TODO: Remove
pub fn new_inline_hook(name: &str, command: String, is_conversation_start: bool) -> Self {
pub fn new_inline_hook(trigger: HookTrigger, command: String) -> Self {
Self {
name: name.to_string(),
trigger,
r#type: HookType::Inline,
disabled: Self::default_disabled(),
timeout_ms: Self::default_timeout_ms(),
max_output_size: Self::default_max_output_size(),
cache_ttl_seconds: None,
cache_ttl_seconds: Self::default_cache_ttl_seconds(),
command: Some(command),
is_conversation_start,
is_global: false,
name: "new hook".to_string(),
}
}

Expand All @@ -85,20 +86,24 @@ impl Hook {
fn default_max_output_size() -> usize {
DEFAULT_MAX_OUTPUT_SIZE
}

fn default_cache_ttl_seconds() -> u64 {
DEFAULT_CACHE_TTL_SECONDS
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum HookType {
// Execute an inline shell command
Inline,
}

#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
pub struct HookConfig {
pub conversation_start: Vec<Hook>,
pub per_prompt: Vec<Hook>,
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum HookTrigger {
ConversationStart,
PerPrompt,
}

#[derive(Debug, Clone)]
Expand All @@ -107,15 +112,18 @@ pub struct CachedHook {
expiry: Option<Instant>,
}

/// Maps a hook name to a [`CachedHook`]
#[derive(Debug, Clone)]
pub struct HookExecutor {
pub execution_cache: HashMap<String, CachedHook>,
pub global_cache: HashMap<String, CachedHook>,
pub profile_cache: HashMap<String, CachedHook>,
}

impl HookExecutor {
pub fn new() -> Self {
Self {
execution_cache: HashMap::new(),
global_cache: HashMap::new(),
profile_cache: HashMap::new(),
}
}

Expand All @@ -134,7 +142,7 @@ impl HookExecutor {
}

// Check if the hook is cached. If so, push a completed future.
if let Some(cached) = self.get_cache(&hook.name) {
if let Some(cached) = self.get_cache(hook) {
futures.push(Either::Left(future::ready((
hook,
Ok(cached.clone()),
Expand Down Expand Up @@ -186,16 +194,12 @@ impl HookExecutor {
for (hook, result, _) in results {
if result.is_ok() {
// Conversation start hooks are always cached as they are expected to run once per session.
let expiry = if hook.is_conversation_start {
None
} else {
Some(
Instant::now()
+ Duration::from_secs(hook.cache_ttl_seconds.unwrap_or(DEFAULT_CACHE_TTL_SECONDS)),
)
let expiry = match hook.trigger {
HookTrigger::ConversationStart => None,
HookTrigger::PerPrompt => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)),
};

self.insert_cache(&hook.name, CachedHook {
self.insert_cache(hook, CachedHook {
output: result.as_ref().cloned().unwrap(),
expiry,
});
Expand Down Expand Up @@ -247,8 +251,14 @@ impl HookExecutor {
}

/// Will return a cached hook's output if it exists and isn't expired.
fn get_cache(&self, name: &str) -> Option<String> {
self.execution_cache.get(name).and_then(|o| {
fn get_cache(&self, hook: &Hook) -> Option<String> {
let cache = if hook.is_global {
&self.global_cache
} else {
&self.profile_cache
};

cache.get(&hook.name).and_then(|o| {
if let Some(expiry) = o.expiry {
if Instant::now() < expiry {
Some(o.output.clone())
Expand All @@ -261,7 +271,13 @@ impl HookExecutor {
})
}

fn insert_cache(&mut self, name: &str, hook_output: CachedHook) {
self.execution_cache.insert(name.to_string(), hook_output);
fn insert_cache(&mut self, hook: &Hook, hook_output: CachedHook) {
let cache = if hook.is_global {
&mut self.global_cache
} else {
&mut self.profile_cache
};

cache.insert(hook.name.clone(), hook_output);
}
}
Loading
Loading