Skip to content

Commit 8f638b0

Browse files
authored
refactor(tools): add context to report_issue tool (#980)
- Add the ability to add context to `Tool`s, without passing in awkward args to the generic invoke function. - Graceful tool error if context fails to be set (should be impossible).
1 parent 65d7f50 commit 8f638b0

File tree

3 files changed

+58
-32
lines changed

3 files changed

+58
-32
lines changed

crates/q_cli/src/cli/chat/mod.rs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,13 +1013,7 @@ where
10131013
style::Print(format!("{}\n", "▔".repeat(terminal_width))),
10141014
style::SetForegroundColor(Color::Reset),
10151015
)?;
1016-
let invoke_result = tool
1017-
.1
1018-
.invoke(&self.ctx, &mut self.output, GhIssueContext {
1019-
conversation_state: &self.conversation_state,
1020-
failed_request_ids: &self.failed_request_ids,
1021-
})
1022-
.await;
1016+
let invoke_result = tool.1.invoke(&self.ctx, &mut self.output).await;
10231017

10241018
if self.interactive && self.spinner.is_some() {
10251019
queue!(
@@ -1317,6 +1311,9 @@ where
13171311
.utterance_id(self.conversation_state.message_id().map(|s| s.to_string()));
13181312
match Tool::try_from(tool_use) {
13191313
Ok(mut tool) => {
1314+
// Apply non-Q-generated context to tools
1315+
self.contextualize_tool(&mut tool);
1316+
13201317
match tool.validate(&self.ctx).await {
13211318
Ok(()) => {
13221319
tool_telemetry.is_valid = Some(true);
@@ -1400,6 +1397,27 @@ where
14001397
}
14011398
}
14021399

1400+
/// Apply program context to tools that Q may not have.
1401+
// We cannot attach this any other way because Tools are constructed by deserializing
1402+
// output from Amazon Q.
1403+
// TODO: Is there a better way?
1404+
fn contextualize_tool(&self, tool: &mut Tool) {
1405+
#[allow(clippy::single_match)]
1406+
match tool {
1407+
Tool::GhIssue(gh_issue) => {
1408+
gh_issue.set_context(GhIssueContext {
1409+
// Ideally we avoid cloning, but this function is not called very often.
1410+
// Using references with lifetimes requires a large refactor, and Arc<Mutex<T>>
1411+
// seems like overkill and may incur some performance cost anyway.
1412+
context_manager: self.conversation_state.context_manager.clone(),
1413+
transcript: self.conversation_state.transcript.clone(),
1414+
failed_request_ids: self.failed_request_ids.clone(),
1415+
});
1416+
},
1417+
_ => (),
1418+
};
1419+
}
1420+
14031421
async fn print_tool_descriptions(&mut self, tool_uses: &[QueuedTool]) -> Result<(), ChatError> {
14041422
let terminal_width = self.terminal_width();
14051423
for (_, tool) in tool_uses.iter() {

crates/q_cli/src/cli/chat/tools/gh_issue.rs

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::VecDeque;
12
use std::io::Write;
23

34
use crossterm::style::Color;
@@ -8,12 +9,13 @@ use crossterm::{
89
use eyre::{
910
Result,
1011
WrapErr,
12+
eyre,
1113
};
1214
use fig_os_shim::Context;
1315
use serde::Deserialize;
1416

1517
use super::InvokeOutput;
16-
use crate::cli::chat::conversation_state::ConversationState;
18+
use crate::cli::chat::context::ContextManager;
1719
use crate::cli::issue::IssueCreator;
1820

1921
#[derive(Debug, Clone, Deserialize)]
@@ -22,25 +24,36 @@ pub struct GhIssue {
2224
pub expected_behavior: Option<String>,
2325
pub actual_behavior: Option<String>,
2426
pub steps_to_reproduce: Option<String>,
27+
28+
#[serde(skip_deserializing)]
29+
pub context: Option<GhIssueContext>,
2530
}
2631

27-
pub struct GhIssueContext<'a> {
28-
pub conversation_state: &'a ConversationState,
29-
pub failed_request_ids: &'a Vec<String>,
32+
#[derive(Debug, Clone)]
33+
pub struct GhIssueContext {
34+
pub context_manager: Option<ContextManager>,
35+
pub transcript: VecDeque<String>,
36+
pub failed_request_ids: Vec<String>,
3037
}
3138

3239
/// Max amount of user chat + assistant recent chat messages to include in the issue.
3340
const MAX_TRANSCRIPT_LEN: usize = 10;
3441

3542
impl GhIssue {
36-
pub async fn invoke(&self, _updates: impl Write, context: GhIssueContext<'_>) -> Result<InvokeOutput> {
43+
pub async fn invoke(&self, _updates: impl Write) -> Result<InvokeOutput> {
44+
let Some(context) = self.context.as_ref() else {
45+
return Err(eyre!(
46+
"report_issue: Required tool context (GhIssueContext) not set by the program."
47+
));
48+
};
49+
3750
// Prepare additional details from the chat session
38-
let additional_environment = [Self::get_request_ids(&context), Self::get_context(&context).await].join("\n\n");
51+
let additional_environment = [Self::get_request_ids(context), Self::get_context(context).await].join("\n\n");
3952

4053
// Add chat history to the actual behavior text.
4154
let actual_behavior = self.actual_behavior.as_ref().map_or_else(
42-
|| Self::get_transcript(&context),
43-
|behavior| format!("{behavior}\n\n{}\n", Self::get_transcript(&context)),
55+
|| Self::get_transcript(context),
56+
|behavior| format!("{behavior}\n\n{}\n", Self::get_transcript(context)),
4457
);
4558

4659
let _ = IssueCreator {
@@ -57,9 +70,13 @@ impl GhIssue {
5770
Ok(Default::default())
5871
}
5972

60-
fn get_transcript(context: &GhIssueContext<'_>) -> String {
73+
pub fn set_context(&mut self, context: GhIssueContext) {
74+
self.context = Some(context);
75+
}
76+
77+
fn get_transcript(context: &GhIssueContext) -> String {
6178
let mut transcript_str = String::from("```\n[chat-transcript]\n");
62-
let transcript: Vec<String> = context.conversation_state.transcript
79+
let transcript: Vec<String> = context.transcript
6380
.iter()
6481
.rev() // To take last N items
6582
.scan(0, |user_msg_count, line| {
@@ -89,7 +106,7 @@ impl GhIssue {
89106
transcript_str
90107
}
91108

92-
fn get_request_ids(context: &GhIssueContext<'_>) -> String {
109+
fn get_request_ids(context: &GhIssueContext) -> String {
93110
format!(
94111
"[chat-failed_request_ids]\n{}",
95112
if context.failed_request_ids.is_empty() {
@@ -100,9 +117,9 @@ impl GhIssue {
100117
)
101118
}
102119

103-
async fn get_context(context: &GhIssueContext<'_>) -> String {
120+
async fn get_context(context: &GhIssueContext) -> String {
104121
let mut ctx_str = "[chat-context]\n".to_string();
105-
let Some(ctx_manager) = &context.conversation_state.context_manager else {
122+
let Some(ctx_manager) = &context.context_manager else {
106123
ctx_str.push_str("No context available.");
107124
return ctx_str;
108125
};

crates/q_cli/src/cli/chat/tools/mod.rs

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,7 @@ use fig_api_client::model::{
2424
use fig_os_shim::Context;
2525
use fs_read::FsRead;
2626
use fs_write::FsWrite;
27-
use gh_issue::{
28-
GhIssue,
29-
GhIssueContext,
30-
};
27+
use gh_issue::GhIssue;
3128
use serde::Deserialize;
3229
use use_aws::UseAws;
3330

@@ -81,19 +78,13 @@ impl Tool {
8178
}
8279

8380
/// Invokes the tool asynchronously
84-
// TODO: Need to rework this to avoid passing in args meant for a single tool.
85-
pub async fn invoke(
86-
&self,
87-
context: &Context,
88-
updates: &mut impl Write,
89-
gh_issue_context: GhIssueContext<'_>,
90-
) -> Result<InvokeOutput> {
81+
pub async fn invoke(&self, context: &Context, updates: &mut impl Write) -> Result<InvokeOutput> {
9182
match self {
9283
Tool::FsRead(fs_read) => fs_read.invoke(context, updates).await,
9384
Tool::FsWrite(fs_write) => fs_write.invoke(context, updates).await,
9485
Tool::ExecuteBash(execute_bash) => execute_bash.invoke(updates).await,
9586
Tool::UseAws(use_aws) => use_aws.invoke(context, updates).await,
96-
Tool::GhIssue(gh_issue) => gh_issue.invoke(updates, gh_issue_context).await,
87+
Tool::GhIssue(gh_issue) => gh_issue.invoke(updates).await,
9788
}
9889
}
9990

0 commit comments

Comments
 (0)