Skip to content

Commit 67879aa

Browse files
fix: formatting hooks to enable correct caching with the API (#461)
1 parent 9c083cf commit 67879aa

File tree

4 files changed

+58
-25
lines changed

4 files changed

+58
-25
lines changed

crates/chat-cli/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ rustyline = { version = "15.0.0", features = [
117117
semantic_search_client = { path = "../semantic_search_client" }
118118
semver = { version = "1.0.26", features = ["serde"] }
119119
serde = { version = "1.0.219", features = ["derive", "rc"] }
120-
serde_json = "1.0.140"
120+
serde_json = { version = "1.0.140", features = ["preserve_order"] }
121121
sha2 = "0.10.9"
122122
shell-color = "1.0.0"
123123
shell-words = "1.1.0"

crates/chat-cli/src/cli/chat/cli/hooks.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -873,8 +873,14 @@ mod tests {
873873
manager.add_hook(&os, "hook2".to_string(), hook2, false).await?;
874874

875875
// Run the hooks
876-
let results = manager.run_hooks(&mut vec![]).await.unwrap();
877-
assert_eq!(results.len(), 2); // Should include both hooks
876+
let results = manager
877+
.run_hooks(HookTrigger::ConversationStart, &mut vec![])
878+
.await
879+
.unwrap();
880+
assert_eq!(results.len(), 2);
881+
882+
let results = manager.run_hooks(HookTrigger::PerPrompt, &mut vec![]).await.unwrap();
883+
assert_eq!(results.len(), 0);
878884

879885
Ok(())
880886
}
@@ -889,14 +895,20 @@ mod tests {
889895
manager.add_hook(&os, "profile_hook".to_string(), hook1, false).await?;
890896
manager.add_hook(&os, "global_hook".to_string(), hook2, true).await?;
891897

892-
let results = manager.run_hooks(&mut vec![]).await.unwrap();
898+
let results = manager
899+
.run_hooks(HookTrigger::ConversationStart, &mut vec![])
900+
.await
901+
.unwrap();
893902
assert_eq!(results.len(), 2); // Should include both hooks
894903

895904
// Create and switch to a new profile
896905
manager.create_profile(&os, "test_profile").await?;
897906
manager.switch_profile(&os, "test_profile").await?;
898907

899-
let results = manager.run_hooks(&mut vec![]).await.unwrap();
908+
let results = manager
909+
.run_hooks(HookTrigger::ConversationStart, &mut vec![])
910+
.await
911+
.unwrap();
900912
assert_eq!(results.len(), 1); // Should include global hook
901913
assert_eq!(results[0].0.name, "global_hook");
902914

crates/chat-cli/src/cli/chat/context.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use serde::{
1717
};
1818
use tracing::debug;
1919

20+
use super::cli::hooks::HookTrigger;
2021
use super::consts::CONTEXT_FILES_MAX_SIZE;
2122
use super::util::drop_matched_context_files;
2223
use crate::cli::chat::ChatError;
@@ -573,7 +574,11 @@ impl ContextManager {
573574
/// * `updates` - output stream to write hook run status to if Some, else do nothing if None
574575
/// # Returns
575576
/// A vector containing pairs of a [`Hook`] definition and its execution output
576-
pub async fn run_hooks(&mut self, output: &mut impl Write) -> Result<Vec<(Hook, String)>, ChatError> {
577+
pub async fn run_hooks(
578+
&mut self,
579+
trigger: HookTrigger,
580+
output: &mut impl Write,
581+
) -> Result<Vec<(Hook, String)>, ChatError> {
577582
let mut hooks: Vec<&Hook> = Vec::new();
578583

579584
// Set internal hook states
@@ -583,10 +588,14 @@ impl ContextManager {
583588
];
584589

585590
for (hook_list, is_global) in configs {
586-
hooks.extend(hook_list.iter_mut().map(|(name, h)| {
587-
h.name = name.clone();
588-
h.is_global = is_global;
589-
&*h
591+
hooks.extend(hook_list.iter_mut().filter_map(|(name, h)| {
592+
if h.trigger == trigger {
593+
h.name = name.clone();
594+
h.is_global = is_global;
595+
Some(&*h)
596+
} else {
597+
None
598+
}
590599
}));
591600
}
592601

crates/chat-cli/src/cli/chat/conversation.rs

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -324,19 +324,20 @@ impl ConversationState {
324324
/// Returns a [FigConversationState] capable of being sent by [api_client::StreamingClient].
325325
///
326326
/// Params:
327-
/// - `run_hooks` - whether hooks should be executed and included as context
327+
/// - `run_perprompt_hooks` - whether per-prompt hooks should be executed and included as
328+
/// context
328329
pub async fn as_sendable_conversation_state(
329330
&mut self,
330331
os: &Os,
331332
stderr: &mut impl Write,
332-
run_hooks: bool,
333+
run_perprompt_hooks: bool,
333334
) -> Result<FigConversationState, ChatError> {
334335
debug_assert!(self.next_message.is_some());
335336
self.enforce_conversation_invariants();
336337
self.history.drain(self.valid_history_range.1..);
337338
self.history.drain(..self.valid_history_range.0);
338339

339-
let context = self.backend_conversation_state(os, run_hooks, stderr).await?;
340+
let context = self.backend_conversation_state(os, run_perprompt_hooks, stderr).await?;
340341
if !context.dropped_context_files.is_empty() {
341342
execute!(
342343
stderr,
@@ -390,21 +391,22 @@ impl ConversationState {
390391
pub async fn backend_conversation_state(
391392
&mut self,
392393
os: &Os,
393-
run_hooks: bool,
394+
run_perprompt_hooks: bool,
394395
output: &mut impl Write,
395396
) -> Result<BackendConversationState<'_>, ChatError> {
396397
self.update_state(false).await;
397398
self.enforce_conversation_invariants();
398399

399-
// Run hooks and add to conversation start and next user message.
400400
let mut conversation_start_context = None;
401-
if let (true, Some(cm)) = (run_hooks, self.context_manager.as_mut()) {
402-
let hook_results = cm.run_hooks(output).await?;
403-
conversation_start_context = Some(format_hook_context(hook_results.iter(), HookTrigger::ConversationStart));
404-
405-
// add per prompt content to next_user_message if available
406-
if let Some(next_message) = self.next_message.as_mut() {
407-
next_message.additional_context = format_hook_context(hook_results.iter(), HookTrigger::PerPrompt);
401+
if let Some(cm) = self.context_manager.as_mut() {
402+
let conv_start = cm.run_hooks(HookTrigger::ConversationStart, output).await?;
403+
conversation_start_context = format_hook_context(&conv_start, HookTrigger::ConversationStart);
404+
405+
if let (true, Some(next_message)) = (run_perprompt_hooks, self.next_message.as_mut()) {
406+
let per_prompt = cm.run_hooks(HookTrigger::PerPrompt, output).await?;
407+
if let Some(ctx) = format_hook_context(&per_prompt, HookTrigger::PerPrompt) {
408+
next_message.additional_context = ctx;
409+
}
408410
}
409411
}
410412

@@ -756,7 +758,17 @@ impl From<InputSchema> for ToolInputSchema {
756758
}
757759
}
758760

759-
fn format_hook_context<'a>(hook_results: impl IntoIterator<Item = &'a (Hook, String)>, trigger: HookTrigger) -> String {
761+
/// Formats hook output to be used within context blocks (e.g., in context messages or in new user
762+
/// prompts).
763+
///
764+
/// # Returns
765+
/// [Option::Some] if `hook_results` is not empty and at least one hook has content. Otherwise,
766+
/// [Option::None]
767+
fn format_hook_context(hook_results: &[(Hook, String)], trigger: HookTrigger) -> Option<String> {
768+
if hook_results.iter().all(|(_, content)| content.is_empty()) {
769+
return None;
770+
}
771+
760772
let mut context_content = String::new();
761773

762774
context_content.push_str(CONTEXT_ENTRY_START_HEADER);
@@ -766,11 +778,11 @@ fn format_hook_context<'a>(hook_results: impl IntoIterator<Item = &'a (Hook, Str
766778
}
767779
context_content.push_str("\n\n");
768780

769-
for (hook, output) in hook_results.into_iter().filter(|(h, _)| h.trigger == trigger) {
781+
for (hook, output) in hook_results.iter().filter(|(h, _)| h.trigger == trigger) {
770782
context_content.push_str(&format!("'{}': {output}\n\n", &hook.name));
771783
}
772784
context_content.push_str(CONTEXT_ENTRY_END_HEADER);
773-
context_content
785+
Some(context_content)
774786
}
775787

776788
fn enforce_conversation_invariants(

0 commit comments

Comments
 (0)