Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 121 additions & 45 deletions crates/chat-cli/src/cli/chat/conversation_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use super::hooks::{
};
use super::message::{
AssistantMessage,
AssistantToolUse,
ToolUseResult,
ToolUseResultBlock,
UserMessage,
Expand Down Expand Up @@ -344,53 +345,122 @@ impl ConversationState {
tool_uses.iter().map(|t| t.id.as_str()),
);
}
}

// Here we also need to make sure that the tool result corresponds to one of the tools
// in the list. Otherwise we will see validation error from the backend. There are three
// such circumstances where intervention would be needed:
// 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
// which case we would automatically resolve this tool call to its correct name. This will NOT
// result in an error in its tool result. The intervention here is to substitute the partial name
// with its full name.
// 2. The model had decided to call a tool with its partial name AND there are multiple tools it
// could be referring to, in which case we WILL return an error in the tool result. The
// intervention here is to substitute the ambiguous, partial name with a dummy.
// 3. The model had decided to call a tool that does not exist. The intervention here is to
// substitute the non-existent tool name with a dummy.
let tool_use_results = user_msg.tool_use_results();
if let Some(tool_use_results) = tool_use_results {
// Note that we need to use the keys in tool manager's tn_map as the keys are the
// actual tool names as exposed to the model and the backend. If we use the actual
// names as they are recognized by their respective servers, we risk concluding
// with false positives.
let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::<Vec<_>>();
for result in tool_use_results {
let tool_use_id = result.tool_use_id.as_str();
let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id);
if let Some(tool_use) = corresponding_tool_use {
if tool_name_list.contains(&tool_use.name.as_str()) {
// If this tool matches of the tools in our list, this is not our
// concern, error or not.
continue;
}
if let ToolResultStatus::Error = result.status {
// case 2 and 3
tool_use.name = DUMMY_TOOL_NAME.to_string();
tool_use.args = serde_json::json!({});
} else {
// case 1
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
// We should be able to find a match but if not we'll just treat it as
// a dummy and move on
if let Some(full_name) = full_name {
tool_use.name = (*full_name).to_string();
} else {
tool_use.name = DUMMY_TOOL_NAME.to_string();
tool_use.args = serde_json::json!({});
}
}
self.enforce_tool_use_history_invariants(true);
}

/// Here we also need to make sure that the tool result corresponds to one of the tools
/// in the list. Otherwise we will see validation error from the backend. There are three
/// such circumstances where intervention would be needed:
/// 1. The model had decided to call a tool with its partial name AND there is only one such
/// tool, in which case we would automatically resolve this tool call to its correct name.
/// This will NOT result in an error in its tool result. The intervention here is to
/// substitute the partial name with its full name.
/// 2. The model had decided to call a tool with its partial name AND there are multiple tools
/// it could be referring to, in which case we WILL return an error in the tool result. The
/// intervention here is to substitute the ambiguous, partial name with a dummy.
/// 3. The model had decided to call a tool that does not exist. The intervention here is to
/// substitute the non-existent tool name with a dummy.
pub fn enforce_tool_use_history_invariants(&mut self, last_only: bool) {
let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::<Vec<_>>();
// We need to first determine what the range of interest is. There are two places where we
// would call this function:
// 1. When there are changes to the list of available tools, in which case we comb through the
// entire conversation
// 2. When we send a message, in which case we only examine the most recent entry
let (tool_use_results, mut tool_uses) = if last_only {
if let (Some((_, AssistantMessage::ToolUse { ref mut tool_uses, .. })), Some(user_msg)) = (
self.history
.range_mut(self.valid_history_range.0..self.valid_history_range.1)
.last(),
&mut self.next_message,
) {
let tool_use_results = user_msg
.tool_use_results()
.map_or(Vec::new(), |results| results.iter().collect::<Vec<_>>());
let tool_uses = tool_uses.iter_mut().collect::<Vec<_>>();
(tool_use_results, tool_uses)
} else {
(Vec::new(), Vec::new())
}
} else {
let tool_use_results = self.next_message.as_ref().map_or(Vec::new(), |user_msg| {
user_msg
.tool_use_results()
.map_or(Vec::new(), |results| results.iter().collect::<Vec<_>>())
});
self.history
.iter_mut()
.filter_map(|(user_msg, asst_msg)| {
if let (Some(tool_use_results), AssistantMessage::ToolUse { ref mut tool_uses, .. }) =
(user_msg.tool_use_results(), asst_msg)
{
Some((tool_use_results, tool_uses))
} else {
None
}
})
.fold(
(tool_use_results, Vec::<&mut AssistantToolUse>::new()),
|(mut tool_use_results, mut tool_uses), (results, uses)| {
let mut results = results.iter().collect::<Vec<_>>();
let mut uses = uses.iter_mut().collect::<Vec<_>>();
tool_use_results.append(&mut results);
tool_uses.append(&mut uses);
(tool_use_results, tool_uses)
},
)
};

// Replace tool uses associated with tools that does not exist / no longer exists with
// dummy (i.e. put them to sleep / dormant)
for result in tool_use_results {
let tool_use_id = result.tool_use_id.as_str();
let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id);
if let Some(tool_use) = corresponding_tool_use {
if tool_name_list.contains(&tool_use.name.as_str()) {
// If this tool matches of the tools in our list, this is not our
// concern, error or not.
continue;
}
if let ToolResultStatus::Error = result.status {
// case 2 and 3
tool_use.name = DUMMY_TOOL_NAME.to_string();
tool_use.args = serde_json::json!({});
} else {
// case 1
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
// We should be able to find a match but if not we'll just treat it as
// a dummy and move on
if let Some(full_name) = full_name {
tool_use.name = (*full_name).to_string();
} else {
tool_use.name = DUMMY_TOOL_NAME.to_string();
tool_use.args = serde_json::json!({});
}
}
}
}

// Revive tools that were previously dormant if they now corresponds to one of the tools in
// our list of available tools. Note that this check only works because tn_map does NOT
// contain names of native tools.
for tool_use in tool_uses {
if tool_use.name == DUMMY_TOOL_NAME
&& tool_use
.orig_name
.as_ref()
.is_some_and(|name| tool_name_list.contains(&(*name).as_str()))
{
tool_use.name = tool_use
.orig_name
.as_ref()
.map_or(DUMMY_TOOL_NAME.to_string(), |name| name.clone());
tool_use.args = tool_use
.orig_args
.as_ref()
.map_or(serde_json::json!({}), |args| args.clone());
}
}
}
Expand Down Expand Up @@ -419,7 +489,6 @@ impl ConversationState {
/// - `run_hooks` - whether hooks should be executed and included as context
pub async fn as_sendable_conversation_state(&mut self, run_hooks: bool) -> FigConversationState {
debug_assert!(self.next_message.is_some());
self.update_state().await;
self.enforce_conversation_invariants();
self.history.drain(self.valid_history_range.1..);
self.history.drain(..self.valid_history_range.0);
Expand Down Expand Up @@ -451,6 +520,7 @@ impl ConversationState {
return;
}
self.tool_manager.update().await;
// TODO: make this more targeted so we don't have to clone the entire list of tools
self.tools = self
.tool_manager
.schema
Expand All @@ -467,6 +537,10 @@ impl ConversationState {
acc
});
self.tool_manager.has_new_stuff.store(false, Ordering::Release);
// We call this in [Self::enforce_conversation_invariants] as well. But we need to call it
// here as well because when it's being called in [Self::enforce_conversation_invariants]
// it is only checking the last entry.
self.enforce_tool_use_history_invariants(false);
}

/// Returns a conversation state representation which reflects the exact conversation to send
Expand Down Expand Up @@ -1066,6 +1140,7 @@ mod tests {
id: "tool_id".to_string(),
name: "tool name".to_string(),
args: serde_json::Value::Null,
..Default::default()
}]),
&mut database,
);
Expand Down Expand Up @@ -1096,6 +1171,7 @@ mod tests {
id: "tool_id".to_string(),
name: "tool name".to_string(),
args: serde_json::Value::Null,
..Default::default()
}]),
&mut database,
);
Expand Down
11 changes: 8 additions & 3 deletions crates/chat-cli/src/cli/chat/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,14 +342,18 @@ impl From<AssistantMessage> for AssistantResponseMessage {
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct AssistantToolUse {
/// The ID for the tool request.
pub id: String,
/// The name for the tool.
/// The name for the tool as exposed to the model
pub name: String,
/// The input to pass to the tool.
/// Original name of the tool
pub orig_name: Option<String>,
/// The input to pass to the tool as exposed to the model
pub args: serde_json::Value,
/// Original input passed to the tool
pub orig_args: Option<serde_json::Value>,
}

impl From<AssistantToolUse> for ToolUse {
Expand All @@ -368,6 +372,7 @@ impl From<ToolUse> for AssistantToolUse {
id: value.tool_use_id,
name: value.name,
args: document_to_serde_value(value.input.into()),
..Default::default()
}
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/chat-cli/src/cli/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ impl ChatContext {
cs.reload_serialized_state(Arc::clone(&ctx), Some(output.clone())).await;
input = Some(input.unwrap_or("In a few words, summarize our conversation so far.".to_owned()));
cs.tool_manager = tool_manager;
cs.enforce_tool_use_history_invariants(false);
cs
} else {
ConversationState::new(
Expand Down
31 changes: 19 additions & 12 deletions crates/chat-cli/src/cli/chat/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ impl ResponseParser {
// including the tool contents. Essentially, the tool was too large.
// Timeouts have been seen as short as ~1 minute, so setting the time to 30.
let time_elapsed = start.elapsed();
let args = serde_json::Value::Object(
[(
"key".to_string(),
serde_json::Value::String(
"WARNING: the actual tool use arguments were too complicated to be generated".to_string(),
),
)]
.into_iter()
.collect(),
);
if self.peek().await?.is_none() && time_elapsed > Duration::from_secs(30) {
error!(
"Received an unexpected end of stream after spending ~{}s receiving tool events",
Expand All @@ -212,17 +222,9 @@ impl ResponseParser {
self.tool_uses.push(AssistantToolUse {
id: id.clone(),
name: name.clone(),
args: serde_json::Value::Object(
[(
"key".to_string(),
serde_json::Value::String(
"WARNING: the actual tool use arguments were too complicated to be generated"
.to_string(),
),
)]
.into_iter()
.collect(),
),
orig_name: Some(name.clone()),
args: args.clone(),
orig_args: Some(args.clone()),
});
let message = Box::new(AssistantMessage::new_tool_use(
Some(self.message_id.clone()),
Expand All @@ -242,7 +244,12 @@ impl ResponseParser {
// if the tool just does not need any input
_ => serde_json::json!({}),
};
Ok(AssistantToolUse { id, name, args })
Ok(AssistantToolUse {
id,
name,
args,
..Default::default()
})
}

/// Returns the next event in the [SendMessageOutput] without consuming it.
Expand Down
10 changes: 3 additions & 7 deletions crates/chat-cli/src/cli/chat/tool_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,13 +909,9 @@ impl ToolManager {
})
};
let mut updated_servers = HashSet::<ToolOrigin>::new();
for (_server_name, (tool_name_map, specs)) in new_tools {
// In a populated tn map (i.e. a partially initialized or outdated fleet of servers) there
// will be incoming tools with names that are already in the tn map, we will be writing
// over them (perhaps with the same information that they already had), and that's okay.
// In an event where a server has removed tools, the tools that are no longer available
// will linger in this map. This is also okay to not clean up as it does not affect the
// look up of tool names that are still active.
for (server_name, (tool_name_map, specs)) in new_tools {
let target = format!("{server_name}{NAMESPACE_DELIMITER}");
self.tn_map.retain(|k, _| !k.starts_with(&target));
for (k, v) in tool_name_map {
self.tn_map.insert(k, v);
}
Expand Down
Loading