1+ use std:: collections:: vec_deque:: IterMut ;
12use std:: collections:: {
23 HashMap ,
34 VecDeque ,
@@ -364,6 +365,78 @@ impl ConversationState {
364365 }
365366 }
366367
368+ // Here we also need to make sure that the tool result corresponds to one of the tools
369+ // in the list. Otherwise we will see validation error from the backend. There are three
370+ // such circumstances where intervention would be needed:
371+ // 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
372+ // which case we would automatically resolve this tool call to its correct name. This will NOT
373+ // result in an error in its tool result. The intervention here is to substitute the partial name
374+ // with its full name.
375+ // 2. The model had decided to call a tool with its partial name AND there are multiple tools it
376+ // could be referring to, in which case we WILL return an error in the tool result. The
377+ // intervention here is to substitute the ambiguous, partial name with a dummy.
378+ // 3. The model had decided to call a tool that does not exist. The intervention here is to
379+ // substitute the non-existent tool name with a dummy.
380+ fn enforce_tool_use_invariants ( & mut self , history_of_interest : & mut Vec < ( UserMessage , AssistantMessage ) > ) {
381+ let tool_name_list = self . tool_manager . tn_map . keys ( ) . map ( String :: as_str) . collect :: < Vec < _ > > ( ) ;
382+ let mut tool_uses = history_of_interest
383+ . iter_mut ( )
384+ . filter_map ( |( _user_msg, asst_msg) | {
385+ if let AssistantMessage :: ToolUse { ref mut tool_uses, .. } = asst_msg {
386+ Some ( tool_uses)
387+ } else {
388+ None
389+ }
390+ } )
391+ . flatten ( ) ;
392+ let tool_use_results = if let Some ( user_msg) = & self . next_message {
393+ // We only check to verify the last message if [Self::next_message] is set
394+ user_msg. tool_use_results ( ) . map ( |arr| arr. iter ( ) . collect :: < Vec < _ > > ( ) )
395+ } else {
396+ // Otherwise, we check the entire conversation
397+ Some (
398+ history_of_interest
399+ . iter ( )
400+ . filter_map ( |( user_msg, _) | user_msg. tool_use_results ( ) )
401+ . flatten ( )
402+ . collect :: < Vec < _ > > ( ) ,
403+ )
404+ } ;
405+ if let Some ( tool_use_results) = tool_use_results {
406+ // Note that we need to use the keys in tool manager's tn_map as the keys are the
407+ // actual tool names as exposed to the model and the backend. If we use the actual
408+ // names as they are recognized by their respective servers, we risk concluding
409+ // with false positives.
410+ for result in tool_use_results {
411+ let tool_use_id = result. tool_use_id . as_str ( ) ;
412+ let corresponding_tool_use = tool_uses. find ( |tool_use| tool_use_id == tool_use. id ) ;
413+ if let Some ( tool_use) = corresponding_tool_use {
414+ if tool_name_list. contains ( & tool_use. name . as_str ( ) ) {
415+ // If this tool matches of the tools in our list, this is not our
416+ // concern, error or not.
417+ continue ;
418+ }
419+ if let ToolResultStatus :: Error = result. status {
420+ // case 2 and 3
421+ tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
422+ tool_use. args = serde_json:: json!( { } ) ;
423+ } else {
424+ // case 1
425+ let full_name = tool_name_list. iter ( ) . find ( |name| name. ends_with ( & tool_use. name ) ) ;
426+ // We should be able to find a match but if not we'll just treat it as
427+ // a dummy and move on
428+ if let Some ( full_name) = full_name {
429+ tool_use. name = ( * full_name) . to_string ( ) ;
430+ } else {
431+ tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
432+ tool_use. args = serde_json:: json!( { } ) ;
433+ }
434+ }
435+ }
436+ }
437+ }
438+ }
439+
367440 pub fn add_tool_results ( & mut self , tool_results : Vec < ToolUseResult > ) {
368441 debug_assert ! ( self . next_message. is_none( ) ) ;
369442 self . next_message = Some ( UserMessage :: new_tool_use_results ( tool_results) ) ;
@@ -388,7 +461,6 @@ impl ConversationState {
388461 /// - `run_hooks` - whether hooks should be executed and included as context
389462 pub async fn as_sendable_conversation_state ( & mut self , run_hooks : bool ) -> FigConversationState {
390463 debug_assert ! ( self . next_message. is_some( ) ) ;
391- self . update_state ( ) . await ;
392464 self . enforce_conversation_invariants ( ) ;
393465 self . history . drain ( self . valid_history_range . 1 ..) ;
394466 self . history . drain ( ..self . valid_history_range . 0 ) ;
@@ -420,6 +492,7 @@ impl ConversationState {
420492 return ;
421493 }
422494 self . tool_manager . update ( ) . await ;
495+ // TODO: make this more targetted so we don't have to clone the entire list of tools
423496 self . tools = self
424497 . tool_manager
425498 . schema
0 commit comments