Skip to content

Commit 26b3672

Browse files
committed
fix: tool use names in history being wrong
1 parent 7044cdc commit 26b3672

File tree

4 files changed

+50
-103
lines changed

4 files changed

+50
-103
lines changed

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 45 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::collections::{
22
HashMap,
3+
HashSet,
34
VecDeque,
45
};
56
use std::sync::Arc;
@@ -32,7 +33,6 @@ use super::hooks::{
3233
};
3334
use super::message::{
3435
AssistantMessage,
35-
AssistantToolUse,
3636
ToolUseResult,
3737
ToolUseResultBlock,
3838
UserMessage,
@@ -60,7 +60,6 @@ use crate::api_client::model::{
6060
ToolInputSchema,
6161
ToolResult,
6262
ToolResultContentBlock,
63-
ToolResultStatus,
6463
ToolSpecification,
6564
ToolUse,
6665
UserInputMessage,
@@ -347,7 +346,7 @@ impl ConversationState {
347346
}
348347
}
349348

350-
self.enforce_tool_use_history_invariants(true);
349+
self.enforce_tool_use_history_invariants();
351350
}
352351

353352
/// Here we also need to make sure that the tool result corresponds to one of the tools
@@ -362,105 +361,51 @@ impl ConversationState {
362361
/// intervention here is to substitute the ambiguous, partial name with a dummy.
363362
/// 3. The model had decided to call a tool that does not exist. The intervention here is to
364363
/// substitute the non-existent tool name with a dummy.
365-
pub fn enforce_tool_use_history_invariants(&mut self, last_only: bool) {
366-
let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::<Vec<_>>();
367-
// We need to first determine what the range of interest is. There are two places where we
368-
// would call this function:
369-
// 1. When there are changes to the list of available tools, in which case we comb through the
370-
// entire conversation
371-
// 2. When we send a message, in which case we only examine the most recent entry
372-
let (tool_use_results, mut tool_uses) = if last_only {
373-
if let (Some((_, AssistantMessage::ToolUse { ref mut tool_uses, .. })), Some(user_msg)) = (
374-
self.history
375-
.range_mut(self.valid_history_range.0..self.valid_history_range.1)
376-
.last(),
377-
&mut self.next_message,
378-
) {
379-
let tool_use_results = user_msg
380-
.tool_use_results()
381-
.map_or(Vec::new(), |results| results.iter().collect::<Vec<_>>());
382-
let tool_uses = tool_uses.iter_mut().collect::<Vec<_>>();
383-
(tool_use_results, tool_uses)
384-
} else {
385-
(Vec::new(), Vec::new())
386-
}
387-
} else {
388-
let tool_use_results = self.next_message.as_ref().map_or(Vec::new(), |user_msg| {
389-
user_msg
390-
.tool_use_results()
391-
.map_or(Vec::new(), |results| results.iter().collect::<Vec<_>>())
392-
});
393-
self.history
394-
.iter_mut()
395-
.filter_map(|(user_msg, asst_msg)| {
396-
if let (Some(tool_use_results), AssistantMessage::ToolUse { ref mut tool_uses, .. }) =
397-
(user_msg.tool_use_results(), asst_msg)
398-
{
399-
Some((tool_use_results, tool_uses))
400-
} else {
401-
None
402-
}
364+
pub fn enforce_tool_use_history_invariants(&mut self) {
365+
let tool_names: HashSet<_> = self
366+
.tools
367+
.values()
368+
.flat_map(|tools| {
369+
tools.iter().map(|tool| match tool {
370+
Tool::ToolSpecification(tool_specification) => tool_specification.name.as_str(),
403371
})
404-
.fold(
405-
(tool_use_results, Vec::<&mut AssistantToolUse>::new()),
406-
|(mut tool_use_results, mut tool_uses), (results, uses)| {
407-
let mut results = results.iter().collect::<Vec<_>>();
408-
let mut uses = uses.iter_mut().collect::<Vec<_>>();
409-
tool_use_results.append(&mut results);
410-
tool_uses.append(&mut uses);
411-
(tool_use_results, tool_uses)
412-
},
413-
)
414-
};
372+
})
373+
.collect();
415374

416-
// Replace tool uses associated with tools that does not exist / no longer exists with
417-
// dummy (i.e. put them to sleep / dormant)
418-
for result in tool_use_results {
419-
let tool_use_id = result.tool_use_id.as_str();
420-
let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id);
421-
if let Some(tool_use) = corresponding_tool_use {
422-
if tool_name_list.contains(&tool_use.name.as_str()) {
423-
// If this tool matches of the tools in our list, this is not our
424-
// concern, error or not.
425-
continue;
426-
}
427-
if let ToolResultStatus::Error = result.status {
428-
// case 2 and 3
429-
tool_use.name = DUMMY_TOOL_NAME.to_string();
430-
tool_use.args = serde_json::json!({});
431-
} else {
432-
// case 1
433-
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
434-
// We should be able to find a match but if not we'll just treat it as
435-
// a dummy and move on
436-
if let Some(full_name) = full_name {
437-
tool_use.name = (*full_name).to_string();
438-
} else {
439-
tool_use.name = DUMMY_TOOL_NAME.to_string();
440-
tool_use.args = serde_json::json!({});
375+
for (_, assistant) in &mut self.history {
376+
if let AssistantMessage::ToolUse { ref mut tool_uses, .. } = assistant {
377+
for tool_use in tool_uses {
378+
if tool_names.contains(tool_use.name.as_str()) {
379+
continue;
441380
}
442-
}
443-
}
444-
}
445381

446-
// Revive tools that were previously dormant if they now corresponds to one of the tools in
447-
// our list of available tools. Note that this check only works because tn_map does NOT
448-
// contain names of native tools.
449-
for tool_use in tool_uses {
450-
if tool_use.name == DUMMY_TOOL_NAME
451-
&& tool_use
452-
.orig_name
453-
.as_ref()
454-
.is_some_and(|name| tool_name_list.contains(&(*name).as_str()))
455-
{
456-
tool_use.name = tool_use
457-
.orig_name
458-
.as_ref()
459-
.map_or(DUMMY_TOOL_NAME.to_string(), |name| name.clone());
460-
tool_use.args = tool_use
461-
.orig_args
462-
.as_ref()
463-
.map_or(serde_json::json!({}), |args| args.clone());
382+
if tool_names.contains(tool_use.orig_name.as_str()) {
383+
tool_use.name = tool_use.orig_name.clone();
384+
tool_use.args = tool_use.orig_args.clone();
385+
continue;
386+
}
387+
388+
let names: Vec<&str> = tool_names
389+
.iter()
390+
.filter_map(|name| {
391+
if name.ends_with(&tool_use.name) {
392+
Some(*name)
393+
} else {
394+
None
395+
}
396+
})
397+
.collect();
398+
399+
// There's only one tool use matching, so we can just replace it with the
400+
// found name.
401+
if names.len() == 1 {
402+
tool_use.name = (*names.first().unwrap()).to_string();
403+
continue;
404+
}
405+
406+
// Otherwise, we have to replace it with a dummy.
407+
tool_use.name = DUMMY_TOOL_NAME.to_string();
408+
}
464409
}
465410
}
466411
}
@@ -540,12 +485,13 @@ impl ConversationState {
540485
// We call this in [Self::enforce_conversation_invariants] as well. But we need to call it
541486
// here as well because when it's being called in [Self::enforce_conversation_invariants]
542487
// it is only checking the last entry.
543-
self.enforce_tool_use_history_invariants(false);
488+
self.enforce_tool_use_history_invariants();
544489
}
545490

546491
/// Returns a conversation state representation which reflects the exact conversation to send
547492
/// back to the model.
548493
pub async fn backend_conversation_state(&mut self, run_hooks: bool, quiet: bool) -> BackendConversationState<'_> {
494+
self.update_state(false).await;
549495
self.enforce_conversation_invariants();
550496

551497
// Run hooks and add to conversation start and next user message.

crates/chat-cli/src/cli/chat/message.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,11 @@ pub struct AssistantToolUse {
349349
/// The name for the tool as exposed to the model
350350
pub name: String,
351351
/// Original name of the tool
352-
pub orig_name: Option<String>,
352+
pub orig_name: String,
353353
/// The input to pass to the tool as exposed to the model
354354
pub args: serde_json::Value,
355355
/// Original input passed to the tool
356-
pub orig_args: Option<serde_json::Value>,
356+
pub orig_args: serde_json::Value,
357357
}
358358

359359
impl From<AssistantToolUse> for ToolUse {

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ impl ChatContext {
559559
input = Some(input.unwrap_or("In a few words, summarize our conversation so far.".to_owned()));
560560
cs.tool_manager = tool_manager;
561561
cs.update_state(true).await;
562+
cs.enforce_tool_use_history_invariants();
562563
cs
563564
} else {
564565
ConversationState::new(

crates/chat-cli/src/cli/chat/parser.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,9 @@ impl ResponseParser {
222222
self.tool_uses.push(AssistantToolUse {
223223
id: id.clone(),
224224
name: name.clone(),
225-
orig_name: Some(name.clone()),
225+
orig_name: name.clone(),
226226
args: args.clone(),
227-
orig_args: Some(args.clone()),
227+
orig_args: args.clone(),
228228
});
229229
let message = Box::new(AssistantMessage::new_tool_use(
230230
Some(self.message_id.clone()),

0 commit comments

Comments
 (0)