Skip to content

Commit ec6fa9c

Browse files
authored
feat: Background server load (#1775)
* first commit * adds messenger trait * supplies server init with messenger * makes messenger to be used by dynamic dispatch instead * loads tools in the background * makes necessary changes for refactor * removes mcp client crate * makes initial server loading interruptable * formats * adds atomic bool to signal when new things are added * moves tool manager to conversation state * makes main chat loop update state if applicable * enables list changed for prompts and tools * adds copy change to server loading task * makes server init timeout configurable * uses tn map keys as list of tools for dummy substitute * adds tip for background loading and init timeout * updates tools info per try chat loop * shows servers still loading in /tools * makes timeout fut resolve immediately for tests * refines conversation invariant logic with regards to tool calls with wrong names * alias pkce to all uppercase * fixes test for oauth ser deser * puts timeout on telemetry finish * bumps telemetry finish timeout to 1 second and surface errors other than timeout * only surface error for telemetry finish if it's not a timeout
1 parent 06b1bfb commit ec6fa9c

File tree

18 files changed

+1372
-591
lines changed

18 files changed

+1372
-591
lines changed

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 83 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::collections::{
33
VecDeque,
44
};
55
use std::sync::Arc;
6+
use std::sync::atomic::Ordering;
67

78
use crossterm::style::Color;
89
use crossterm::{
@@ -41,6 +42,7 @@ use super::token_counter::{
4142
CharCount,
4243
CharCounter,
4344
};
45+
use super::tool_manager::ToolManager;
4446
use super::tools::{
4547
InputSchema,
4648
QueuedTool,
@@ -90,6 +92,9 @@ pub struct ConversationState {
9092
pub tools: HashMap<ToolOrigin, Vec<Tool>>,
9193
/// Context manager for handling sticky context files
9294
pub context_manager: Option<ContextManager>,
95+
/// Tool manager for handling tool and mcp related activities
96+
#[serde(skip)]
97+
pub tool_manager: ToolManager,
9398
/// Cached value representing the length of the user context message.
9499
context_message_length: Option<usize>,
95100
/// Stores the latest conversation summary created by /compact
@@ -105,6 +110,7 @@ impl ConversationState {
105110
tool_config: HashMap<String, ToolSpec>,
106111
profile: Option<String>,
107112
updates: Option<SharedWriter>,
113+
tool_manager: ToolManager,
108114
) -> Self {
109115
// Initialize context manager
110116
let context_manager = match ContextManager::new(ctx, None).await {
@@ -143,6 +149,7 @@ impl ConversationState {
143149
acc
144150
}),
145151
context_manager,
152+
tool_manager,
146153
context_message_length: None,
147154
latest_summary: None,
148155
updates,
@@ -310,29 +317,49 @@ impl ConversationState {
310317
}
311318

312319
// Here we also need to make sure that the tool result corresponds to one of the tools
313-
// in the list. Otherwise we will see validation error from the backend. We would only
314-
// do this if the last message is a tool call that has failed.
320+
// in the list. Otherwise we will see validation error from the backend. There are three
321+
// such circumstances where intervention would be needed:
322+
// 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
323+
// which case we would automatically resolve this tool call to its correct name. This will NOT
324+
// result in an error in its tool result. The intervention here is to substitute the partial name
325+
// with its full name.
326+
// 2. The model had decided to call a tool with its partial name AND there are multiple tools it
327+
// could be referring to, in which case we WILL return an error in the tool result. The
328+
// intervention here is to substitute the ambiguous, partial name with a dummy.
329+
// 3. The model had decided to call a tool that does not exist. The intervention here is to
330+
// substitute the non-existent tool name with a dummy.
315331
let tool_use_results = user_msg.tool_use_results();
316332
if let Some(tool_use_results) = tool_use_results {
317-
let tool_name_list = self
318-
.tools
319-
.values()
320-
.flatten()
321-
.map(|Tool::ToolSpecification(spec)| spec.name.as_str())
322-
.collect::<Vec<_>>();
333+
// Note that we need to use the keys in tool manager's tn_map as the keys are the
334+
// actual tool names as exposed to the model and the backend. If we use the actual
335+
// names as they are recognized by their respective servers, we risk concluding
336+
// with false positives.
337+
let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::<Vec<_>>();
323338
for result in tool_use_results {
324-
if let ToolResultStatus::Error = result.status {
325-
let tool_use_id = result.tool_use_id.as_str();
326-
let _ = tool_uses
327-
.iter_mut()
328-
.filter(|tool_use| tool_use.id == tool_use_id)
329-
.map(|tool_use| {
330-
let tool_name = tool_use.name.as_str();
331-
if !tool_name_list.contains(&tool_name) {
332-
tool_use.name = DUMMY_TOOL_NAME.to_string();
333-
}
334-
})
335-
.collect::<Vec<_>>();
339+
let tool_use_id = result.tool_use_id.as_str();
340+
let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id);
341+
if let Some(tool_use) = corresponding_tool_use {
342+
if tool_name_list.contains(&tool_use.name.as_str()) {
343+
// If this tool matches of the tools in our list, this is not our
344+
// concern, error or not.
345+
continue;
346+
}
347+
if let ToolResultStatus::Error = result.status {
348+
// case 2 and 3
349+
tool_use.name = DUMMY_TOOL_NAME.to_string();
350+
tool_use.args = serde_json::json!({});
351+
} else {
352+
// case 1
353+
let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name));
354+
// We should be able to find a match but if not we'll just treat it as
355+
// a dummy and move on
356+
if let Some(full_name) = full_name {
357+
tool_use.name = (*full_name).to_string();
358+
} else {
359+
tool_use.name = DUMMY_TOOL_NAME.to_string();
360+
tool_use.args = serde_json::json!({});
361+
}
362+
}
336363
}
337364
}
338365
}
@@ -363,6 +390,7 @@ impl ConversationState {
363390
/// - `run_hooks` - whether hooks should be executed and included as context
364391
pub async fn as_sendable_conversation_state(&mut self, run_hooks: bool) -> FigConversationState {
365392
debug_assert!(self.next_message.is_some());
393+
self.update_state().await;
366394
self.enforce_conversation_invariants();
367395
self.history.drain(self.valid_history_range.1..);
368396
self.history.drain(..self.valid_history_range.0);
@@ -388,6 +416,30 @@ impl ConversationState {
388416
.expect("unable to construct conversation state")
389417
}
390418

419+
pub async fn update_state(&mut self) {
420+
let needs_update = self.tool_manager.has_new_stuff.load(Ordering::Acquire);
421+
if !needs_update {
422+
return;
423+
}
424+
self.tool_manager.update().await;
425+
self.tools = self
426+
.tool_manager
427+
.schema
428+
.values()
429+
.fold(HashMap::<ToolOrigin, Vec<Tool>>::new(), |mut acc, v| {
430+
let tool = Tool::ToolSpecification(ToolSpecification {
431+
name: v.name.clone(),
432+
description: v.description.clone(),
433+
input_schema: v.input_schema.clone().into(),
434+
});
435+
acc.entry(v.tool_origin.clone())
436+
.and_modify(|tools| tools.push(tool.clone()))
437+
.or_insert(vec![tool]);
438+
acc
439+
});
440+
self.tool_manager.has_new_stuff.store(false, Ordering::Release);
441+
}
442+
391443
/// Returns a conversation state representation which reflects the exact conversation to send
392444
/// back to the model.
393445
pub async fn backend_conversation_state(&mut self, run_hooks: bool, quiet: bool) -> BackendConversationState<'_> {
@@ -843,8 +895,6 @@ mod tests {
843895
};
844896
use crate::cli::chat::tool_manager::ToolManager;
845897
use crate::database::Database;
846-
use crate::platform::Env;
847-
use crate::telemetry::TelemetryThread;
848898

849899
fn assert_conversation_state_invariants(state: FigConversationState, assertion_iteration: usize) {
850900
if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) {
@@ -936,17 +986,16 @@ mod tests {
936986

937987
#[tokio::test]
938988
async fn test_conversation_state_history_handling_truncation() {
939-
let env = Env::new();
940989
let mut database = Database::new().await.unwrap();
941-
let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap();
942990

943991
let mut tool_manager = ToolManager::default();
944992
let mut conversation_state = ConversationState::new(
945993
Context::new(),
946994
"fake_conv_id",
947-
tool_manager.load_tools(&database, &telemetry).await.unwrap(),
995+
tool_manager.load_tools(&database).await.unwrap(),
948996
None,
949997
None,
998+
tool_manager,
950999
)
9511000
.await;
9521001

@@ -964,18 +1013,18 @@ mod tests {
9641013

9651014
#[tokio::test]
9661015
async fn test_conversation_state_history_handling_with_tool_results() {
967-
let env = Env::new();
9681016
let mut database = Database::new().await.unwrap();
969-
let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap();
9701017

9711018
// Build a long conversation history of tool use results.
9721019
let mut tool_manager = ToolManager::default();
1020+
let tool_config = tool_manager.load_tools(&database).await.unwrap();
9731021
let mut conversation_state = ConversationState::new(
9741022
Context::new(),
9751023
"fake_conv_id",
976-
tool_manager.load_tools(&database, &telemetry).await.unwrap(),
1024+
tool_config.clone(),
9771025
None,
9781026
None,
1027+
tool_manager.clone(),
9791028
)
9801029
.await;
9811030
conversation_state.set_next_user_message("start".to_string()).await;
@@ -1002,9 +1051,10 @@ mod tests {
10021051
let mut conversation_state = ConversationState::new(
10031052
Context::new(),
10041053
"fake_conv_id",
1005-
tool_manager.load_tools(&database, &telemetry).await.unwrap(),
1054+
tool_config.clone(),
10061055
None,
10071056
None,
1057+
tool_manager.clone(),
10081058
)
10091059
.await;
10101060
conversation_state.set_next_user_message("start".to_string()).await;
@@ -1035,9 +1085,7 @@ mod tests {
10351085

10361086
#[tokio::test]
10371087
async fn test_conversation_state_with_context_files() {
1038-
let env = Env::new();
10391088
let mut database = Database::new().await.unwrap();
1040-
let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap();
10411089

10421090
let ctx = Context::builder().with_test_home().await.unwrap().build_fake();
10431091
ctx.fs().write(AMAZONQ_FILENAME, "test context").await.unwrap();
@@ -1046,9 +1094,10 @@ mod tests {
10461094
let mut conversation_state = ConversationState::new(
10471095
ctx,
10481096
"fake_conv_id",
1049-
tool_manager.load_tools(&database, &telemetry).await.unwrap(),
1097+
tool_manager.load_tools(&database).await.unwrap(),
10501098
None,
10511099
None,
1100+
tool_manager,
10521101
)
10531102
.await;
10541103

@@ -1085,9 +1134,7 @@ mod tests {
10851134
async fn test_conversation_state_additional_context() {
10861135
// tracing_subscriber::fmt::try_init().ok();
10871136

1088-
let env = Env::new();
10891137
let mut database = Database::new().await.unwrap();
1090-
let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap();
10911138

10921139
let mut tool_manager = ToolManager::default();
10931140
let ctx = Context::builder().with_test_home().await.unwrap().build_fake();
@@ -1116,9 +1163,10 @@ mod tests {
11161163
let mut conversation_state = ConversationState::new(
11171164
ctx,
11181165
"fake_conv_id",
1119-
tool_manager.load_tools(&database, &telemetry).await.unwrap(),
1166+
tool_manager.load_tools(&database).await.unwrap(),
11201167
None,
11211168
Some(SharedWriter::stdout()),
1169+
tool_manager,
11221170
)
11231171
.await;
11241172

0 commit comments

Comments
 (0)