Skip to content

Commit 6363716

Browse files
authored
fix(context): prevent hook name duplicates (#1216)
Updated the schema to be: ``` { "hooks": { "my_hook_name": { ...hook data "trigger": "conversation_start" | "per_prompt", } } } ``` So that we do not have to worry about duplicate name hooks or generating an ID for a hook. Whether the user adds a new hook via commands or editing the JSON file, the invariant holds true. Additionally: - Separate execution caches for profile/global, so that we don't have to worry about name overlappying - also allows us to clear hook cache when switch profile - Rename CLI arg `--type` to `--trigger` for `/context hooks add ...`
1 parent 3522375 commit 6363716

File tree

4 files changed

+229
-165
lines changed

4 files changed

+229
-165
lines changed

crates/q_cli/src/cli/chat/command.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ pub enum HooksSubcommand {
104104
name: String,
105105

106106
#[arg(long, value_parser = ["per_prompt", "conversation_start"])]
107-
r#type: String,
107+
trigger: String,
108108

109109
#[arg(long, value_parser = clap::value_parser!(String))]
110110
command: String,
@@ -191,7 +191,7 @@ impl ContextSubcommand {
191191
192192
<em>hooks add [--global] <<name>></em> <black!>Add a new command context hook</black!>
193193
<black!>--global: Add to global hooks</black!>
194-
<em>--type <<type>></em> <black!>Type of hook, valid options: `per_prompt` or `conversation_start`</black!>
194+
<em>--trigger <<trigger>></em> <black!>When to trigger the hook, valid options: `per_prompt` or `conversation_start`</black!>
195195
<em>--command <<command>></em> <black!>Shell command to execute</black!>
196196
197197
<em>hooks rm [--global] <<name>></em> <black!>Remove an existing context hook</black!>
@@ -864,12 +864,12 @@ mod tests {
864864
context!(ContextSubcommand::Hooks { subcommand: None }),
865865
),
866866
(
867-
"/context hooks add test --type per_prompt --command 'echo 1' --global",
867+
"/context hooks add test --trigger per_prompt --command 'echo 1' --global",
868868
context!(ContextSubcommand::Hooks {
869869
subcommand: Some(HooksSubcommand::Add {
870870
name: "test".to_string(),
871871
global: true,
872-
r#type: "per_prompt".to_string(),
872+
trigger: "per_prompt".to_string(),
873873
command: "echo 1".to_string()
874874
})
875875
}),

crates/q_cli/src/cli/chat/context.rs

Lines changed: 37 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::HashMap;
12
use std::io::Write;
23
use std::path::{
34
Path,
@@ -22,7 +23,6 @@ use super::hooks::{
2223
Hook,
2324
HookExecutor,
2425
};
25-
use crate::cli::chat::hooks::HookConfig;
2626

2727
pub const AMAZONQ_FILENAME: &str = "AmazonQ.md";
2828

@@ -32,7 +32,9 @@ pub const AMAZONQ_FILENAME: &str = "AmazonQ.md";
3232
pub struct ContextConfig {
3333
/// List of file paths or glob patterns to include in the context.
3434
pub paths: Vec<String>,
35-
pub hooks: HookConfig,
35+
36+
/// Map of Hook Name to [`Hook`]. The hook name serves as the hook's ID.
37+
pub hooks: HashMap<String, Hook>,
3638
}
3739

3840
#[allow(dead_code)]
@@ -353,6 +355,7 @@ impl ContextManager {
353355
/// A Result indicating success or an error
354356
pub async fn switch_profile(&mut self, name: &str) -> Result<()> {
355357
validate_profile_name(name)?;
358+
self.hook_executor.profile_cache.clear();
356359

357360
// Special handling for default profile - it always exists
358361
if name == "default" {
@@ -459,37 +462,17 @@ impl ContextManager {
459462
/// config.
460463
/// * `conversation_start` - If true, add the hook to conversation_start. Otherwise, it will be
461464
/// added to per_prompt.
462-
pub async fn add_hook(&mut self, hook: Hook, global: bool, conversation_start: bool) -> Result<()> {
463-
if self.num_hooks_with_name(&hook.name) > 0 {
464-
return Err(eyre!(
465-
"Cannot add hook, another hook with this name already exists in global or profile context."
466-
));
467-
}
468-
465+
pub async fn add_hook(&mut self, name: String, hook: Hook, global: bool) -> Result<()> {
469466
let config = self.get_config_mut(global);
470467

471-
let hook_vec = if conversation_start {
472-
&mut config.hooks.conversation_start
473-
} else {
474-
&mut config.hooks.per_prompt
475-
};
468+
if config.hooks.contains_key(&name) {
469+
return Err(eyre!("name already exists."));
470+
}
476471

477-
hook_vec.push(hook);
472+
config.hooks.insert(name, hook);
478473
self.save_config(global).await
479474
}
480475

481-
fn num_hooks_with_name(&self, name: &str) -> usize {
482-
self.global_config
483-
.hooks
484-
.conversation_start
485-
.iter()
486-
.chain(self.global_config.hooks.per_prompt.iter())
487-
.chain(self.profile_config.hooks.conversation_start.iter())
488-
.chain(self.profile_config.hooks.per_prompt.iter())
489-
.filter(|h| h.name == name)
490-
.count()
491-
}
492-
493476
/// Delete hook(s) by name
494477
/// # Arguments
495478
/// * `name` - name of the hook to delete
@@ -498,8 +481,11 @@ impl ContextManager {
498481
pub async fn remove_hook(&mut self, name: &str, global: bool) -> Result<()> {
499482
let config = self.get_config_mut(global);
500483

501-
config.hooks.conversation_start.retain(|h| h.name != name);
502-
config.hooks.per_prompt.retain(|h| h.name != name);
484+
if !config.hooks.contains_key(name) {
485+
return Err(eyre!("does not exist."));
486+
}
487+
488+
config.hooks.remove(name);
503489

504490
self.save_config(global).await
505491
}
@@ -510,13 +496,13 @@ impl ContextManager {
510496
pub async fn set_hook_disabled(&mut self, name: &str, global: bool, disable: bool) -> Result<()> {
511497
let config = self.get_config_mut(global);
512498

513-
config
514-
.hooks
515-
.conversation_start
516-
.iter_mut()
517-
.chain(config.hooks.per_prompt.iter_mut())
518-
.filter(|h| h.name == name)
519-
.for_each(|h| h.disabled = disable);
499+
if !config.hooks.contains_key(name) {
500+
return Err(eyre!("does not exist."));
501+
}
502+
503+
if let Some(hook) = config.hooks.get_mut(name) {
504+
hook.disabled = disable;
505+
}
520506

521507
self.save_config(global).await
522508
}
@@ -527,12 +513,7 @@ impl ContextManager {
527513
pub async fn set_all_hooks_disabled(&mut self, global: bool, disable: bool) -> Result<()> {
528514
let config = self.get_config_mut(global);
529515

530-
config
531-
.hooks
532-
.conversation_start
533-
.iter_mut()
534-
.chain(config.hooks.per_prompt.iter_mut())
535-
.for_each(|h| h.disabled = disable);
516+
config.hooks.iter_mut().for_each(|(_, h)| h.disabled = disable);
536517

537518
self.save_config(global).await
538519
}
@@ -545,27 +526,19 @@ impl ContextManager {
545526
pub async fn run_hooks(&mut self, updates: &mut impl Write) -> Vec<(Hook, String)> {
546527
let mut hooks: Vec<&Hook> = Vec::new();
547528

548-
// Collect all conversation start hooks
549-
hooks.extend(
550-
self.global_config
551-
.hooks
552-
.conversation_start
553-
.iter_mut()
554-
.chain(self.profile_config.hooks.conversation_start.iter_mut())
555-
.map(|h| {
556-
h.is_conversation_start = true;
557-
&*h
558-
}),
559-
);
560-
561-
// Collect all per-prompt hooks
562-
hooks.extend(
563-
self.global_config
564-
.hooks
565-
.per_prompt
566-
.iter()
567-
.chain(self.profile_config.hooks.per_prompt.iter()),
568-
);
529+
// Set internal hook states
530+
let configs = [
531+
(&mut self.global_config.hooks, true),
532+
(&mut self.profile_config.hooks, false),
533+
];
534+
535+
for (hook_list, is_global) in configs {
536+
hooks.extend(hook_list.iter_mut().map(|(name, h)| {
537+
h.name = name.to_string();
538+
h.is_global = is_global;
539+
&*h
540+
}));
541+
}
569542

570543
self.hook_executor.run_hooks(hooks, updates).await
571544
}
@@ -600,7 +573,7 @@ async fn load_global_config(ctx: &Context) -> Result<ContextConfig> {
600573
"README.md".to_string(),
601574
AMAZONQ_FILENAME.to_string(),
602575
],
603-
hooks: HookConfig::default(),
576+
hooks: HashMap::new(),
604577
})
605578
}
606579
}

crates/q_cli/src/cli/chat/hooks.rs

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ const DEFAULT_CACHE_TTL_SECONDS: u64 = 0;
3030

3131
#[derive(Debug, Clone, Serialize, Deserialize)]
3232
pub struct Hook {
33-
/// Unique name of the hook
34-
pub name: String,
33+
pub trigger: HookTrigger,
3534

3635
pub r#type: HookType,
3736

@@ -47,30 +46,32 @@ pub struct Hook {
4746
pub max_output_size: usize,
4847

4948
/// How long the hook output is cached before it will be executed again
50-
#[serde(skip_serializing_if = "Option::is_none")]
51-
pub cache_ttl_seconds: Option<u64>,
49+
#[serde(default = "Hook::default_cache_ttl_seconds")]
50+
pub cache_ttl_seconds: u64,
5251

5352
// Type-specific fields
5453
/// The bash command to execute
5554
pub command: Option<String>, // For inline hooks
5655

5756
// Internal data
5857
#[serde(skip)]
59-
pub is_conversation_start: bool,
58+
pub name: String,
59+
#[serde(skip)]
60+
pub is_global: bool,
6061
}
6162

6263
impl Hook {
63-
#[allow(dead_code)] // TODO: Remove
64-
pub fn new_inline_hook(name: &str, command: String, is_conversation_start: bool) -> Self {
64+
pub fn new_inline_hook(trigger: HookTrigger, command: String) -> Self {
6565
Self {
66-
name: name.to_string(),
66+
trigger,
6767
r#type: HookType::Inline,
6868
disabled: Self::default_disabled(),
6969
timeout_ms: Self::default_timeout_ms(),
7070
max_output_size: Self::default_max_output_size(),
71-
cache_ttl_seconds: None,
71+
cache_ttl_seconds: Self::default_cache_ttl_seconds(),
7272
command: Some(command),
73-
is_conversation_start,
73+
is_global: false,
74+
name: "new hook".to_string(),
7475
}
7576
}
7677

@@ -85,20 +86,24 @@ impl Hook {
8586
fn default_max_output_size() -> usize {
8687
DEFAULT_MAX_OUTPUT_SIZE
8788
}
89+
90+
fn default_cache_ttl_seconds() -> u64 {
91+
DEFAULT_CACHE_TTL_SECONDS
92+
}
8893
}
8994

90-
#[derive(Debug, Clone, Serialize, Deserialize)]
95+
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
9196
#[serde(rename_all = "lowercase")]
9297
pub enum HookType {
9398
// Execute an inline shell command
9499
Inline,
95100
}
96101

97-
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
98-
#[serde(default)]
99-
pub struct HookConfig {
100-
pub conversation_start: Vec<Hook>,
101-
pub per_prompt: Vec<Hook>,
102+
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
103+
#[serde(rename_all = "snake_case")]
104+
pub enum HookTrigger {
105+
ConversationStart,
106+
PerPrompt,
102107
}
103108

104109
#[derive(Debug, Clone)]
@@ -107,15 +112,18 @@ pub struct CachedHook {
107112
expiry: Option<Instant>,
108113
}
109114

115+
/// Maps a hook name to a [`CachedHook`]
110116
#[derive(Debug, Clone)]
111117
pub struct HookExecutor {
112-
pub execution_cache: HashMap<String, CachedHook>,
118+
pub global_cache: HashMap<String, CachedHook>,
119+
pub profile_cache: HashMap<String, CachedHook>,
113120
}
114121

115122
impl HookExecutor {
116123
pub fn new() -> Self {
117124
Self {
118-
execution_cache: HashMap::new(),
125+
global_cache: HashMap::new(),
126+
profile_cache: HashMap::new(),
119127
}
120128
}
121129

@@ -134,7 +142,7 @@ impl HookExecutor {
134142
}
135143

136144
// Check if the hook is cached. If so, push a completed future.
137-
if let Some(cached) = self.get_cache(&hook.name) {
145+
if let Some(cached) = self.get_cache(hook) {
138146
futures.push(Either::Left(future::ready((
139147
hook,
140148
Ok(cached.clone()),
@@ -186,16 +194,12 @@ impl HookExecutor {
186194
for (hook, result, _) in results {
187195
if result.is_ok() {
188196
// Conversation start hooks are always cached as they are expected to run once per session.
189-
let expiry = if hook.is_conversation_start {
190-
None
191-
} else {
192-
Some(
193-
Instant::now()
194-
+ Duration::from_secs(hook.cache_ttl_seconds.unwrap_or(DEFAULT_CACHE_TTL_SECONDS)),
195-
)
197+
let expiry = match hook.trigger {
198+
HookTrigger::ConversationStart => None,
199+
HookTrigger::PerPrompt => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)),
196200
};
197201

198-
self.insert_cache(&hook.name, CachedHook {
202+
self.insert_cache(hook, CachedHook {
199203
output: result.as_ref().cloned().unwrap(),
200204
expiry,
201205
});
@@ -247,8 +251,14 @@ impl HookExecutor {
247251
}
248252

249253
/// Will return a cached hook's output if it exists and isn't expired.
250-
fn get_cache(&self, name: &str) -> Option<String> {
251-
self.execution_cache.get(name).and_then(|o| {
254+
fn get_cache(&self, hook: &Hook) -> Option<String> {
255+
let cache = if hook.is_global {
256+
&self.global_cache
257+
} else {
258+
&self.profile_cache
259+
};
260+
261+
cache.get(&hook.name).and_then(|o| {
252262
if let Some(expiry) = o.expiry {
253263
if Instant::now() < expiry {
254264
Some(o.output.clone())
@@ -261,7 +271,13 @@ impl HookExecutor {
261271
})
262272
}
263273

264-
fn insert_cache(&mut self, name: &str, hook_output: CachedHook) {
265-
self.execution_cache.insert(name.to_string(), hook_output);
274+
fn insert_cache(&mut self, hook: &Hook, hook_output: CachedHook) {
275+
let cache = if hook.is_global {
276+
&mut self.global_cache
277+
} else {
278+
&mut self.profile_cache
279+
};
280+
281+
cache.insert(hook.name.clone(), hook_output);
266282
}
267283
}

0 commit comments

Comments
 (0)