@@ -3,6 +3,7 @@ use std::collections::{
33 VecDeque ,
44} ;
55use std:: sync:: Arc ;
6+ use std:: sync:: atomic:: Ordering ;
67
78use crossterm:: style:: Color ;
89use crossterm:: {
@@ -41,6 +42,7 @@ use super::token_counter::{
4142 CharCount ,
4243 CharCounter ,
4344} ;
45+ use super :: tool_manager:: ToolManager ;
4446use super :: tools:: {
4547 InputSchema ,
4648 QueuedTool ,
@@ -90,6 +92,9 @@ pub struct ConversationState {
9092 pub tools : HashMap < ToolOrigin , Vec < Tool > > ,
9193 /// Context manager for handling sticky context files
9294 pub context_manager : Option < ContextManager > ,
95+ /// Tool manager for handling tool and mcp related activities
96+ #[ serde( skip) ]
97+ pub tool_manager : ToolManager ,
9398 /// Cached value representing the length of the user context message.
9499 context_message_length : Option < usize > ,
95100 /// Stores the latest conversation summary created by /compact
@@ -105,6 +110,7 @@ impl ConversationState {
105110 tool_config : HashMap < String , ToolSpec > ,
106111 profile : Option < String > ,
107112 updates : Option < SharedWriter > ,
113+ tool_manager : ToolManager ,
108114 ) -> Self {
109115 // Initialize context manager
110116 let context_manager = match ContextManager :: new ( ctx, None ) . await {
@@ -143,6 +149,7 @@ impl ConversationState {
143149 acc
144150 } ) ,
145151 context_manager,
152+ tool_manager,
146153 context_message_length : None ,
147154 latest_summary : None ,
148155 updates,
@@ -308,29 +315,49 @@ impl ConversationState {
308315 }
309316
310317 // Here we also need to make sure that the tool result corresponds to one of the tools
311- // in the list. Otherwise we will see validation error from the backend. We would only
312- // do this if the last message is a tool call that has failed.
318+ // in the list. Otherwise we will see validation error from the backend. There are three
319+ // such circumstances where intervention would be needed:
320+ // 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
321+ // which case we would automatically resolve this tool call to its correct name. This will NOT
322+ // result in an error in its tool result. The intervention here is to substitute the partial name
323+ // with its full name.
324+ // 2. The model had decided to call a tool with its partial name AND there are multiple tools it
325+ // could be referring to, in which case we WILL return an error in the tool result. The
326+ // intervention here is to substitute the ambiguous, partial name with a dummy.
327+ // 3. The model had decided to call a tool that does not exist. The intervention here is to
328+ // substitute the non-existent tool name with a dummy.
313329 let tool_use_results = user_msg. tool_use_results ( ) ;
314330 if let Some ( tool_use_results) = tool_use_results {
315- let tool_name_list = self
316- . tools
317- . values ( )
318- . flatten ( )
319- . map ( |Tool :: ToolSpecification ( spec) | spec. name . as_str ( ) )
320- . collect :: < Vec < _ > > ( ) ;
331+ // Note that we need to use the keys in tool manager's tn_map as the keys are the
332+ // actual tool names as exposed to the model and the backend. If we use the actual
333+ // names as they are recognized by their respective servers, we risk concluding
334+ // with false positives.
335+ let tool_name_list = self . tool_manager . tn_map . keys ( ) . map ( String :: as_str) . collect :: < Vec < _ > > ( ) ;
321336 for result in tool_use_results {
322- if let ToolResultStatus :: Error = result. status {
323- let tool_use_id = result. tool_use_id . as_str ( ) ;
324- let _ = tool_uses
325- . iter_mut ( )
326- . filter ( |tool_use| tool_use. id == tool_use_id)
327- . map ( |tool_use| {
328- let tool_name = tool_use. name . as_str ( ) ;
329- if !tool_name_list. contains ( & tool_name) {
330- tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
331- }
332- } )
333- . collect :: < Vec < _ > > ( ) ;
337+ let tool_use_id = result. tool_use_id . as_str ( ) ;
338+ let corresponding_tool_use = tool_uses. iter_mut ( ) . find ( |tool_use| tool_use_id == tool_use. id ) ;
339+ if let Some ( tool_use) = corresponding_tool_use {
340+ if tool_name_list. contains ( & tool_use. name . as_str ( ) ) {
341+ // If this tool matches of the tools in our list, this is not our
342+ // concern, error or not.
343+ continue ;
344+ }
345+ if let ToolResultStatus :: Error = result. status {
346+ // case 2 and 3
347+ tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
348+ tool_use. args = serde_json:: json!( { } ) ;
349+ } else {
350+ // case 1
351+ let full_name = tool_name_list. iter ( ) . find ( |name| name. ends_with ( & tool_use. name ) ) ;
352+ // We should be able to find a match but if not we'll just treat it as
353+ // a dummy and move on
354+ if let Some ( full_name) = full_name {
355+ tool_use. name = ( * full_name) . to_string ( ) ;
356+ } else {
357+ tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
358+ tool_use. args = serde_json:: json!( { } ) ;
359+ }
360+ }
334361 }
335362 }
336363 }
@@ -361,6 +388,7 @@ impl ConversationState {
361388 /// - `run_hooks` - whether hooks should be executed and included as context
362389 pub async fn as_sendable_conversation_state ( & mut self , run_hooks : bool ) -> FigConversationState {
363390 debug_assert ! ( self . next_message. is_some( ) ) ;
391+ self . update_state ( ) . await ;
364392 self . enforce_conversation_invariants ( ) ;
365393 self . history . drain ( self . valid_history_range . 1 ..) ;
366394 self . history . drain ( ..self . valid_history_range . 0 ) ;
@@ -386,6 +414,30 @@ impl ConversationState {
386414 . expect ( "unable to construct conversation state" )
387415 }
388416
417+ pub async fn update_state ( & mut self ) {
418+ let needs_update = self . tool_manager . has_new_stuff . load ( Ordering :: Acquire ) ;
419+ if !needs_update {
420+ return ;
421+ }
422+ self . tool_manager . update ( ) . await ;
423+ self . tools = self
424+ . tool_manager
425+ . schema
426+ . values ( )
427+ . fold ( HashMap :: < ToolOrigin , Vec < Tool > > :: new ( ) , |mut acc, v| {
428+ let tool = Tool :: ToolSpecification ( ToolSpecification {
429+ name : v. name . clone ( ) ,
430+ description : v. description . clone ( ) ,
431+ input_schema : v. input_schema . clone ( ) . into ( ) ,
432+ } ) ;
433+ acc. entry ( v. tool_origin . clone ( ) )
434+ . and_modify ( |tools| tools. push ( tool. clone ( ) ) )
435+ . or_insert ( vec ! [ tool] ) ;
436+ acc
437+ } ) ;
438+ self . tool_manager . has_new_stuff . store ( false , Ordering :: Release ) ;
439+ }
440+
389441 /// Returns a conversation state representation which reflects the exact conversation to send
390442 /// back to the model.
391443 pub async fn backend_conversation_state ( & mut self , run_hooks : bool , quiet : bool ) -> BackendConversationState < ' _ > {
@@ -841,8 +893,6 @@ mod tests {
841893 } ;
842894 use crate :: cli:: chat:: tool_manager:: ToolManager ;
843895 use crate :: database:: Database ;
844- use crate :: platform:: Env ;
845- use crate :: telemetry:: TelemetryThread ;
846896
847897 fn assert_conversation_state_invariants ( state : FigConversationState , assertion_iteration : usize ) {
848898 if let Some ( Some ( msg) ) = state. history . as_ref ( ) . map ( |h| h. first ( ) ) {
@@ -934,17 +984,16 @@ mod tests {
934984
935985 #[ tokio:: test]
936986 async fn test_conversation_state_history_handling_truncation ( ) {
937- let env = Env :: new ( ) ;
938987 let mut database = Database :: new ( ) . await . unwrap ( ) ;
939- let telemetry = TelemetryThread :: new ( & env, & mut database) . await . unwrap ( ) ;
940988
941989 let mut tool_manager = ToolManager :: default ( ) ;
942990 let mut conversation_state = ConversationState :: new (
943991 Context :: new ( ) ,
944992 "fake_conv_id" ,
945- tool_manager. load_tools ( & database, & telemetry ) . await . unwrap ( ) ,
993+ tool_manager. load_tools ( & database) . await . unwrap ( ) ,
946994 None ,
947995 None ,
996+ tool_manager,
948997 )
949998 . await ;
950999
@@ -962,18 +1011,18 @@ mod tests {
9621011
9631012 #[ tokio:: test]
9641013 async fn test_conversation_state_history_handling_with_tool_results ( ) {
965- let env = Env :: new ( ) ;
9661014 let mut database = Database :: new ( ) . await . unwrap ( ) ;
967- let telemetry = TelemetryThread :: new ( & env, & mut database) . await . unwrap ( ) ;
9681015
9691016 // Build a long conversation history of tool use results.
9701017 let mut tool_manager = ToolManager :: default ( ) ;
1018+ let tool_config = tool_manager. load_tools ( & database) . await . unwrap ( ) ;
9711019 let mut conversation_state = ConversationState :: new (
9721020 Context :: new ( ) ,
9731021 "fake_conv_id" ,
974- tool_manager . load_tools ( & database , & telemetry ) . await . unwrap ( ) ,
1022+ tool_config . clone ( ) ,
9751023 None ,
9761024 None ,
1025+ tool_manager. clone ( ) ,
9771026 )
9781027 . await ;
9791028 conversation_state. set_next_user_message ( "start" . to_string ( ) ) . await ;
@@ -1000,9 +1049,10 @@ mod tests {
10001049 let mut conversation_state = ConversationState :: new (
10011050 Context :: new ( ) ,
10021051 "fake_conv_id" ,
1003- tool_manager . load_tools ( & database , & telemetry ) . await . unwrap ( ) ,
1052+ tool_config . clone ( ) ,
10041053 None ,
10051054 None ,
1055+ tool_manager. clone ( ) ,
10061056 )
10071057 . await ;
10081058 conversation_state. set_next_user_message ( "start" . to_string ( ) ) . await ;
@@ -1033,9 +1083,7 @@ mod tests {
10331083
10341084 #[ tokio:: test]
10351085 async fn test_conversation_state_with_context_files ( ) {
1036- let env = Env :: new ( ) ;
10371086 let mut database = Database :: new ( ) . await . unwrap ( ) ;
1038- let telemetry = TelemetryThread :: new ( & env, & mut database) . await . unwrap ( ) ;
10391087
10401088 let ctx = Context :: builder ( ) . with_test_home ( ) . await . unwrap ( ) . build_fake ( ) ;
10411089 ctx. fs ( ) . write ( AMAZONQ_FILENAME , "test context" ) . await . unwrap ( ) ;
@@ -1044,9 +1092,10 @@ mod tests {
10441092 let mut conversation_state = ConversationState :: new (
10451093 ctx,
10461094 "fake_conv_id" ,
1047- tool_manager. load_tools ( & database, & telemetry ) . await . unwrap ( ) ,
1095+ tool_manager. load_tools ( & database) . await . unwrap ( ) ,
10481096 None ,
10491097 None ,
1098+ tool_manager,
10501099 )
10511100 . await ;
10521101
@@ -1083,9 +1132,7 @@ mod tests {
10831132 async fn test_conversation_state_additional_context ( ) {
10841133 // tracing_subscriber::fmt::try_init().ok();
10851134
1086- let env = Env :: new ( ) ;
10871135 let mut database = Database :: new ( ) . await . unwrap ( ) ;
1088- let telemetry = TelemetryThread :: new ( & env, & mut database) . await . unwrap ( ) ;
10891136
10901137 let mut tool_manager = ToolManager :: default ( ) ;
10911138 let ctx = Context :: builder ( ) . with_test_home ( ) . await . unwrap ( ) . build_fake ( ) ;
@@ -1114,9 +1161,10 @@ mod tests {
11141161 let mut conversation_state = ConversationState :: new (
11151162 ctx,
11161163 "fake_conv_id" ,
1117- tool_manager. load_tools ( & database, & telemetry ) . await . unwrap ( ) ,
1164+ tool_manager. load_tools ( & database) . await . unwrap ( ) ,
11181165 None ,
11191166 Some ( SharedWriter :: stdout ( ) ) ,
1167+ tool_manager,
11201168 )
11211169 . await ;
11221170
0 commit comments