Skip to content

Commit aab54df

Browse files
committed
verifies conversation invariants on conversation deserialization
1 parent da850bf commit aab54df

File tree

4 files changed

+104
-23
lines changed

4 files changed

+104
-23
lines changed

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::vec_deque::IterMut;
12
use std::collections::{
23
HashMap,
34
VecDeque,
@@ -364,6 +365,78 @@ impl ConversationState {
364365
}
365366
}
366367

368+
// Here we also need to make sure that the tool result corresponds to one of the tools
369+
// in the list. Otherwise we will see validation error from the backend. There are three
370+
// such circumstances where intervention would be needed:
371+
// 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
372+
// which case we would automatically resolve this tool call to its correct name. This will NOT
373+
// result in an error in its tool result. The intervention here is to substitute the partial name
374+
// with its full name.
375+
// 2. The model had decided to call a tool with its partial name AND there are multiple tools it
376+
// could be referring to, in which case we WILL return an error in the tool result. The
377+
// intervention here is to substitute the ambiguous, partial name with a dummy.
378+
// 3. The model had decided to call a tool that does not exist. The intervention here is to
379+
// substitute the non-existent tool name with a dummy.
380+
fn enforce_tool_use_invariants(&mut self, history_of_interest: &mut Vec<(UserMessage, AssistantMessage)>) {
381+
let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::<Vec<_>>();
382+
let mut tool_uses = history_of_interest
383+
.iter_mut()
384+
.filter_map(|(_user_msg, asst_msg)| {
385+
if let AssistantMessage::ToolUse { ref mut tool_uses, .. } = asst_msg {
386+
Some(tool_uses)
387+
} else {
388+
None
389+
}
390+
})
391+
.flatten();
392+
let tool_use_results = if let Some(user_msg) = &self.next_message {
393+
// We only check to verify the last message if [Self::next_message] is set
394+
user_msg.tool_use_results().map(|arr| arr.iter().collect::<Vec<_>>())
395+
} else {
396+
// Otherwise, we check the entire conversation
397+
Some(
398+
history_of_interest
399+
.iter()
400+
.filter_map(|(user_msg, _)| user_msg.tool_use_results())
401+
.flatten()
402+
.collect::<Vec<_>>(),
403+
)
404+
};
405+
if let Some(tool_use_results) = tool_use_results {
406+
// Note that we need to use the keys in tool manager's tn_map as the keys are the
407+
// actual tool names as exposed to the model and the backend. If we use the actual
408+
// names as they are recognized by their respective servers, we risk concluding
409+
// with false positives.
410+
for result in tool_use_results {
411+
let tool_use_id = result.tool_use_id.as_str();
412+
let corresponding_tool_use = tool_uses.find(|tool_use| tool_use_id == tool_use.id);
413+
if let Some(tool_use) = corresponding_tool_use {
414+
if tool_name_list.contains(&tool_use.name.as_str()) {
415+
// If this tool matches of the tools in our list, this is not our
416+
// concern, error or not.
417+
continue;
418+
}
419+
if let ToolResultStatus::Error = result.status {
420+
// case 2 and 3
421+
tool_use.name = DUMMY_TOOL_NAME.to_string();
422+
tool_use.args = serde_json::json!({});
423+
} else {
424+
// case 1
425+
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
426+
// We should be able to find a match but if not we'll just treat it as
427+
// a dummy and move on
428+
if let Some(full_name) = full_name {
429+
tool_use.name = (*full_name).to_string();
430+
} else {
431+
tool_use.name = DUMMY_TOOL_NAME.to_string();
432+
tool_use.args = serde_json::json!({});
433+
}
434+
}
435+
}
436+
}
437+
}
438+
}
439+
367440
pub fn add_tool_results(&mut self, tool_results: Vec<ToolUseResult>) {
368441
debug_assert!(self.next_message.is_none());
369442
self.next_message = Some(UserMessage::new_tool_use_results(tool_results));
@@ -388,7 +461,6 @@ impl ConversationState {
388461
/// - `run_hooks` - whether hooks should be executed and included as context
389462
pub async fn as_sendable_conversation_state(&mut self, run_hooks: bool) -> FigConversationState {
390463
debug_assert!(self.next_message.is_some());
391-
self.update_state().await;
392464
self.enforce_conversation_invariants();
393465
self.history.drain(self.valid_history_range.1..);
394466
self.history.drain(..self.valid_history_range.0);
@@ -420,6 +492,7 @@ impl ConversationState {
420492
return;
421493
}
422494
self.tool_manager.update().await;
495+
// TODO: make this more targetted so we don't have to clone the entire list of tools
423496
self.tools = self
424497
.tool_manager
425498
.schema

crates/chat-cli/src/cli/chat/message.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,14 +342,18 @@ impl From<AssistantMessage> for AssistantResponseMessage {
342342
}
343343
}
344344

345-
#[derive(Debug, Clone, Serialize, Deserialize)]
345+
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
346346
pub struct AssistantToolUse {
347347
/// The ID for the tool request.
348348
pub id: String,
349-
/// The name for the tool.
349+
/// The name for the tool as exposed to the model
350350
pub name: String,
351-
/// The input to pass to the tool.
351+
/// Original name of the tool
352+
pub orig_name: Option<String>,
353+
/// The input to pass to the tool as exposed to the model
352354
pub args: serde_json::Value,
355+
/// Original input passed to the tool
356+
pub orig_args: Option<serde_json::Value>,
353357
}
354358

355359
impl From<AssistantToolUse> for ToolUse {
@@ -368,6 +372,7 @@ impl From<ToolUse> for AssistantToolUse {
368372
id: value.tool_use_id,
369373
name: value.name,
370374
args: document_to_serde_value(value.input.into()),
375+
..Default::default()
371376
}
372377
}
373378
}

crates/chat-cli/src/cli/chat/parser.rs

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,16 @@ impl ResponseParser {
204204
// including the tool contents. Essentially, the tool was too large.
205205
// Timeouts have been seen as short as ~1 minute, so setting the time to 30.
206206
let time_elapsed = start.elapsed();
207+
let args = serde_json::Value::Object(
208+
[(
209+
"key".to_string(),
210+
serde_json::Value::String(
211+
"WARNING: the actual tool use arguments were too complicated to be generated".to_string(),
212+
),
213+
)]
214+
.into_iter()
215+
.collect(),
216+
);
207217
if self.peek().await?.is_none() && time_elapsed > Duration::from_secs(30) {
208218
error!(
209219
"Received an unexpected end of stream after spending ~{}s receiving tool events",
@@ -212,17 +222,9 @@ impl ResponseParser {
212222
self.tool_uses.push(AssistantToolUse {
213223
id: id.clone(),
214224
name: name.clone(),
215-
args: serde_json::Value::Object(
216-
[(
217-
"key".to_string(),
218-
serde_json::Value::String(
219-
"WARNING: the actual tool use arguments were too complicated to be generated"
220-
.to_string(),
221-
),
222-
)]
223-
.into_iter()
224-
.collect(),
225-
),
225+
orig_name: Some(name.clone()),
226+
args: args.clone(),
227+
orig_args: Some(args.clone()),
226228
});
227229
let message = Box::new(AssistantMessage::new_tool_use(
228230
Some(self.message_id.clone()),
@@ -242,7 +244,12 @@ impl ResponseParser {
242244
// if the tool just does not need any input
243245
_ => serde_json::json!({}),
244246
};
245-
Ok(AssistantToolUse { id, name, args })
247+
Ok(AssistantToolUse {
248+
id,
249+
name,
250+
args,
251+
..Default::default()
252+
})
246253
}
247254

248255
/// Returns the next event in the [SendMessageOutput] without consuming it.

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -878,13 +878,9 @@ impl ToolManager {
878878
})
879879
};
880880
let mut updated_servers = HashSet::<ToolOrigin>::new();
881-
for (_server_name, (tool_name_map, specs)) in new_tools {
882-
// In a populated tn map (i.e. a partially initialized or outdated fleet of servers) there
883-
// will be incoming tools with names that are already in the tn map, we will be writing
884-
// over them (perhaps with the same information that they already had), and that's okay.
885-
// In an event where a server has removed tools, the tools that are no longer available
886-
// will linger in this map. This is also okay to not clean up as it does not affect the
887-
// look up of tool names that are still active.
881+
for (server_name, (tool_name_map, specs)) in new_tools {
882+
let target = format!("{server_name}{NAMESPACE_DELIMITER}");
883+
self.tn_map.retain(|k, _| !k.starts_with(&target));
888884
for (k, v) in tool_name_map {
889885
self.tn_map.insert(k, v);
890886
}

0 commit comments

Comments
 (0)