Skip to content

Commit 200f16c

Browse files
authored
fix: Persistant conv invariant (#1822)
* verifies conversation invariants on conversation deserialization * adds logic to check tool use history invariants * modifies history invariant function to take parameter to explicitly control range of history to examine
1 parent eff496e commit 200f16c

File tree

5 files changed

+152
-67
lines changed

5 files changed

+152
-67
lines changed

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 121 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use super::hooks::{
3232
};
3333
use super::message::{
3434
AssistantMessage,
35+
AssistantToolUse,
3536
ToolUseResult,
3637
ToolUseResultBlock,
3738
UserMessage,
@@ -344,53 +345,122 @@ impl ConversationState {
344345
tool_uses.iter().map(|t| t.id.as_str()),
345346
);
346347
}
348+
}
347349

348-
// Here we also need to make sure that the tool result corresponds to one of the tools
349-
// in the list. Otherwise we will see validation error from the backend. There are three
350-
// such circumstances where intervention would be needed:
351-
// 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
352-
// which case we would automatically resolve this tool call to its correct name. This will NOT
353-
// result in an error in its tool result. The intervention here is to substitute the partial name
354-
// with its full name.
355-
// 2. The model had decided to call a tool with its partial name AND there are multiple tools it
356-
// could be referring to, in which case we WILL return an error in the tool result. The
357-
// intervention here is to substitute the ambiguous, partial name with a dummy.
358-
// 3. The model had decided to call a tool that does not exist. The intervention here is to
359-
// substitute the non-existent tool name with a dummy.
360-
let tool_use_results = user_msg.tool_use_results();
361-
if let Some(tool_use_results) = tool_use_results {
362-
// Note that we need to use the keys in tool manager's tn_map as the keys are the
363-
// actual tool names as exposed to the model and the backend. If we use the actual
364-
// names as they are recognized by their respective servers, we risk concluding
365-
// with false positives.
366-
let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::<Vec<_>>();
367-
for result in tool_use_results {
368-
let tool_use_id = result.tool_use_id.as_str();
369-
let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id);
370-
if let Some(tool_use) = corresponding_tool_use {
371-
if tool_name_list.contains(&tool_use.name.as_str()) {
372-
// If this tool matches of the tools in our list, this is not our
373-
// concern, error or not.
374-
continue;
375-
}
376-
if let ToolResultStatus::Error = result.status {
377-
// case 2 and 3
378-
tool_use.name = DUMMY_TOOL_NAME.to_string();
379-
tool_use.args = serde_json::json!({});
380-
} else {
381-
// case 1
382-
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
383-
// We should be able to find a match but if not we'll just treat it as
384-
// a dummy and move on
385-
if let Some(full_name) = full_name {
386-
tool_use.name = (*full_name).to_string();
387-
} else {
388-
tool_use.name = DUMMY_TOOL_NAME.to_string();
389-
tool_use.args = serde_json::json!({});
390-
}
391-
}
350+
self.enforce_tool_use_history_invariants(true);
351+
}
352+
353+
/// Here we also need to make sure that the tool result corresponds to one of the tools
354+
/// in the list. Otherwise we will see validation error from the backend. There are three
355+
/// such circumstances where intervention would be needed:
356+
/// 1. The model had decided to call a tool with its partial name AND there is only one such
357+
/// tool, in which case we would automatically resolve this tool call to its correct name.
358+
/// This will NOT result in an error in its tool result. The intervention here is to
359+
/// substitute the partial name with its full name.
360+
/// 2. The model had decided to call a tool with its partial name AND there are multiple tools
361+
/// it could be referring to, in which case we WILL return an error in the tool result. The
362+
/// intervention here is to substitute the ambiguous, partial name with a dummy.
363+
/// 3. The model had decided to call a tool that does not exist. The intervention here is to
364+
/// substitute the non-existent tool name with a dummy.
365+
pub fn enforce_tool_use_history_invariants(&mut self, last_only: bool) {
366+
let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::<Vec<_>>();
367+
// We need to first determine what the range of interest is. There are two places where we
368+
// would call this function:
369+
// 1. When there are changes to the list of available tools, in which case we comb through the
370+
// entire conversation
371+
// 2. When we send a message, in which case we only examine the most recent entry
372+
let (tool_use_results, mut tool_uses) = if last_only {
373+
if let (Some((_, AssistantMessage::ToolUse { ref mut tool_uses, .. })), Some(user_msg)) = (
374+
self.history
375+
.range_mut(self.valid_history_range.0..self.valid_history_range.1)
376+
.last(),
377+
&mut self.next_message,
378+
) {
379+
let tool_use_results = user_msg
380+
.tool_use_results()
381+
.map_or(Vec::new(), |results| results.iter().collect::<Vec<_>>());
382+
let tool_uses = tool_uses.iter_mut().collect::<Vec<_>>();
383+
(tool_use_results, tool_uses)
384+
} else {
385+
(Vec::new(), Vec::new())
386+
}
387+
} else {
388+
let tool_use_results = self.next_message.as_ref().map_or(Vec::new(), |user_msg| {
389+
user_msg
390+
.tool_use_results()
391+
.map_or(Vec::new(), |results| results.iter().collect::<Vec<_>>())
392+
});
393+
self.history
394+
.iter_mut()
395+
.filter_map(|(user_msg, asst_msg)| {
396+
if let (Some(tool_use_results), AssistantMessage::ToolUse { ref mut tool_uses, .. }) =
397+
(user_msg.tool_use_results(), asst_msg)
398+
{
399+
Some((tool_use_results, tool_uses))
400+
} else {
401+
None
392402
}
403+
})
404+
.fold(
405+
(tool_use_results, Vec::<&mut AssistantToolUse>::new()),
406+
|(mut tool_use_results, mut tool_uses), (results, uses)| {
407+
let mut results = results.iter().collect::<Vec<_>>();
408+
let mut uses = uses.iter_mut().collect::<Vec<_>>();
409+
tool_use_results.append(&mut results);
410+
tool_uses.append(&mut uses);
411+
(tool_use_results, tool_uses)
412+
},
413+
)
414+
};
415+
416+
// Replace tool uses associated with tools that does not exist / no longer exists with
417+
// dummy (i.e. put them to sleep / dormant)
418+
for result in tool_use_results {
419+
let tool_use_id = result.tool_use_id.as_str();
420+
let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id);
421+
if let Some(tool_use) = corresponding_tool_use {
422+
if tool_name_list.contains(&tool_use.name.as_str()) {
423+
// If this tool matches of the tools in our list, this is not our
424+
// concern, error or not.
425+
continue;
393426
}
427+
if let ToolResultStatus::Error = result.status {
428+
// case 2 and 3
429+
tool_use.name = DUMMY_TOOL_NAME.to_string();
430+
tool_use.args = serde_json::json!({});
431+
} else {
432+
// case 1
433+
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
434+
// We should be able to find a match but if not we'll just treat it as
435+
// a dummy and move on
436+
if let Some(full_name) = full_name {
437+
tool_use.name = (*full_name).to_string();
438+
} else {
439+
tool_use.name = DUMMY_TOOL_NAME.to_string();
440+
tool_use.args = serde_json::json!({});
441+
}
442+
}
443+
}
444+
}
445+
446+
// Revive tools that were previously dormant if they now corresponds to one of the tools in
447+
// our list of available tools. Note that this check only works because tn_map does NOT
448+
// contain names of native tools.
449+
for tool_use in tool_uses {
450+
if tool_use.name == DUMMY_TOOL_NAME
451+
&& tool_use
452+
.orig_name
453+
.as_ref()
454+
.is_some_and(|name| tool_name_list.contains(&(*name).as_str()))
455+
{
456+
tool_use.name = tool_use
457+
.orig_name
458+
.as_ref()
459+
.map_or(DUMMY_TOOL_NAME.to_string(), |name| name.clone());
460+
tool_use.args = tool_use
461+
.orig_args
462+
.as_ref()
463+
.map_or(serde_json::json!({}), |args| args.clone());
394464
}
395465
}
396466
}
@@ -419,7 +489,6 @@ impl ConversationState {
419489
/// - `run_hooks` - whether hooks should be executed and included as context
420490
pub async fn as_sendable_conversation_state(&mut self, run_hooks: bool) -> FigConversationState {
421491
debug_assert!(self.next_message.is_some());
422-
self.update_state().await;
423492
self.enforce_conversation_invariants();
424493
self.history.drain(self.valid_history_range.1..);
425494
self.history.drain(..self.valid_history_range.0);
@@ -451,6 +520,7 @@ impl ConversationState {
451520
return;
452521
}
453522
self.tool_manager.update().await;
523+
// TODO: make this more targeted so we don't have to clone the entire list of tools
454524
self.tools = self
455525
.tool_manager
456526
.schema
@@ -467,6 +537,10 @@ impl ConversationState {
467537
acc
468538
});
469539
self.tool_manager.has_new_stuff.store(false, Ordering::Release);
540+
// We call this in [Self::enforce_conversation_invariants] as well. But we need to call it
541+
// here as well because when it's being called in [Self::enforce_conversation_invariants]
542+
// it is only checking the last entry.
543+
self.enforce_tool_use_history_invariants(false);
470544
}
471545

472546
/// Returns a conversation state representation which reflects the exact conversation to send
@@ -1066,6 +1140,7 @@ mod tests {
10661140
id: "tool_id".to_string(),
10671141
name: "tool name".to_string(),
10681142
args: serde_json::Value::Null,
1143+
..Default::default()
10691144
}]),
10701145
&mut database,
10711146
);
@@ -1096,6 +1171,7 @@ mod tests {
10961171
id: "tool_id".to_string(),
10971172
name: "tool name".to_string(),
10981173
args: serde_json::Value::Null,
1174+
..Default::default()
10991175
}]),
11001176
&mut database,
11011177
);

crates/chat-cli/src/cli/chat/message.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,14 +342,18 @@ impl From<AssistantMessage> for AssistantResponseMessage {
342342
}
343343
}
344344

345-
#[derive(Debug, Clone, Serialize, Deserialize)]
345+
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
346346
pub struct AssistantToolUse {
347347
/// The ID for the tool request.
348348
pub id: String,
349-
/// The name for the tool.
349+
/// The name for the tool as exposed to the model
350350
pub name: String,
351-
/// The input to pass to the tool.
351+
/// Original name of the tool
352+
pub orig_name: Option<String>,
353+
/// The input to pass to the tool as exposed to the model
352354
pub args: serde_json::Value,
355+
/// Original input passed to the tool
356+
pub orig_args: Option<serde_json::Value>,
353357
}
354358

355359
impl From<AssistantToolUse> for ToolUse {
@@ -368,6 +372,7 @@ impl From<ToolUse> for AssistantToolUse {
368372
id: value.tool_use_id,
369373
name: value.name,
370374
args: document_to_serde_value(value.input.into()),
375+
..Default::default()
371376
}
372377
}
373378
}

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ impl ChatContext {
572572
cs.reload_serialized_state(Arc::clone(&ctx), Some(output.clone())).await;
573573
input = Some(input.unwrap_or("In a few words, summarize our conversation so far.".to_owned()));
574574
cs.tool_manager = tool_manager;
575+
cs.enforce_tool_use_history_invariants(false);
575576
cs
576577
} else {
577578
ConversationState::new(

crates/chat-cli/src/cli/chat/parser.rs

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,16 @@ impl ResponseParser {
204204
// including the tool contents. Essentially, the tool was too large.
205205
// Timeouts have been seen as short as ~1 minute, so setting the time to 30.
206206
let time_elapsed = start.elapsed();
207+
let args = serde_json::Value::Object(
208+
[(
209+
"key".to_string(),
210+
serde_json::Value::String(
211+
"WARNING: the actual tool use arguments were too complicated to be generated".to_string(),
212+
),
213+
)]
214+
.into_iter()
215+
.collect(),
216+
);
207217
if self.peek().await?.is_none() && time_elapsed > Duration::from_secs(30) {
208218
error!(
209219
"Received an unexpected end of stream after spending ~{}s receiving tool events",
@@ -212,17 +222,9 @@ impl ResponseParser {
212222
self.tool_uses.push(AssistantToolUse {
213223
id: id.clone(),
214224
name: name.clone(),
215-
args: serde_json::Value::Object(
216-
[(
217-
"key".to_string(),
218-
serde_json::Value::String(
219-
"WARNING: the actual tool use arguments were too complicated to be generated"
220-
.to_string(),
221-
),
222-
)]
223-
.into_iter()
224-
.collect(),
225-
),
225+
orig_name: Some(name.clone()),
226+
args: args.clone(),
227+
orig_args: Some(args.clone()),
226228
});
227229
let message = Box::new(AssistantMessage::new_tool_use(
228230
Some(self.message_id.clone()),
@@ -242,7 +244,12 @@ impl ResponseParser {
242244
// if the tool just does not need any input
243245
_ => serde_json::json!({}),
244246
};
245-
Ok(AssistantToolUse { id, name, args })
247+
Ok(AssistantToolUse {
248+
id,
249+
name,
250+
args,
251+
..Default::default()
252+
})
246253
}
247254

248255
/// Returns the next event in the [SendMessageOutput] without consuming it.

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -909,13 +909,9 @@ impl ToolManager {
909909
})
910910
};
911911
let mut updated_servers = HashSet::<ToolOrigin>::new();
912-
for (_server_name, (tool_name_map, specs)) in new_tools {
913-
// In a populated tn map (i.e. a partially initialized or outdated fleet of servers) there
914-
// will be incoming tools with names that are already in the tn map, we will be writing
915-
// over them (perhaps with the same information that they already had), and that's okay.
916-
// In an event where a server has removed tools, the tools that are no longer available
917-
// will linger in this map. This is also okay to not clean up as it does not affect the
918-
// look up of tool names that are still active.
912+
for (server_name, (tool_name_map, specs)) in new_tools {
913+
let target = format!("{server_name}{NAMESPACE_DELIMITER}");
914+
self.tn_map.retain(|k, _| !k.starts_with(&target));
919915
for (k, v) in tool_name_map {
920916
self.tn_map.insert(k, v);
921917
}

0 commit comments

Comments
 (0)