@@ -3,6 +3,7 @@ use std::collections::{
33 VecDeque ,
44} ;
55use std:: sync:: Arc ;
6+ use std:: sync:: atomic:: Ordering ;
67
78use crossterm:: style:: Color ;
89use crossterm:: {
@@ -41,6 +42,7 @@ use super::token_counter::{
4142 CharCount ,
4243 CharCounter ,
4344} ;
45+ use super :: tool_manager:: ToolManager ;
4446use super :: tools:: {
4547 InputSchema ,
4648 QueuedTool ,
@@ -90,6 +92,9 @@ pub struct ConversationState {
9092 pub tools : HashMap < ToolOrigin , Vec < Tool > > ,
9193 /// Context manager for handling sticky context files
9294 pub context_manager : Option < ContextManager > ,
95+ /// Tool manager for handling tool and mcp related activities
96+ #[ serde( skip) ]
97+ pub tool_manager : ToolManager ,
9398 /// Cached value representing the length of the user context message.
9499 context_message_length : Option < usize > ,
95100 /// Stores the latest conversation summary created by /compact
@@ -105,6 +110,7 @@ impl ConversationState {
105110 tool_config : HashMap < String , ToolSpec > ,
106111 profile : Option < String > ,
107112 updates : Option < SharedWriter > ,
113+ tool_manager : ToolManager ,
108114 ) -> Self {
109115 // Initialize context manager
110116 let context_manager = match ContextManager :: new ( ctx, None ) . await {
@@ -143,6 +149,7 @@ impl ConversationState {
143149 acc
144150 } ) ,
145151 context_manager,
152+ tool_manager,
146153 context_message_length : None ,
147154 latest_summary : None ,
148155 updates,
@@ -213,9 +220,7 @@ impl ConversationState {
213220 warn ! ( "input must not be empty when adding new messages" ) ;
214221 "Empty prompt" . to_string ( )
215222 } else {
216- let now = chrono:: Utc :: now ( ) ;
217- let formatted_time = now. format ( "%Y-%m-%d %H:%M:%S" ) . to_string ( ) ;
218- format ! ( "{}\n \n <currentTimeUTC>\n {}\n </currentTimeUTC>" , input, formatted_time)
223+ input
219224 } ;
220225
221226 let msg = UserMessage :: new_prompt ( input) ;
@@ -310,29 +315,49 @@ impl ConversationState {
310315 }
311316
312317 // Here we also need to make sure that the tool result corresponds to one of the tools
313- // in the list. Otherwise we will see validation error from the backend. We would only
314- // do this if the last message is a tool call that has failed.
318+ // in the list. Otherwise we will see validation error from the backend. There are three
319+ // such circumstances where intervention would be needed:
320+ // 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
321+ // which case we would automatically resolve this tool call to its correct name. This will NOT
322+ // result in an error in its tool result. The intervention here is to substitute the partial name
323+ // with its full name.
324+ // 2. The model had decided to call a tool with its partial name AND there are multiple tools it
325+ // could be referring to, in which case we WILL return an error in the tool result. The
326+ // intervention here is to substitute the ambiguous, partial name with a dummy.
327+ // 3. The model had decided to call a tool that does not exist. The intervention here is to
328+ // substitute the non-existent tool name with a dummy.
315329 let tool_use_results = user_msg. tool_use_results ( ) ;
316330 if let Some ( tool_use_results) = tool_use_results {
317- let tool_name_list = self
318- . tools
319- . values ( )
320- . flatten ( )
321- . map ( |Tool :: ToolSpecification ( spec) | spec. name . as_str ( ) )
322- . collect :: < Vec < _ > > ( ) ;
331+ // Note that we need to use the keys in tool manager's tn_map as the keys are the
332+ // actual tool names as exposed to the model and the backend. If we use the actual
333+ // names as they are recognized by their respective servers, we risk concluding
334+ // with false positives.
335+ let tool_name_list = self . tool_manager . tn_map . keys ( ) . map ( String :: as_str) . collect :: < Vec < _ > > ( ) ;
323336 for result in tool_use_results {
324- if let ToolResultStatus :: Error = result. status {
325- let tool_use_id = result. tool_use_id . as_str ( ) ;
326- let _ = tool_uses
327- . iter_mut ( )
328- . filter ( |tool_use| tool_use. id == tool_use_id)
329- . map ( |tool_use| {
330- let tool_name = tool_use. name . as_str ( ) ;
331- if !tool_name_list. contains ( & tool_name) {
332- tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
333- }
334- } )
335- . collect :: < Vec < _ > > ( ) ;
337+ let tool_use_id = result. tool_use_id . as_str ( ) ;
338+ let corresponding_tool_use = tool_uses. iter_mut ( ) . find ( |tool_use| tool_use_id == tool_use. id ) ;
339+ if let Some ( tool_use) = corresponding_tool_use {
340+ if tool_name_list. contains ( & tool_use. name . as_str ( ) ) {
341+ // If this tool matches of the tools in our list, this is not our
342+ // concern, error or not.
343+ continue ;
344+ }
345+ if let ToolResultStatus :: Error = result. status {
346+ // case 2 and 3
347+ tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
348+ tool_use. args = serde_json:: json!( { } ) ;
349+ } else {
350+ // case 1
351+ let full_name = tool_name_list. iter ( ) . find ( |name| name. ends_with ( & tool_use. name ) ) ;
352+ // We should be able to find a match but if not we'll just treat it as
353+ // a dummy and move on
354+ if let Some ( full_name) = full_name {
355+ tool_use. name = ( * full_name) . to_string ( ) ;
356+ } else {
357+ tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
358+ tool_use. args = serde_json:: json!( { } ) ;
359+ }
360+ }
336361 }
337362 }
338363 }
@@ -363,6 +388,7 @@ impl ConversationState {
363388 /// - `run_hooks` - whether hooks should be executed and included as context
364389 pub async fn as_sendable_conversation_state ( & mut self , run_hooks : bool ) -> FigConversationState {
365390 debug_assert ! ( self . next_message. is_some( ) ) ;
391+ self . update_state ( ) . await ;
366392 self . enforce_conversation_invariants ( ) ;
367393 self . history . drain ( self . valid_history_range . 1 ..) ;
368394 self . history . drain ( ..self . valid_history_range . 0 ) ;
@@ -388,6 +414,30 @@ impl ConversationState {
388414 . expect ( "unable to construct conversation state" )
389415 }
390416
417+ pub async fn update_state ( & mut self ) {
418+ let needs_update = self . tool_manager . has_new_stuff . load ( Ordering :: Acquire ) ;
419+ if !needs_update {
420+ return ;
421+ }
422+ self . tool_manager . update ( ) . await ;
423+ self . tools = self
424+ . tool_manager
425+ . schema
426+ . values ( )
427+ . fold ( HashMap :: < ToolOrigin , Vec < Tool > > :: new ( ) , |mut acc, v| {
428+ let tool = Tool :: ToolSpecification ( ToolSpecification {
429+ name : v. name . clone ( ) ,
430+ description : v. description . clone ( ) ,
431+ input_schema : v. input_schema . clone ( ) . into ( ) ,
432+ } ) ;
433+ acc. entry ( v. tool_origin . clone ( ) )
434+ . and_modify ( |tools| tools. push ( tool. clone ( ) ) )
435+ . or_insert ( vec ! [ tool] ) ;
436+ acc
437+ } ) ;
438+ self . tool_manager . has_new_stuff . store ( false , Ordering :: Release ) ;
439+ }
440+
391441 /// Returns a conversation state representation which reflects the exact conversation to send
392442 /// back to the model.
393443 pub async fn backend_conversation_state ( & mut self , run_hooks : bool , quiet : bool ) -> BackendConversationState < ' _ > {
@@ -843,8 +893,6 @@ mod tests {
843893 } ;
844894 use crate :: cli:: chat:: tool_manager:: ToolManager ;
845895 use crate :: database:: Database ;
846- use crate :: platform:: Env ;
847- use crate :: telemetry:: TelemetryThread ;
848896
849897 fn assert_conversation_state_invariants ( state : FigConversationState , assertion_iteration : usize ) {
850898 if let Some ( Some ( msg) ) = state. history . as_ref ( ) . map ( |h| h. first ( ) ) {
@@ -936,17 +984,16 @@ mod tests {
936984
937985 #[ tokio:: test]
938986 async fn test_conversation_state_history_handling_truncation ( ) {
939- let env = Env :: new ( ) ;
940987 let mut database = Database :: new ( ) . await . unwrap ( ) ;
941- let telemetry = TelemetryThread :: new ( & env, & mut database) . await . unwrap ( ) ;
942988
943989 let mut tool_manager = ToolManager :: default ( ) ;
944990 let mut conversation_state = ConversationState :: new (
945991 Context :: new ( ) ,
946992 "fake_conv_id" ,
947- tool_manager. load_tools ( & database, & telemetry ) . await . unwrap ( ) ,
993+ tool_manager. load_tools ( & database) . await . unwrap ( ) ,
948994 None ,
949995 None ,
996+ tool_manager,
950997 )
951998 . await ;
952999
@@ -964,18 +1011,18 @@ mod tests {
9641011
9651012 #[ tokio:: test]
9661013 async fn test_conversation_state_history_handling_with_tool_results ( ) {
967- let env = Env :: new ( ) ;
9681014 let mut database = Database :: new ( ) . await . unwrap ( ) ;
969- let telemetry = TelemetryThread :: new ( & env, & mut database) . await . unwrap ( ) ;
9701015
9711016 // Build a long conversation history of tool use results.
9721017 let mut tool_manager = ToolManager :: default ( ) ;
1018+ let tool_config = tool_manager. load_tools ( & database) . await . unwrap ( ) ;
9731019 let mut conversation_state = ConversationState :: new (
9741020 Context :: new ( ) ,
9751021 "fake_conv_id" ,
976- tool_manager . load_tools ( & database , & telemetry ) . await . unwrap ( ) ,
1022+ tool_config . clone ( ) ,
9771023 None ,
9781024 None ,
1025+ tool_manager. clone ( ) ,
9791026 )
9801027 . await ;
9811028 conversation_state. set_next_user_message ( "start" . to_string ( ) ) . await ;
@@ -1002,9 +1049,10 @@ mod tests {
10021049 let mut conversation_state = ConversationState :: new (
10031050 Context :: new ( ) ,
10041051 "fake_conv_id" ,
1005- tool_manager . load_tools ( & database , & telemetry ) . await . unwrap ( ) ,
1052+ tool_config . clone ( ) ,
10061053 None ,
10071054 None ,
1055+ tool_manager. clone ( ) ,
10081056 )
10091057 . await ;
10101058 conversation_state. set_next_user_message ( "start" . to_string ( ) ) . await ;
@@ -1035,9 +1083,7 @@ mod tests {
10351083
10361084 #[ tokio:: test]
10371085 async fn test_conversation_state_with_context_files ( ) {
1038- let env = Env :: new ( ) ;
10391086 let mut database = Database :: new ( ) . await . unwrap ( ) ;
1040- let telemetry = TelemetryThread :: new ( & env, & mut database) . await . unwrap ( ) ;
10411087
10421088 let ctx = Context :: builder ( ) . with_test_home ( ) . await . unwrap ( ) . build_fake ( ) ;
10431089 ctx. fs ( ) . write ( AMAZONQ_FILENAME , "test context" ) . await . unwrap ( ) ;
@@ -1046,9 +1092,10 @@ mod tests {
10461092 let mut conversation_state = ConversationState :: new (
10471093 ctx,
10481094 "fake_conv_id" ,
1049- tool_manager. load_tools ( & database, & telemetry ) . await . unwrap ( ) ,
1095+ tool_manager. load_tools ( & database) . await . unwrap ( ) ,
10501096 None ,
10511097 None ,
1098+ tool_manager,
10521099 )
10531100 . await ;
10541101
@@ -1085,9 +1132,7 @@ mod tests {
10851132 async fn test_conversation_state_additional_context ( ) {
10861133 // tracing_subscriber::fmt::try_init().ok();
10871134
1088- let env = Env :: new ( ) ;
10891135 let mut database = Database :: new ( ) . await . unwrap ( ) ;
1090- let telemetry = TelemetryThread :: new ( & env, & mut database) . await . unwrap ( ) ;
10911136
10921137 let mut tool_manager = ToolManager :: default ( ) ;
10931138 let ctx = Context :: builder ( ) . with_test_home ( ) . await . unwrap ( ) . build_fake ( ) ;
@@ -1116,9 +1161,10 @@ mod tests {
11161161 let mut conversation_state = ConversationState :: new (
11171162 ctx,
11181163 "fake_conv_id" ,
1119- tool_manager. load_tools ( & database, & telemetry ) . await . unwrap ( ) ,
1164+ tool_manager. load_tools ( & database) . await . unwrap ( ) ,
11201165 None ,
11211166 Some ( SharedWriter :: stdout ( ) ) ,
1167+ tool_manager,
11221168 )
11231169 . await ;
11241170
0 commit comments