@@ -3,6 +3,7 @@ use std::collections::{
33 VecDeque ,
44} ;
55use std:: sync:: Arc ;
6+ use std:: sync:: atomic:: Ordering ;
67
78use crossterm:: style:: Color ;
89use crossterm:: {
@@ -41,6 +42,7 @@ use super::token_counter::{
4142 CharCount ,
4243 CharCounter ,
4344} ;
45+ use super :: tool_manager:: ToolManager ;
4446use super :: tools:: {
4547 InputSchema ,
4648 QueuedTool ,
@@ -90,6 +92,9 @@ pub struct ConversationState {
9092 pub tools : HashMap < ToolOrigin , Vec < Tool > > ,
9193 /// Context manager for handling sticky context files
9294 pub context_manager : Option < ContextManager > ,
95+ /// Tool manager for handling tool and mcp related activities
96+ #[ serde( skip) ]
97+ pub tool_manager : ToolManager ,
9398 /// Cached value representing the length of the user context message.
9499 context_message_length : Option < usize > ,
95100 /// Stores the latest conversation summary created by /compact
@@ -105,6 +110,7 @@ impl ConversationState {
105110 tool_config : HashMap < String , ToolSpec > ,
106111 profile : Option < String > ,
107112 updates : Option < SharedWriter > ,
113+ tool_manager : ToolManager ,
108114 ) -> Self {
109115 // Initialize context manager
110116 let context_manager = match ContextManager :: new ( ctx, None ) . await {
@@ -143,6 +149,7 @@ impl ConversationState {
143149 acc
144150 } ) ,
145151 context_manager,
152+ tool_manager,
146153 context_message_length : None ,
147154 latest_summary : None ,
148155 updates,
@@ -310,29 +317,49 @@ impl ConversationState {
310317 }
311318
312319 // Here we also need to make sure that the tool result corresponds to one of the tools
313- // in the list. Otherwise we will see validation error from the backend. We would only
314- // do this if the last message is a tool call that has failed.
320+ // in the list. Otherwise we will see validation error from the backend. There are three
321+ // such circumstances where intervention would be needed:
322+ // 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
323+ // which case we would automatically resolve this tool call to its correct name. This will NOT
324+ // result in an error in its tool result. The intervention here is to substitute the partial name
325+ // with its full name.
326+ // 2. The model had decided to call a tool with its partial name AND there are multiple tools it
327+ // could be referring to, in which case we WILL return an error in the tool result. The
328+ // intervention here is to substitute the ambiguous, partial name with a dummy.
329+ // 3. The model had decided to call a tool that does not exist. The intervention here is to
330+ // substitute the non-existent tool name with a dummy.
315331 let tool_use_results = user_msg. tool_use_results ( ) ;
316332 if let Some ( tool_use_results) = tool_use_results {
317- let tool_name_list = self
318- . tools
319- . values ( )
320- . flatten ( )
321- . map ( |Tool :: ToolSpecification ( spec) | spec. name . as_str ( ) )
322- . collect :: < Vec < _ > > ( ) ;
333+ // Note that we need to use the keys in tool manager's tn_map as the keys are the
334+ // actual tool names as exposed to the model and the backend. If we use the actual
335+ // names as they are recognized by their respective servers, we risk concluding
336+ // with false positives.
337+ let tool_name_list = self . tool_manager . tn_map . keys ( ) . map ( String :: as_str) . collect :: < Vec < _ > > ( ) ;
323338 for result in tool_use_results {
324- if let ToolResultStatus :: Error = result. status {
325- let tool_use_id = result. tool_use_id . as_str ( ) ;
326- let _ = tool_uses
327- . iter_mut ( )
328- . filter ( |tool_use| tool_use. id == tool_use_id)
329- . map ( |tool_use| {
330- let tool_name = tool_use. name . as_str ( ) ;
331- if !tool_name_list. contains ( & tool_name) {
332- tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
333- }
334- } )
335- . collect :: < Vec < _ > > ( ) ;
339+ let tool_use_id = result. tool_use_id . as_str ( ) ;
340+ let corresponding_tool_use = tool_uses. iter_mut ( ) . find ( |tool_use| tool_use_id == tool_use. id ) ;
341+ if let Some ( tool_use) = corresponding_tool_use {
342+ if tool_name_list. contains ( & tool_use. name . as_str ( ) ) {
343+ // If this tool matches of the tools in our list, this is not our
344+ // concern, error or not.
345+ continue ;
346+ }
347+ if let ToolResultStatus :: Error = result. status {
348+ // case 2 and 3
349+ tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
350+ tool_use. args = serde_json:: json!( { } ) ;
351+ } else {
352+ // case 1
353+ let full_name = tool_name_list. iter ( ) . find ( |name| name. ends_with ( & tool_use. name ) ) ;
354+ // We should be able to find a match but if not we'll just treat it as
355+ // a dummy and move on
356+ if let Some ( full_name) = full_name {
357+ tool_use. name = ( * full_name) . to_string ( ) ;
358+ } else {
359+ tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
360+ tool_use. args = serde_json:: json!( { } ) ;
361+ }
362+ }
336363 }
337364 }
338365 }
@@ -363,6 +390,7 @@ impl ConversationState {
363390 /// - `run_hooks` - whether hooks should be executed and included as context
364391 pub async fn as_sendable_conversation_state ( & mut self , run_hooks : bool ) -> FigConversationState {
365392 debug_assert ! ( self . next_message. is_some( ) ) ;
393+ self . update_state ( ) . await ;
366394 self . enforce_conversation_invariants ( ) ;
367395 self . history . drain ( self . valid_history_range . 1 ..) ;
368396 self . history . drain ( ..self . valid_history_range . 0 ) ;
@@ -388,6 +416,30 @@ impl ConversationState {
388416 . expect ( "unable to construct conversation state" )
389417 }
390418
419+ pub async fn update_state ( & mut self ) {
420+ let needs_update = self . tool_manager . has_new_stuff . load ( Ordering :: Acquire ) ;
421+ if !needs_update {
422+ return ;
423+ }
424+ self . tool_manager . update ( ) . await ;
425+ self . tools = self
426+ . tool_manager
427+ . schema
428+ . values ( )
429+ . fold ( HashMap :: < ToolOrigin , Vec < Tool > > :: new ( ) , |mut acc, v| {
430+ let tool = Tool :: ToolSpecification ( ToolSpecification {
431+ name : v. name . clone ( ) ,
432+ description : v. description . clone ( ) ,
433+ input_schema : v. input_schema . clone ( ) . into ( ) ,
434+ } ) ;
435+ acc. entry ( v. tool_origin . clone ( ) )
436+ . and_modify ( |tools| tools. push ( tool. clone ( ) ) )
437+ . or_insert ( vec ! [ tool] ) ;
438+ acc
439+ } ) ;
440+ self . tool_manager . has_new_stuff . store ( false , Ordering :: Release ) ;
441+ }
442+
391443 /// Returns a conversation state representation which reflects the exact conversation to send
392444 /// back to the model.
393445 pub async fn backend_conversation_state ( & mut self , run_hooks : bool , quiet : bool ) -> BackendConversationState < ' _ > {
@@ -843,8 +895,6 @@ mod tests {
843895 } ;
844896 use crate :: cli:: chat:: tool_manager:: ToolManager ;
845897 use crate :: database:: Database ;
846- use crate :: platform:: Env ;
847- use crate :: telemetry:: TelemetryThread ;
848898
849899 fn assert_conversation_state_invariants ( state : FigConversationState , assertion_iteration : usize ) {
850900 if let Some ( Some ( msg) ) = state. history . as_ref ( ) . map ( |h| h. first ( ) ) {
@@ -936,17 +986,16 @@ mod tests {
936986
937987 #[ tokio:: test]
938988 async fn test_conversation_state_history_handling_truncation ( ) {
939- let env = Env :: new ( ) ;
940989 let mut database = Database :: new ( ) . await . unwrap ( ) ;
941- let telemetry = TelemetryThread :: new ( & env, & mut database) . await . unwrap ( ) ;
942990
943991 let mut tool_manager = ToolManager :: default ( ) ;
944992 let mut conversation_state = ConversationState :: new (
945993 Context :: new ( ) ,
946994 "fake_conv_id" ,
947- tool_manager. load_tools ( & database, & telemetry ) . await . unwrap ( ) ,
995+ tool_manager. load_tools ( & database) . await . unwrap ( ) ,
948996 None ,
949997 None ,
998+ tool_manager,
950999 )
9511000 . await ;
9521001
@@ -964,18 +1013,18 @@ mod tests {
9641013
9651014 #[ tokio:: test]
9661015 async fn test_conversation_state_history_handling_with_tool_results ( ) {
967- let env = Env :: new ( ) ;
9681016 let mut database = Database :: new ( ) . await . unwrap ( ) ;
969- let telemetry = TelemetryThread :: new ( & env, & mut database) . await . unwrap ( ) ;
9701017
9711018 // Build a long conversation history of tool use results.
9721019 let mut tool_manager = ToolManager :: default ( ) ;
1020+ let tool_config = tool_manager. load_tools ( & database) . await . unwrap ( ) ;
9731021 let mut conversation_state = ConversationState :: new (
9741022 Context :: new ( ) ,
9751023 "fake_conv_id" ,
976- tool_manager . load_tools ( & database , & telemetry ) . await . unwrap ( ) ,
1024+ tool_config . clone ( ) ,
9771025 None ,
9781026 None ,
1027+ tool_manager. clone ( ) ,
9791028 )
9801029 . await ;
9811030 conversation_state. set_next_user_message ( "start" . to_string ( ) ) . await ;
@@ -1002,9 +1051,10 @@ mod tests {
10021051 let mut conversation_state = ConversationState :: new (
10031052 Context :: new ( ) ,
10041053 "fake_conv_id" ,
1005- tool_manager . load_tools ( & database , & telemetry ) . await . unwrap ( ) ,
1054+ tool_config . clone ( ) ,
10061055 None ,
10071056 None ,
1057+ tool_manager. clone ( ) ,
10081058 )
10091059 . await ;
10101060 conversation_state. set_next_user_message ( "start" . to_string ( ) ) . await ;
@@ -1035,9 +1085,7 @@ mod tests {
10351085
10361086 #[ tokio:: test]
10371087 async fn test_conversation_state_with_context_files ( ) {
1038- let env = Env :: new ( ) ;
10391088 let mut database = Database :: new ( ) . await . unwrap ( ) ;
1040- let telemetry = TelemetryThread :: new ( & env, & mut database) . await . unwrap ( ) ;
10411089
10421090 let ctx = Context :: builder ( ) . with_test_home ( ) . await . unwrap ( ) . build_fake ( ) ;
10431091 ctx. fs ( ) . write ( AMAZONQ_FILENAME , "test context" ) . await . unwrap ( ) ;
@@ -1046,9 +1094,10 @@ mod tests {
10461094 let mut conversation_state = ConversationState :: new (
10471095 ctx,
10481096 "fake_conv_id" ,
1049- tool_manager. load_tools ( & database, & telemetry ) . await . unwrap ( ) ,
1097+ tool_manager. load_tools ( & database) . await . unwrap ( ) ,
10501098 None ,
10511099 None ,
1100+ tool_manager,
10521101 )
10531102 . await ;
10541103
@@ -1085,9 +1134,7 @@ mod tests {
10851134 async fn test_conversation_state_additional_context ( ) {
10861135 // tracing_subscriber::fmt::try_init().ok();
10871136
1088- let env = Env :: new ( ) ;
10891137 let mut database = Database :: new ( ) . await . unwrap ( ) ;
1090- let telemetry = TelemetryThread :: new ( & env, & mut database) . await . unwrap ( ) ;
10911138
10921139 let mut tool_manager = ToolManager :: default ( ) ;
10931140 let ctx = Context :: builder ( ) . with_test_home ( ) . await . unwrap ( ) . build_fake ( ) ;
@@ -1116,9 +1163,10 @@ mod tests {
11161163 let mut conversation_state = ConversationState :: new (
11171164 ctx,
11181165 "fake_conv_id" ,
1119- tool_manager. load_tools ( & database, & telemetry ) . await . unwrap ( ) ,
1166+ tool_manager. load_tools ( & database) . await . unwrap ( ) ,
11201167 None ,
11211168 Some ( SharedWriter :: stdout ( ) ) ,
1169+ tool_manager,
11221170 )
11231171 . await ;
11241172
0 commit comments