Skip to content

Commit f9b0891

Browse files
committed
moves tool manager to conversation state
1 parent 1206a6e commit f9b0891

File tree

10 files changed

+57
-38
lines changed

10 files changed

+57
-38
lines changed

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ use super::token_counter::{
3737
CharCount,
3838
CharCounter,
3939
};
40+
use super::tool_manager::ToolManager;
4041
use super::tools::{
4142
InputSchema,
4243
QueuedTool,
@@ -85,6 +86,8 @@ pub struct ConversationState {
8586
pub tools: HashMap<ToolOrigin, Vec<Tool>>,
8687
/// Context manager for handling sticky context files
8788
pub context_manager: Option<ContextManager>,
89+
/// Tool manager for handling tool and mcp related activities
90+
pub tool_manager: ToolManager,
8891
/// Cached value representing the length of the user context message.
8992
context_message_length: Option<usize>,
9093
/// Stores the latest conversation summary created by /compact
@@ -99,6 +102,7 @@ impl ConversationState {
99102
tool_config: HashMap<String, ToolSpec>,
100103
profile: Option<String>,
101104
updates: Option<SharedWriter>,
105+
tool_manager: ToolManager,
102106
) -> Self {
103107
// Initialize context manager
104108
let context_manager = match ContextManager::new(ctx, None).await {
@@ -137,6 +141,7 @@ impl ConversationState {
137141
acc
138142
}),
139143
context_manager,
144+
tool_manager,
140145
context_message_length: None,
141146
latest_summary: None,
142147
updates,
@@ -926,6 +931,7 @@ mod tests {
926931
tool_manager.load_tools().await.unwrap(),
927932
None,
928933
None,
934+
tool_manager,
929935
)
930936
.await;
931937

@@ -944,12 +950,14 @@ mod tests {
944950
async fn test_conversation_state_history_handling_with_tool_results() {
945951
// Build a long conversation history of tool use results.
946952
let mut tool_manager = ToolManager::default();
953+
let tool_config = tool_manager.load_tools().await.unwrap();
947954
let mut conversation_state = ConversationState::new(
948955
Context::new(),
949956
"fake_conv_id",
950-
tool_manager.load_tools().await.unwrap(),
957+
tool_config.clone(),
951958
None,
952959
None,
960+
tool_manager.clone(),
953961
)
954962
.await;
955963
conversation_state.set_next_user_message("start".to_string()).await;
@@ -975,9 +983,10 @@ mod tests {
975983
let mut conversation_state = ConversationState::new(
976984
Context::new(),
977985
"fake_conv_id",
978-
tool_manager.load_tools().await.unwrap(),
986+
tool_config.clone(),
979987
None,
980988
None,
989+
tool_manager.clone(),
981990
)
982991
.await;
983992
conversation_state.set_next_user_message("start".to_string()).await;
@@ -1016,6 +1025,7 @@ mod tests {
10161025
tool_manager.load_tools().await.unwrap(),
10171026
None,
10181027
None,
1028+
tool_manager,
10191029
)
10201030
.await;
10211031

@@ -1081,6 +1091,7 @@ mod tests {
10811091
tool_manager.load_tools().await.unwrap(),
10821092
None,
10831093
Some(SharedWriter::stdout()),
1094+
tool_manager,
10841095
)
10851096
.await;
10861097

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,6 @@ pub struct ChatContext {
505505
tool_use_telemetry_events: HashMap<String, ToolUseEventBuilder>,
506506
/// State used to keep track of tool use relation
507507
tool_use_status: ToolUseStatus,
508-
/// Abstraction that consolidates custom tools with native ones
509-
tool_manager: ToolManager,
510508
/// Any failed requests that could be useful for error report/debugging
511509
failed_request_ids: Vec<String>,
512510
/// Pending prompts to be sent
@@ -533,8 +531,15 @@ impl ChatContext {
533531
) -> Result<Self> {
534532
let ctx_clone = Arc::clone(&ctx);
535533
let output_clone = output.clone();
536-
let conversation_state =
537-
ConversationState::new(ctx_clone, conversation_id, tool_config, profile, Some(output_clone)).await;
534+
let conversation_state = ConversationState::new(
535+
ctx_clone,
536+
conversation_id,
537+
tool_config,
538+
profile,
539+
Some(output_clone),
540+
tool_manager,
541+
)
542+
.await;
538543
Ok(Self {
539544
ctx,
540545
settings,
@@ -550,7 +555,6 @@ impl ChatContext {
550555
conversation_state,
551556
tool_use_telemetry_events: HashMap::new(),
552557
tool_use_status: ToolUseStatus::Idle,
553-
tool_manager,
554558
failed_request_ids: Vec::new(),
555559
pending_prompts: VecDeque::new(),
556560
})
@@ -1217,6 +1221,7 @@ impl ChatContext {
12171221
#[cfg(unix)]
12181222
if let Some(ref context_manager) = self.conversation_state.context_manager {
12191223
let tool_names = self
1224+
.conversation_state
12201225
.tool_manager
12211226
.tn_map
12221227
.keys()
@@ -2152,9 +2157,10 @@ impl ChatContext {
21522157

21532158
match subcommand {
21542159
Some(ToolsSubcommand::Schema) => {
2155-
let schema_json = serde_json::to_string_pretty(&self.tool_manager.schema).map_err(|e| {
2156-
ChatError::Custom(format!("Error converting tool schema to string: {e}").into())
2157-
})?;
2160+
let schema_json = serde_json::to_string_pretty(&self.conversation_state.tool_manager.schema)
2161+
.map_err(|e| {
2162+
ChatError::Custom(format!("Error converting tool schema to string: {e}").into())
2163+
})?;
21582164
queue!(self.output, style::Print(schema_json), style::Print("\n"))?;
21592165
},
21602166
Some(ToolsSubcommand::Trust { tool_names }) => {
@@ -2368,7 +2374,7 @@ impl ChatContext {
23682374
},
23692375
Some(PromptsSubcommand::Get { mut get_command }) => {
23702376
let orig_input = get_command.orig_input.take();
2371-
let prompts = match self.tool_manager.get_prompt(get_command).await {
2377+
let prompts = match self.conversation_state.tool_manager.get_prompt(get_command).await {
23722378
Ok(resp) => resp,
23732379
Err(e) => {
23742380
match e {
@@ -2455,12 +2461,12 @@ impl ChatContext {
24552461
_ => None,
24562462
};
24572463
let terminal_width = self.terminal_width();
2458-
let mut prompts_wl = self.tool_manager.prompts.write().map_err(|e| {
2464+
let mut prompts_wl = self.conversation_state.tool_manager.prompts.write().map_err(|e| {
24592465
ChatError::Custom(
24602466
format!("Poison error encountered while retrieving prompts: {}", e).into(),
24612467
)
24622468
})?;
2463-
self.tool_manager.refresh_prompts(&mut prompts_wl)?;
2469+
self.conversation_state.tool_manager.refresh_prompts(&mut prompts_wl)?;
24642470
let mut longest_name = "";
24652471
let arg_pos = {
24662472
let optimal_case = UnicodeWidthStr::width(longest_name) + terminal_width / 4;
@@ -3126,7 +3132,7 @@ impl ChatContext {
31263132
.set_tool_use_id(tool_use_id.clone())
31273133
.set_tool_name(tool_use.name.clone())
31283134
.utterance_id(self.conversation_state.message_id().map(|s| s.to_string()));
3129-
match self.tool_manager.get_tool_from_tool_use(tool_use) {
3135+
match self.conversation_state.tool_manager.get_tool_from_tool_use(tool_use) {
31303136
Ok(mut tool) => {
31313137
// Apply non-Q-generated context to tools
31323138
self.contextualize_tool(&mut tool);

crates/chat-cli/src/cli/chat/server_messenger.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::mcp_client::{
1313
ToolsListResult,
1414
};
1515

16+
#[allow(dead_code)]
1617
#[derive(Clone, Debug)]
1718
pub enum UpdateEventMessage {
1819
ToolsListResult {

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ enum OutOfSpecName {
596596

597597
type NewToolSpecs = Arc<Mutex<HashMap<String, (HashMap<String, String>, Vec<ToolSpec>)>>>;
598598

599-
#[derive(Default)]
599+
#[derive(Default, Debug)]
600600
/// Manages the lifecycle and interactions with tools from various sources, including MCP servers.
601601
/// This struct is responsible for initializing tools, handling tool requests, and maintaining
602602
/// a cache of available prompts from connected servers.
@@ -639,6 +639,21 @@ pub struct ToolManager {
639639
pub schema: HashMap<String, ToolSpec>,
640640
}
641641

642+
impl Clone for ToolManager {
643+
fn clone(&self) -> Self {
644+
Self {
645+
conversation_id: self.conversation_id.clone(),
646+
clients: self.clients.clone(),
647+
has_new_stuff: self.has_new_stuff.clone(),
648+
new_tool_specs: self.new_tool_specs.clone(),
649+
prompts: self.prompts.clone(),
650+
tn_map: self.tn_map.clone(),
651+
schema: self.schema.clone(),
652+
..Default::default()
653+
}
654+
}
655+
}
656+
642657
impl ToolManager {
643658
pub async fn load_tools(&mut self) -> eyre::Result<HashMap<String, ToolSpec>> {
644659
let tx = self.loading_status_sender.take();

crates/chat-cli/src/cli/chat/tools/custom_tool.rs

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@ use serde::{
1515
use tokio::sync::RwLock;
1616
use tracing::warn;
1717

18-
use super::{
19-
InvokeOutput,
20-
ToolSpec,
21-
};
18+
use super::InvokeOutput;
2219
use crate::cli::chat::CONTINUATION_LINE;
2320
use crate::cli::chat::token_counter::TokenCounter;
2421
use crate::mcp_client::{
@@ -88,32 +85,20 @@ impl CustomToolClient {
8885
})
8986
}
9087

91-
pub async fn init(&self) -> Result<(String, Vec<ToolSpec>)> {
88+
pub async fn init(&self) -> Result<()> {
9289
match self {
9390
CustomToolClient::Stdio {
9491
client,
95-
server_name,
9692
server_capabilities,
93+
..
9794
} => {
9895
// We'll need to first initialize. This is the handshake every client and server
9996
// needs to do before proceeding to anything else
10097
let cap = client.init().await?;
10198
// We'll be scrapping this for background server load: https://github.com/aws/amazon-q-developer-cli/issues/1466
10299
// So don't worry about the tidiness for now
103-
let is_tool_supported = cap.tools.is_some();
104100
server_capabilities.write().await.replace(cap);
105-
// Assuming a shape of return as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#listing-tools
106-
// let tools = if is_tool_supported {
107-
// // And now we make the server tell us what tools they have
108-
// let resp = client.request("tools/list", None).await?;
109-
// match resp.result.and_then(|r| r.get("tools").cloned()) {
110-
// Some(value) => serde_json::from_value::<Vec<ToolSpec>>(value)?,
111-
// None => Default::default(),
112-
// }
113-
// } else {
114-
// Default::default()
115-
// };
116-
Ok((server_name.clone(), vec![]))
101+
Ok(())
117102
},
118103
}
119104
}

crates/chat-cli/src/mcp_client/client.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ pub struct ClientConfig {
8282
pub env: Option<HashMap<String, String>>,
8383
}
8484

85+
#[allow(dead_code)]
8586
#[derive(Debug, Error)]
8687
pub enum ClientError {
8788
#[error(transparent)]
@@ -548,10 +549,6 @@ where
548549
)
549550
}
550551

551-
pub async fn shutdown(&self) -> Result<(), ClientError> {
552-
Ok(self.transport.shutdown().await?)
553-
}
554-
555552
fn get_id(&self) -> u64 {
556553
self.current_id.fetch_add(1, Ordering::SeqCst)
557554
}

crates/chat-cli/src/mcp_client/facilitator_types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use serde::{
55
use thiserror::Error;
66

77
/// https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#operations-supporting-pagination
8+
#[allow(clippy::enum_variant_names)]
89
#[derive(Debug, Clone, PartialEq, Eq)]
910
pub enum PaginationSupportedOps {
1011
ResourcesList,

crates/chat-cli/src/mcp_client/messenger.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use super::{
1111
/// consumer. It is through this interface secondary information (i.e. information that are needed
1212
/// to make requests to mcp servers) are obtained passively. Consumers of client can of course
1313
/// choose to "actively" retrieve these information via explicitly making these requests.
14+
#[allow(dead_code)]
1415
#[async_trait::async_trait]
1516
pub trait Messenger: std::fmt::Debug + Send + Sync + 'static {
1617
/// Sends the result of a tools list operation to the consumer

crates/chat-cli/src/mcp_client/server.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#![allow(dead_code)]
12
use std::collections::HashMap;
23
use std::sync::atomic::{
34
AtomicBool,

crates/chat-cli/src/mcp_client/transport/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ impl From<std::io::Error> for TransportError {
3131
}
3232
}
3333

34+
#[allow(dead_code)]
3435
#[async_trait::async_trait]
3536
pub trait Transport: Send + Sync + Debug + 'static {
3637
/// Sends a message over the transport layer.

0 commit comments

Comments
 (0)