Skip to content

Commit 51a7f96

Browse files
committed
merge main
2 parents 5bf85d1 + da850bf commit 51a7f96

File tree

22 files changed

+1431
-648
lines changed

22 files changed

+1431
-648
lines changed

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/chat-cli/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ bstr = "1.12.0"
4545
bytes = "1.10.1"
4646
camino = { version = "1.1.3", features = ["serde1"] }
4747
cfg-if = "1.0.0"
48-
chrono = { version = "0.4.41", default-features = false, features = ["std"] }
4948
clap = { version = "4.5.32", features = [
5049
"deprecated",
5150
"derive",

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 84 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::collections::{
33
VecDeque,
44
};
55
use std::sync::Arc;
6+
use std::sync::atomic::Ordering;
67

78
use crossterm::style::Color;
89
use crossterm::{
@@ -41,6 +42,7 @@ use super::token_counter::{
4142
CharCount,
4243
CharCounter,
4344
};
45+
use super::tool_manager::ToolManager;
4446
use super::tools::{
4547
InputSchema,
4648
QueuedTool,
@@ -90,6 +92,9 @@ pub struct ConversationState {
9092
pub tools: HashMap<ToolOrigin, Vec<Tool>>,
9193
/// Context manager for handling sticky context files
9294
pub context_manager: Option<ContextManager>,
95+
/// Tool manager for handling tool and mcp related activities
96+
#[serde(skip)]
97+
pub tool_manager: ToolManager,
9398
/// Cached value representing the length of the user context message.
9499
context_message_length: Option<usize>,
95100
/// Stores the latest conversation summary created by /compact
@@ -105,6 +110,7 @@ impl ConversationState {
105110
tool_config: HashMap<String, ToolSpec>,
106111
profile: Option<String>,
107112
updates: Option<SharedWriter>,
113+
tool_manager: ToolManager,
108114
) -> Self {
109115
// Initialize context manager
110116
let context_manager = match ContextManager::new(ctx, None).await {
@@ -143,6 +149,7 @@ impl ConversationState {
143149
acc
144150
}),
145151
context_manager,
152+
tool_manager,
146153
context_message_length: None,
147154
latest_summary: None,
148155
updates,
@@ -213,9 +220,7 @@ impl ConversationState {
213220
warn!("input must not be empty when adding new messages");
214221
"Empty prompt".to_string()
215222
} else {
216-
let now = chrono::Utc::now();
217-
let formatted_time = now.format("%Y-%m-%d %H:%M:%S").to_string();
218-
format!("{}\n\n<currentTimeUTC>\n{}\n</currentTimeUTC>", input, formatted_time)
223+
input
219224
};
220225

221226
let msg = UserMessage::new_prompt(input);
@@ -310,29 +315,49 @@ impl ConversationState {
310315
}
311316

312317
// Here we also need to make sure that the tool result corresponds to one of the tools
313-
// in the list. Otherwise we will see validation error from the backend. We would only
314-
// do this if the last message is a tool call that has failed.
318+
// in the list. Otherwise we will see validation error from the backend. There are three
319+
// such circumstances where intervention would be needed:
320+
// 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
321+
// which case we would automatically resolve this tool call to its correct name. This will NOT
322+
// result in an error in its tool result. The intervention here is to substitute the partial name
323+
// with its full name.
324+
// 2. The model had decided to call a tool with its partial name AND there are multiple tools it
325+
// could be referring to, in which case we WILL return an error in the tool result. The
326+
// intervention here is to substitute the ambiguous, partial name with a dummy.
327+
// 3. The model had decided to call a tool that does not exist. The intervention here is to
328+
// substitute the non-existent tool name with a dummy.
315329
let tool_use_results = user_msg.tool_use_results();
316330
if let Some(tool_use_results) = tool_use_results {
317-
let tool_name_list = self
318-
.tools
319-
.values()
320-
.flatten()
321-
.map(|Tool::ToolSpecification(spec)| spec.name.as_str())
322-
.collect::<Vec<_>>();
331+
// Note that we need to use the keys in tool manager's tn_map as the keys are the
332+
// actual tool names as exposed to the model and the backend. If we use the actual
333+
// names as they are recognized by their respective servers, we risk concluding
334+
// with false positives.
335+
let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::<Vec<_>>();
323336
for result in tool_use_results {
324-
if let ToolResultStatus::Error = result.status {
325-
let tool_use_id = result.tool_use_id.as_str();
326-
let _ = tool_uses
327-
.iter_mut()
328-
.filter(|tool_use| tool_use.id == tool_use_id)
329-
.map(|tool_use| {
330-
let tool_name = tool_use.name.as_str();
331-
if !tool_name_list.contains(&tool_name) {
332-
tool_use.name = DUMMY_TOOL_NAME.to_string();
333-
}
334-
})
335-
.collect::<Vec<_>>();
337+
let tool_use_id = result.tool_use_id.as_str();
338+
let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id);
339+
if let Some(tool_use) = corresponding_tool_use {
340+
if tool_name_list.contains(&tool_use.name.as_str()) {
341+
// If this tool matches of the tools in our list, this is not our
342+
// concern, error or not.
343+
continue;
344+
}
345+
if let ToolResultStatus::Error = result.status {
346+
// case 2 and 3
347+
tool_use.name = DUMMY_TOOL_NAME.to_string();
348+
tool_use.args = serde_json::json!({});
349+
} else {
350+
// case 1
351+
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
352+
// We should be able to find a match but if not we'll just treat it as
353+
// a dummy and move on
354+
if let Some(full_name) = full_name {
355+
tool_use.name = (*full_name).to_string();
356+
} else {
357+
tool_use.name = DUMMY_TOOL_NAME.to_string();
358+
tool_use.args = serde_json::json!({});
359+
}
360+
}
336361
}
337362
}
338363
}
@@ -363,6 +388,7 @@ impl ConversationState {
363388
/// - `run_hooks` - whether hooks should be executed and included as context
364389
pub async fn as_sendable_conversation_state(&mut self, run_hooks: bool) -> FigConversationState {
365390
debug_assert!(self.next_message.is_some());
391+
self.update_state().await;
366392
self.enforce_conversation_invariants();
367393
self.history.drain(self.valid_history_range.1..);
368394
self.history.drain(..self.valid_history_range.0);
@@ -388,6 +414,30 @@ impl ConversationState {
388414
.expect("unable to construct conversation state")
389415
}
390416

417+
pub async fn update_state(&mut self) {
418+
let needs_update = self.tool_manager.has_new_stuff.load(Ordering::Acquire);
419+
if !needs_update {
420+
return;
421+
}
422+
self.tool_manager.update().await;
423+
self.tools = self
424+
.tool_manager
425+
.schema
426+
.values()
427+
.fold(HashMap::<ToolOrigin, Vec<Tool>>::new(), |mut acc, v| {
428+
let tool = Tool::ToolSpecification(ToolSpecification {
429+
name: v.name.clone(),
430+
description: v.description.clone(),
431+
input_schema: v.input_schema.clone().into(),
432+
});
433+
acc.entry(v.tool_origin.clone())
434+
.and_modify(|tools| tools.push(tool.clone()))
435+
.or_insert(vec![tool]);
436+
acc
437+
});
438+
self.tool_manager.has_new_stuff.store(false, Ordering::Release);
439+
}
440+
391441
/// Returns a conversation state representation which reflects the exact conversation to send
392442
/// back to the model.
393443
pub async fn backend_conversation_state(&mut self, run_hooks: bool, quiet: bool) -> BackendConversationState<'_> {
@@ -843,8 +893,6 @@ mod tests {
843893
};
844894
use crate::cli::chat::tool_manager::ToolManager;
845895
use crate::database::Database;
846-
use crate::platform::Env;
847-
use crate::telemetry::TelemetryThread;
848896

849897
fn assert_conversation_state_invariants(state: FigConversationState, assertion_iteration: usize) {
850898
if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) {
@@ -936,17 +984,16 @@ mod tests {
936984

937985
#[tokio::test]
938986
async fn test_conversation_state_history_handling_truncation() {
939-
let env = Env::new();
940987
let mut database = Database::new().await.unwrap();
941-
let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap();
942988

943989
let mut tool_manager = ToolManager::default();
944990
let mut conversation_state = ConversationState::new(
945991
Context::new(),
946992
"fake_conv_id",
947-
tool_manager.load_tools(&database, &telemetry).await.unwrap(),
993+
tool_manager.load_tools(&database).await.unwrap(),
948994
None,
949995
None,
996+
tool_manager,
950997
)
951998
.await;
952999

@@ -964,18 +1011,18 @@ mod tests {
9641011

9651012
#[tokio::test]
9661013
async fn test_conversation_state_history_handling_with_tool_results() {
967-
let env = Env::new();
9681014
let mut database = Database::new().await.unwrap();
969-
let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap();
9701015

9711016
// Build a long conversation history of tool use results.
9721017
let mut tool_manager = ToolManager::default();
1018+
let tool_config = tool_manager.load_tools(&database).await.unwrap();
9731019
let mut conversation_state = ConversationState::new(
9741020
Context::new(),
9751021
"fake_conv_id",
976-
tool_manager.load_tools(&database, &telemetry).await.unwrap(),
1022+
tool_config.clone(),
9771023
None,
9781024
None,
1025+
tool_manager.clone(),
9791026
)
9801027
.await;
9811028
conversation_state.set_next_user_message("start".to_string()).await;
@@ -1002,9 +1049,10 @@ mod tests {
10021049
let mut conversation_state = ConversationState::new(
10031050
Context::new(),
10041051
"fake_conv_id",
1005-
tool_manager.load_tools(&database, &telemetry).await.unwrap(),
1052+
tool_config.clone(),
10061053
None,
10071054
None,
1055+
tool_manager.clone(),
10081056
)
10091057
.await;
10101058
conversation_state.set_next_user_message("start".to_string()).await;
@@ -1035,9 +1083,7 @@ mod tests {
10351083

10361084
#[tokio::test]
10371085
async fn test_conversation_state_with_context_files() {
1038-
let env = Env::new();
10391086
let mut database = Database::new().await.unwrap();
1040-
let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap();
10411087

10421088
let ctx = Context::builder().with_test_home().await.unwrap().build_fake();
10431089
ctx.fs().write(AMAZONQ_FILENAME, "test context").await.unwrap();
@@ -1046,9 +1092,10 @@ mod tests {
10461092
let mut conversation_state = ConversationState::new(
10471093
ctx,
10481094
"fake_conv_id",
1049-
tool_manager.load_tools(&database, &telemetry).await.unwrap(),
1095+
tool_manager.load_tools(&database).await.unwrap(),
10501096
None,
10511097
None,
1098+
tool_manager,
10521099
)
10531100
.await;
10541101

@@ -1085,9 +1132,7 @@ mod tests {
10851132
async fn test_conversation_state_additional_context() {
10861133
// tracing_subscriber::fmt::try_init().ok();
10871134

1088-
let env = Env::new();
10891135
let mut database = Database::new().await.unwrap();
1090-
let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap();
10911136

10921137
let mut tool_manager = ToolManager::default();
10931138
let ctx = Context::builder().with_test_home().await.unwrap().build_fake();
@@ -1116,9 +1161,10 @@ mod tests {
11161161
let mut conversation_state = ConversationState::new(
11171162
ctx,
11181163
"fake_conv_id",
1119-
tool_manager.load_tools(&database, &telemetry).await.unwrap(),
1164+
tool_manager.load_tools(&database).await.unwrap(),
11201165
None,
11211166
Some(SharedWriter::stdout()),
1167+
tool_manager,
11221168
)
11231169
.await;
11241170

0 commit comments

Comments
 (0)