diff --git a/crates/chat-cli/src/cli/chat/conversation_state.rs b/crates/chat-cli/src/cli/chat/conversation_state.rs index 3428af29c7..23d5fdb434 100644 --- a/crates/chat-cli/src/cli/chat/conversation_state.rs +++ b/crates/chat-cli/src/cli/chat/conversation_state.rs @@ -32,6 +32,7 @@ use super::hooks::{ }; use super::message::{ AssistantMessage, + AssistantToolUse, ToolUseResult, ToolUseResultBlock, UserMessage, @@ -344,53 +345,122 @@ impl ConversationState { tool_uses.iter().map(|t| t.id.as_str()), ); } + } - // Here we also need to make sure that the tool result corresponds to one of the tools - // in the list. Otherwise we will see validation error from the backend. There are three - // such circumstances where intervention would be needed: - // 1. The model had decided to call a tool with its partial name AND there is only one such tool, in - // which case we would automatically resolve this tool call to its correct name. This will NOT - // result in an error in its tool result. The intervention here is to substitute the partial name - // with its full name. - // 2. The model had decided to call a tool with its partial name AND there are multiple tools it - // could be referring to, in which case we WILL return an error in the tool result. The - // intervention here is to substitute the ambiguous, partial name with a dummy. - // 3. The model had decided to call a tool that does not exist. The intervention here is to - // substitute the non-existent tool name with a dummy. - let tool_use_results = user_msg.tool_use_results(); - if let Some(tool_use_results) = tool_use_results { - // Note that we need to use the keys in tool manager's tn_map as the keys are the - // actual tool names as exposed to the model and the backend. If we use the actual - // names as they are recognized by their respective servers, we risk concluding - // with false positives. - let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::>(); - for result in tool_use_results { - let tool_use_id = result.tool_use_id.as_str(); - let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id); - if let Some(tool_use) = corresponding_tool_use { - if tool_name_list.contains(&tool_use.name.as_str()) { - // If this tool matches of the tools in our list, this is not our - // concern, error or not. - continue; - } - if let ToolResultStatus::Error = result.status { - // case 2 and 3 - tool_use.name = DUMMY_TOOL_NAME.to_string(); - tool_use.args = serde_json::json!({}); - } else { - // case 1 - let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name)); - // We should be able to find a match but if not we'll just treat it as - // a dummy and move on - if let Some(full_name) = full_name { - tool_use.name = (*full_name).to_string(); - } else { - tool_use.name = DUMMY_TOOL_NAME.to_string(); - tool_use.args = serde_json::json!({}); - } - } + self.enforce_tool_use_history_invariants(true); + } + + /// Here we also need to make sure that the tool result corresponds to one of the tools + /// in the list. Otherwise we will see validation error from the backend. There are three + /// such circumstances where intervention would be needed: + /// 1. The model had decided to call a tool with its partial name AND there is only one such + /// tool, in which case we would automatically resolve this tool call to its correct name. + /// This will NOT result in an error in its tool result. The intervention here is to + /// substitute the partial name with its full name. + /// 2. The model had decided to call a tool with its partial name AND there are multiple tools + /// it could be referring to, in which case we WILL return an error in the tool result. The + /// intervention here is to substitute the ambiguous, partial name with a dummy. + /// 3. The model had decided to call a tool that does not exist. The intervention here is to + /// substitute the non-existent tool name with a dummy. + pub fn enforce_tool_use_history_invariants(&mut self, last_only: bool) { + let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::>(); + // We need to first determine what the range of interest is. There are two places where we + // would call this function: + // 1. When there are changes to the list of available tools, in which case we comb through the + // entire conversation + // 2. When we send a message, in which case we only examine the most recent entry + let (tool_use_results, mut tool_uses) = if last_only { + if let (Some((_, AssistantMessage::ToolUse { ref mut tool_uses, .. })), Some(user_msg)) = ( + self.history + .range_mut(self.valid_history_range.0..self.valid_history_range.1) + .last(), + &mut self.next_message, + ) { + let tool_use_results = user_msg + .tool_use_results() + .map_or(Vec::new(), |results| results.iter().collect::>()); + let tool_uses = tool_uses.iter_mut().collect::>(); + (tool_use_results, tool_uses) + } else { + (Vec::new(), Vec::new()) + } + } else { + let tool_use_results = self.next_message.as_ref().map_or(Vec::new(), |user_msg| { + user_msg + .tool_use_results() + .map_or(Vec::new(), |results| results.iter().collect::>()) + }); + self.history + .iter_mut() + .filter_map(|(user_msg, asst_msg)| { + if let (Some(tool_use_results), AssistantMessage::ToolUse { ref mut tool_uses, .. }) = + (user_msg.tool_use_results(), asst_msg) + { + Some((tool_use_results, tool_uses)) + } else { + None } + }) + .fold( + (tool_use_results, Vec::<&mut AssistantToolUse>::new()), + |(mut tool_use_results, mut tool_uses), (results, uses)| { + let mut results = results.iter().collect::>(); + let mut uses = uses.iter_mut().collect::>(); + tool_use_results.append(&mut results); + tool_uses.append(&mut uses); + (tool_use_results, tool_uses) + }, + ) + }; + + // Replace tool uses associated with tools that does not exist / no longer exists with + // dummy (i.e. put them to sleep / dormant) + for result in tool_use_results { + let tool_use_id = result.tool_use_id.as_str(); + let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id); + if let Some(tool_use) = corresponding_tool_use { + if tool_name_list.contains(&tool_use.name.as_str()) { + // If this tool matches of the tools in our list, this is not our + // concern, error or not. + continue; } + if let ToolResultStatus::Error = result.status { + // case 2 and 3 + tool_use.name = DUMMY_TOOL_NAME.to_string(); + tool_use.args = serde_json::json!({}); + } else { + // case 1 + let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name)); + // We should be able to find a match but if not we'll just treat it as + // a dummy and move on + if let Some(full_name) = full_name { + tool_use.name = (*full_name).to_string(); + } else { + tool_use.name = DUMMY_TOOL_NAME.to_string(); + tool_use.args = serde_json::json!({}); + } + } + } + } + + // Revive tools that were previously dormant if they now corresponds to one of the tools in + // our list of available tools. Note that this check only works because tn_map does NOT + // contain names of native tools. + for tool_use in tool_uses { + if tool_use.name == DUMMY_TOOL_NAME + && tool_use + .orig_name + .as_ref() + .is_some_and(|name| tool_name_list.contains(&(*name).as_str())) + { + tool_use.name = tool_use + .orig_name + .as_ref() + .map_or(DUMMY_TOOL_NAME.to_string(), |name| name.clone()); + tool_use.args = tool_use + .orig_args + .as_ref() + .map_or(serde_json::json!({}), |args| args.clone()); } } } @@ -419,7 +489,6 @@ impl ConversationState { /// - `run_hooks` - whether hooks should be executed and included as context pub async fn as_sendable_conversation_state(&mut self, run_hooks: bool) -> FigConversationState { debug_assert!(self.next_message.is_some()); - self.update_state().await; self.enforce_conversation_invariants(); self.history.drain(self.valid_history_range.1..); self.history.drain(..self.valid_history_range.0); @@ -451,6 +520,7 @@ impl ConversationState { return; } self.tool_manager.update().await; + // TODO: make this more targeted so we don't have to clone the entire list of tools self.tools = self .tool_manager .schema @@ -467,6 +537,10 @@ impl ConversationState { acc }); self.tool_manager.has_new_stuff.store(false, Ordering::Release); + // We call this in [Self::enforce_conversation_invariants] as well. But we need to call it + // here as well because when it's being called in [Self::enforce_conversation_invariants] + // it is only checking the last entry. + self.enforce_tool_use_history_invariants(false); } /// Returns a conversation state representation which reflects the exact conversation to send @@ -1066,6 +1140,7 @@ mod tests { id: "tool_id".to_string(), name: "tool name".to_string(), args: serde_json::Value::Null, + ..Default::default() }]), &mut database, ); @@ -1096,6 +1171,7 @@ mod tests { id: "tool_id".to_string(), name: "tool name".to_string(), args: serde_json::Value::Null, + ..Default::default() }]), &mut database, ); diff --git a/crates/chat-cli/src/cli/chat/message.rs b/crates/chat-cli/src/cli/chat/message.rs index d6361c3c73..733a18f581 100644 --- a/crates/chat-cli/src/cli/chat/message.rs +++ b/crates/chat-cli/src/cli/chat/message.rs @@ -342,14 +342,18 @@ impl From for AssistantResponseMessage { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct AssistantToolUse { /// The ID for the tool request. pub id: String, - /// The name for the tool. + /// The name for the tool as exposed to the model pub name: String, - /// The input to pass to the tool. + /// Original name of the tool + pub orig_name: Option, + /// The input to pass to the tool as exposed to the model pub args: serde_json::Value, + /// Original input passed to the tool + pub orig_args: Option, } impl From for ToolUse { @@ -368,6 +372,7 @@ impl From for AssistantToolUse { id: value.tool_use_id, name: value.name, args: document_to_serde_value(value.input.into()), + ..Default::default() } } } diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index d32ca0a37c..2fd1e5dbe7 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -572,6 +572,7 @@ impl ChatContext { cs.reload_serialized_state(Arc::clone(&ctx), Some(output.clone())).await; input = Some(input.unwrap_or("In a few words, summarize our conversation so far.".to_owned())); cs.tool_manager = tool_manager; + cs.enforce_tool_use_history_invariants(false); cs } else { ConversationState::new( diff --git a/crates/chat-cli/src/cli/chat/parser.rs b/crates/chat-cli/src/cli/chat/parser.rs index e382065f8e..ab3bcd208b 100644 --- a/crates/chat-cli/src/cli/chat/parser.rs +++ b/crates/chat-cli/src/cli/chat/parser.rs @@ -204,6 +204,16 @@ impl ResponseParser { // including the tool contents. Essentially, the tool was too large. // Timeouts have been seen as short as ~1 minute, so setting the time to 30. let time_elapsed = start.elapsed(); + let args = serde_json::Value::Object( + [( + "key".to_string(), + serde_json::Value::String( + "WARNING: the actual tool use arguments were too complicated to be generated".to_string(), + ), + )] + .into_iter() + .collect(), + ); if self.peek().await?.is_none() && time_elapsed > Duration::from_secs(30) { error!( "Received an unexpected end of stream after spending ~{}s receiving tool events", @@ -212,17 +222,9 @@ impl ResponseParser { self.tool_uses.push(AssistantToolUse { id: id.clone(), name: name.clone(), - args: serde_json::Value::Object( - [( - "key".to_string(), - serde_json::Value::String( - "WARNING: the actual tool use arguments were too complicated to be generated" - .to_string(), - ), - )] - .into_iter() - .collect(), - ), + orig_name: Some(name.clone()), + args: args.clone(), + orig_args: Some(args.clone()), }); let message = Box::new(AssistantMessage::new_tool_use( Some(self.message_id.clone()), @@ -242,7 +244,12 @@ impl ResponseParser { // if the tool just does not need any input _ => serde_json::json!({}), }; - Ok(AssistantToolUse { id, name, args }) + Ok(AssistantToolUse { + id, + name, + args, + ..Default::default() + }) } /// Returns the next event in the [SendMessageOutput] without consuming it. diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index ea853100de..7ff79b3a03 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -909,13 +909,9 @@ impl ToolManager { }) }; let mut updated_servers = HashSet::::new(); - for (_server_name, (tool_name_map, specs)) in new_tools { - // In a populated tn map (i.e. a partially initialized or outdated fleet of servers) there - // will be incoming tools with names that are already in the tn map, we will be writing - // over them (perhaps with the same information that they already had), and that's okay. - // In an event where a server has removed tools, the tools that are no longer available - // will linger in this map. This is also okay to not clean up as it does not affect the - // look up of tool names that are still active. + for (server_name, (tool_name_map, specs)) in new_tools { + let target = format!("{server_name}{NAMESPACE_DELIMITER}"); + self.tn_map.retain(|k, _| !k.starts_with(&target)); for (k, v) in tool_name_map { self.tn_map.insert(k, v); }