Skip to content

Commit 8aa696d

Browse files
committed
adds logic to check tool use history invariants
1 parent aab54df commit 8aa696d

File tree

2 files changed

+105
-112
lines changed

2 files changed

+105
-112
lines changed

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 104 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::collections::vec_deque::IterMut;
21
use std::collections::{
32
HashMap,
43
VecDeque,
@@ -33,6 +32,7 @@ use super::hooks::{
3332
};
3433
use super::message::{
3534
AssistantMessage,
35+
AssistantToolUse,
3636
ToolUseResult,
3737
ToolUseResultBlock,
3838
UserMessage,
@@ -314,127 +314,113 @@ impl ConversationState {
314314
tool_uses.iter().map(|t| t.id.as_str()),
315315
);
316316
}
317-
318-
// Here we also need to make sure that the tool result corresponds to one of the tools
319-
// in the list. Otherwise we will see validation error from the backend. There are three
320-
// such circumstances where intervention would be needed:
321-
// 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
322-
// which case we would automatically resolve this tool call to its correct name. This will NOT
323-
// result in an error in its tool result. The intervention here is to substitute the partial name
324-
// with its full name.
325-
// 2. The model had decided to call a tool with its partial name AND there are multiple tools it
326-
// could be referring to, in which case we WILL return an error in the tool result. The
327-
// intervention here is to substitute the ambiguous, partial name with a dummy.
328-
// 3. The model had decided to call a tool that does not exist. The intervention here is to
329-
// substitute the non-existent tool name with a dummy.
330-
let tool_use_results = user_msg.tool_use_results();
331-
if let Some(tool_use_results) = tool_use_results {
332-
// Note that we need to use the keys in tool manager's tn_map as the keys are the
333-
// actual tool names as exposed to the model and the backend. If we use the actual
334-
// names as they are recognized by their respective servers, we risk concluding
335-
// with false positives.
336-
let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::<Vec<_>>();
337-
for result in tool_use_results {
338-
let tool_use_id = result.tool_use_id.as_str();
339-
let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id);
340-
if let Some(tool_use) = corresponding_tool_use {
341-
if tool_name_list.contains(&tool_use.name.as_str()) {
342-
// If this tool matches of the tools in our list, this is not our
343-
// concern, error or not.
344-
continue;
345-
}
346-
if let ToolResultStatus::Error = result.status {
347-
// case 2 and 3
348-
tool_use.name = DUMMY_TOOL_NAME.to_string();
349-
tool_use.args = serde_json::json!({});
350-
} else {
351-
// case 1
352-
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
353-
// We should be able to find a match but if not we'll just treat it as
354-
// a dummy and move on
355-
if let Some(full_name) = full_name {
356-
tool_use.name = (*full_name).to_string();
357-
} else {
358-
tool_use.name = DUMMY_TOOL_NAME.to_string();
359-
tool_use.args = serde_json::json!({});
360-
}
361-
}
362-
}
363-
}
364-
}
317+
self.enforce_tool_use_history_invariants();
365318
}
366319
}
367320

368-
// Here we also need to make sure that the tool result corresponds to one of the tools
369-
// in the list. Otherwise we will see validation error from the backend. There are three
370-
// such circumstances where intervention would be needed:
371-
// 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
372-
// which case we would automatically resolve this tool call to its correct name. This will NOT
373-
// result in an error in its tool result. The intervention here is to substitute the partial name
374-
// with its full name.
375-
// 2. The model had decided to call a tool with its partial name AND there are multiple tools it
376-
// could be referring to, in which case we WILL return an error in the tool result. The
377-
// intervention here is to substitute the ambiguous, partial name with a dummy.
378-
// 3. The model had decided to call a tool that does not exist. The intervention here is to
379-
// substitute the non-existent tool name with a dummy.
380-
fn enforce_tool_use_invariants(&mut self, history_of_interest: &mut Vec<(UserMessage, AssistantMessage)>) {
321+
/// Here we also need to make sure that the tool result corresponds to one of the tools
322+
/// in the list. Otherwise we will see validation error from the backend. There are three
323+
/// such circumstances where intervention would be needed:
324+
/// 1. The model had decided to call a tool with its partial name AND there is only one such
325+
/// tool, in which case we would automatically resolve this tool call to its correct name.
326+
/// This will NOT result in an error in its tool result. The intervention here is to
327+
/// substitute the partial name with its full name.
328+
/// 2. The model had decided to call a tool with its partial name AND there are multiple tools
329+
/// it could be referring to, in which case we WILL return an error in the tool result. The
330+
/// intervention here is to substitute the ambiguous, partial name with a dummy.
331+
/// 3. The model had decided to call a tool that does not exist. The intervention here is to
332+
/// substitute the non-existent tool name with a dummy.
333+
pub fn enforce_tool_use_history_invariants(&mut self) {
381334
let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::<Vec<_>>();
382-
let mut tool_uses = history_of_interest
383-
.iter_mut()
384-
.filter_map(|(_user_msg, asst_msg)| {
385-
if let AssistantMessage::ToolUse { ref mut tool_uses, .. } = asst_msg {
386-
Some(tool_uses)
387-
} else {
388-
None
335+
// We need to first determine what the range of interest is. There are two places where we
336+
// would call this function:
337+
// 1. When there are changes to the list of available tools, in which case we comb through the
338+
// entire conversation
339+
// 2. When we send a message, in which case we only examine the most recent entry
340+
let (tool_use_results, mut tool_uses) =
341+
if let (Some((_, AssistantMessage::ToolUse { ref mut tool_uses, .. })), Some(user_msg)) = (
342+
self.history
343+
.range_mut(self.valid_history_range.0..self.valid_history_range.1)
344+
.last(),
345+
&mut self.next_message,
346+
) {
347+
let tool_use_results = user_msg
348+
.tool_use_results()
349+
.map_or(Vec::new(), |results| results.iter().collect::<Vec<_>>());
350+
let tool_uses = tool_uses.iter_mut().collect::<Vec<_>>();
351+
(tool_use_results, tool_uses)
352+
} else {
353+
self.history
354+
.iter_mut()
355+
.filter_map(|(user_msg, asst_msg)| {
356+
if let (Some(tool_use_results), AssistantMessage::ToolUse { ref mut tool_uses, .. }) =
357+
(user_msg.tool_use_results(), asst_msg)
358+
{
359+
Some((tool_use_results, tool_uses))
360+
} else {
361+
None
362+
}
363+
})
364+
.fold(
365+
(Vec::<&ToolUseResult>::new(), Vec::<&mut AssistantToolUse>::new()),
366+
|(mut tool_use_results, mut tool_uses), (results, uses)| {
367+
let mut results = results.iter().collect::<Vec<_>>();
368+
let mut uses = uses.iter_mut().collect::<Vec<_>>();
369+
tool_use_results.append(&mut results);
370+
tool_uses.append(&mut uses);
371+
(tool_use_results, tool_uses)
372+
},
373+
)
374+
};
375+
// Replace tool uses associated with tools that does not exist / no longer exists with
376+
// dummy (i.e. put them to sleep / dormant)
377+
for result in tool_use_results {
378+
let tool_use_id = result.tool_use_id.as_str();
379+
let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id);
380+
if let Some(tool_use) = corresponding_tool_use {
381+
if tool_name_list.contains(&tool_use.name.as_str()) {
382+
// If this tool matches of the tools in our list, this is not our
383+
// concern, error or not.
384+
continue;
389385
}
390-
})
391-
.flatten();
392-
let tool_use_results = if let Some(user_msg) = &self.next_message {
393-
// We only check to verify the last message if [Self::next_message] is set
394-
user_msg.tool_use_results().map(|arr| arr.iter().collect::<Vec<_>>())
395-
} else {
396-
// Otherwise, we check the entire conversation
397-
Some(
398-
history_of_interest
399-
.iter()
400-
.filter_map(|(user_msg, _)| user_msg.tool_use_results())
401-
.flatten()
402-
.collect::<Vec<_>>(),
403-
)
404-
};
405-
if let Some(tool_use_results) = tool_use_results {
406-
// Note that we need to use the keys in tool manager's tn_map as the keys are the
407-
// actual tool names as exposed to the model and the backend. If we use the actual
408-
// names as they are recognized by their respective servers, we risk concluding
409-
// with false positives.
410-
for result in tool_use_results {
411-
let tool_use_id = result.tool_use_id.as_str();
412-
let corresponding_tool_use = tool_uses.find(|tool_use| tool_use_id == tool_use.id);
413-
if let Some(tool_use) = corresponding_tool_use {
414-
if tool_name_list.contains(&tool_use.name.as_str()) {
415-
// If this tool matches of the tools in our list, this is not our
416-
// concern, error or not.
417-
continue;
418-
}
419-
if let ToolResultStatus::Error = result.status {
420-
// case 2 and 3
386+
if let ToolResultStatus::Error = result.status {
387+
// case 2 and 3
388+
tool_use.name = DUMMY_TOOL_NAME.to_string();
389+
tool_use.args = serde_json::json!({});
390+
} else {
391+
// case 1
392+
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
393+
// We should be able to find a match but if not we'll just treat it as
394+
// a dummy and move on
395+
if let Some(full_name) = full_name {
396+
tool_use.name = (*full_name).to_string();
397+
} else {
421398
tool_use.name = DUMMY_TOOL_NAME.to_string();
422399
tool_use.args = serde_json::json!({});
423-
} else {
424-
// case 1
425-
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
426-
// We should be able to find a match but if not we'll just treat it as
427-
// a dummy and move on
428-
if let Some(full_name) = full_name {
429-
tool_use.name = (*full_name).to_string();
430-
} else {
431-
tool_use.name = DUMMY_TOOL_NAME.to_string();
432-
tool_use.args = serde_json::json!({});
433-
}
434400
}
435401
}
436402
}
437403
}
404+
// Revive tools that were previously dormant if they now corresponds to one of the tools in
405+
// our list of available tools. Note that this check only works because tn_map does NOT
406+
// contain names of native tools.
407+
for tool_use in tool_uses {
408+
if tool_use.name == DUMMY_TOOL_NAME
409+
&& tool_use
410+
.orig_name
411+
.as_ref()
412+
.is_some_and(|name| tool_name_list.contains(&(*name).as_str()))
413+
{
414+
tool_use.name = tool_use
415+
.orig_name
416+
.as_ref()
417+
.map_or(DUMMY_TOOL_NAME.to_string(), |name| name.clone());
418+
tool_use.args = tool_use
419+
.orig_args
420+
.as_ref()
421+
.map_or(serde_json::json!({}), |args| args.clone());
422+
}
423+
}
438424
}
439425

440426
pub fn add_tool_results(&mut self, tool_results: Vec<ToolUseResult>) {
@@ -492,7 +478,7 @@ impl ConversationState {
492478
return;
493479
}
494480
self.tool_manager.update().await;
495-
// TODO: make this more targetted so we don't have to clone the entire list of tools
481+
// TODO: make this more targeted so we don't have to clone the entire list of tools
496482
self.tools = self
497483
.tool_manager
498484
.schema
@@ -509,6 +495,10 @@ impl ConversationState {
509495
acc
510496
});
511497
self.tool_manager.has_new_stuff.store(false, Ordering::Release);
498+
// We call this in [Self::enforce_conversation_invariants] as well. But we need to call it
499+
// here as well because when it's being called in [Self::enforce_conversation_invariants]
500+
// it is only checking the last entry.
501+
self.enforce_tool_use_history_invariants();
512502
}
513503

514504
/// Returns a conversation state representation which reflects the exact conversation to send
@@ -1108,6 +1098,7 @@ mod tests {
11081098
id: "tool_id".to_string(),
11091099
name: "tool name".to_string(),
11101100
args: serde_json::Value::Null,
1101+
..Default::default()
11111102
}]),
11121103
&mut database,
11131104
);
@@ -1138,6 +1129,7 @@ mod tests {
11381129
id: "tool_id".to_string(),
11391130
name: "tool name".to_string(),
11401131
args: serde_json::Value::Null,
1132+
..Default::default()
11411133
}]),
11421134
&mut database,
11431135
);

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ impl ChatContext {
545545
existing_conversation = true;
546546
input = Some(input.unwrap_or("In a few words, summarize our conversation so far.".to_owned()));
547547
prior.tool_manager = tool_manager;
548+
prior.enforce_tool_use_history_invariants();
548549
prior
549550
},
550551
None => {

0 commit comments

Comments
 (0)