Skip to content

Commit 05ea3f2

Browse files
committed
refines conversation invariant logic with regards to tool calls with wrong names
1 parent d1d2005 commit 05ea3f2

File tree

3 files changed

+44
-54
lines changed

3 files changed

+44
-54
lines changed

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,17 @@ impl ConversationState {
302302
}
303303

304304
// Here we also need to make sure that the tool result corresponds to one of the tools
305-
// in the list. Otherwise we will see validation error from the backend. We would only
306-
// do this if the last message is a tool call that has failed.
305+
// in the list. Otherwise we will see validation error from the backend. There are three
306+
// such circumstances where intervention would be needed:
307+
// 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
308+
// which case we would automatically resolve this tool call to its correct name. This will NOT
309+
// result in an error in its tool result. The intervention here is to substitute the partial name
310+
// with its full name.
311+
// 2. The model had decided to call a tool with its partial name AND there are multiple tools it
312+
// could be referring to, in which case we WILL return an error in the tool result. The
313+
// intervention here is to substitute the ambiguous, partial name with a dummy.
314+
// 3. The model had decided to call a tool that does not exist. The intervention here is to
315+
// substitute the non-existent tool name with a dummy.
307316
let tool_use_results = user_msg.tool_use_results();
308317
if let Some(tool_use_results) = tool_use_results {
309318
// Note that we need to use the keys in tool manager's tn_map as the keys are the
@@ -312,19 +321,30 @@ impl ConversationState {
312321
// with false positives.
313322
let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::<Vec<_>>();
314323
for result in tool_use_results {
315-
if let ToolResultStatus::Error = result.status {
316-
let tool_use_id = result.tool_use_id.as_str();
317-
let _ = tool_uses
318-
.iter_mut()
319-
.filter(|tool_use| tool_use.id == tool_use_id)
320-
.map(|tool_use| {
321-
let tool_name = tool_use.name.as_str();
322-
if !tool_name_list.contains(&tool_name) {
323-
tool_use.name = DUMMY_TOOL_NAME.to_string();
324-
tool_use.args = serde_json::json!({});
325-
}
326-
})
327-
.collect::<Vec<_>>();
324+
let tool_use_id = result.tool_use_id.as_str();
325+
let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id);
326+
if let Some(tool_use) = corresponding_tool_use {
327+
if tool_name_list.contains(&tool_use.name.as_str()) {
328+
// If this tool matches of the tools in our list, this is not our
329+
// concern, error or not.
330+
continue;
331+
}
332+
if let ToolResultStatus::Error = result.status {
333+
// case 2 and 3
334+
tool_use.name = DUMMY_TOOL_NAME.to_string();
335+
tool_use.args = serde_json::json!({});
336+
} else {
337+
// case 1
338+
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
339+
// We should be able to find a match but if not we'll just treat it as
340+
// a dummy and move on
341+
if let Some(full_name) = full_name {
342+
tool_use.name = (*full_name).to_string();
343+
} else {
344+
tool_use.name = DUMMY_TOOL_NAME.to_string();
345+
tool_use.args = serde_json::json!({});
346+
}
347+
}
328348
}
329349
}
330350
}

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ use std::process::{
3333
ExitCode,
3434
};
3535
use std::sync::Arc;
36-
use std::sync::atomic::Ordering;
3736
use std::time::Duration;
3837
use std::{
3938
env,
@@ -409,6 +408,7 @@ pub async fn chat(
409408
})
410409
.await?;
411410
let tool_config = tool_manager.load_tools().await?;
411+
error!("## tool config: {:#?}", tool_config);
412412
let mut tool_permissions = ToolPermissions::new(tool_config.len());
413413
if accept_all || trust_all_tools {
414414
tool_permissions.trust_all = true;
@@ -797,14 +797,7 @@ impl ChatContext {
797797
debug!(?chat_state, "changing to state");
798798

799799
// Update conversation state with new tool information
800-
if self
801-
.conversation_state
802-
.tool_manager
803-
.has_new_stuff
804-
.load(Ordering::Relaxed)
805-
{
806-
self.conversation_state.update_state().await;
807-
}
800+
self.conversation_state.update_state().await;
808801

809802
let result = match chat_state {
810803
ChatState::PromptUser {

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ impl ToolManagerBuilder {
280280
loading_servers.insert(name.clone(), status_line);
281281
let total = loading_servers.len();
282282
execute!(output, terminal::Clear(terminal::ClearType::CurrentLine))?;
283-
queue_init_message(spinner_logo_idx, complete, failed, total, is_interactive, &mut output)?;
283+
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
284284
output.flush()?;
285285
},
286286
LoadingMsg::Done(name) => {
@@ -298,14 +298,7 @@ impl ToolManagerBuilder {
298298
)?;
299299
queue_success_message(&name, &time_taken, &mut output)?;
300300
let total = loading_servers.len();
301-
queue_init_message(
302-
spinner_logo_idx,
303-
complete,
304-
failed,
305-
total,
306-
is_interactive,
307-
&mut output,
308-
)?;
301+
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
309302
output.flush()?;
310303
}
311304
if loading_servers.iter().all(|(_, status)| status.is_done) {
@@ -324,14 +317,7 @@ impl ToolManagerBuilder {
324317
)?;
325318
queue_failure_message(&name, &msg, &mut output)?;
326319
let total = loading_servers.len();
327-
queue_init_message(
328-
spinner_logo_idx,
329-
complete,
330-
failed,
331-
total,
332-
is_interactive,
333-
&mut output,
334-
)?;
320+
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
335321
}
336322
if loading_servers.iter().all(|(_, status)| status.is_done) {
337323
break;
@@ -350,14 +336,7 @@ impl ToolManagerBuilder {
350336
let msg = eyre::eyre!(msg.to_string());
351337
queue_warn_message(&name, &msg, &mut output)?;
352338
let total = loading_servers.len();
353-
queue_init_message(
354-
spinner_logo_idx,
355-
complete,
356-
failed,
357-
total,
358-
is_interactive,
359-
&mut output,
360-
)?;
339+
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
361340
output.flush()?;
362341
}
363342
if loading_servers.iter().all(|(_, status)| status.is_done) {
@@ -714,7 +693,7 @@ impl ToolManager {
714693
pub async fn load_tools(&mut self) -> eyre::Result<HashMap<String, ToolSpec>> {
715694
let tx = self.loading_status_sender.take();
716695
let display_task = self.loading_display_task.take();
717-
let mut tool_specs = {
696+
self.schema = {
718697
let mut tool_specs =
719698
serde_json::from_str::<HashMap<String, ToolSpec>>(include_str!("tools/tool_index.json"))?;
720699
if !crate::cli::chat::tools::thinking::Thinking::is_enabled() {
@@ -777,8 +756,7 @@ impl ToolManager {
777756
}
778757
}
779758
self.update().await;
780-
tool_specs.extend(self.schema.clone());
781-
Ok(tool_specs)
759+
Ok(self.schema.clone())
782760
}
783761

784762
pub fn get_tool_from_tool_use(&self, value: AssistantToolUse) -> Result<Tool, ToolResult> {
@@ -1253,7 +1231,6 @@ fn queue_init_message(
12531231
complete: usize,
12541232
failed: usize,
12551233
total: usize,
1256-
is_interactive: bool,
12571234
output: &mut impl Write,
12581235
) -> eyre::Result<()> {
12591236
if total == complete {
@@ -1284,7 +1261,7 @@ fn queue_init_message(
12841261
style::ResetColor,
12851262
style::Print("mcp servers initialized."),
12861263
)?;
1287-
if is_interactive {
1264+
if total > complete + failed {
12881265
queue!(
12891266
output,
12901267
style::SetForegroundColor(style::Color::Blue),

0 commit comments

Comments
 (0)