1- use std:: collections:: vec_deque:: IterMut ;
21use std:: collections:: {
32 HashMap ,
43 VecDeque ,
@@ -33,6 +32,7 @@ use super::hooks::{
3332} ;
3433use super :: message:: {
3534 AssistantMessage ,
35+ AssistantToolUse ,
3636 ToolUseResult ,
3737 ToolUseResultBlock ,
3838 UserMessage ,
@@ -314,127 +314,113 @@ impl ConversationState {
314314 tool_uses. iter ( ) . map ( |t| t. id . as_str ( ) ) ,
315315 ) ;
316316 }
317-
318- // Here we also need to make sure that the tool result corresponds to one of the tools
319- // in the list. Otherwise we will see validation error from the backend. There are three
320- // such circumstances where intervention would be needed:
321- // 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
322- // which case we would automatically resolve this tool call to its correct name. This will NOT
323- // result in an error in its tool result. The intervention here is to substitute the partial name
324- // with its full name.
325- // 2. The model had decided to call a tool with its partial name AND there are multiple tools it
326- // could be referring to, in which case we WILL return an error in the tool result. The
327- // intervention here is to substitute the ambiguous, partial name with a dummy.
328- // 3. The model had decided to call a tool that does not exist. The intervention here is to
329- // substitute the non-existent tool name with a dummy.
330- let tool_use_results = user_msg. tool_use_results ( ) ;
331- if let Some ( tool_use_results) = tool_use_results {
332- // Note that we need to use the keys in tool manager's tn_map as the keys are the
333- // actual tool names as exposed to the model and the backend. If we use the actual
334- // names as they are recognized by their respective servers, we risk concluding
335- // with false positives.
336- let tool_name_list = self . tool_manager . tn_map . keys ( ) . map ( String :: as_str) . collect :: < Vec < _ > > ( ) ;
337- for result in tool_use_results {
338- let tool_use_id = result. tool_use_id . as_str ( ) ;
339- let corresponding_tool_use = tool_uses. iter_mut ( ) . find ( |tool_use| tool_use_id == tool_use. id ) ;
340- if let Some ( tool_use) = corresponding_tool_use {
341- if tool_name_list. contains ( & tool_use. name . as_str ( ) ) {
342- // If this tool matches of the tools in our list, this is not our
343- // concern, error or not.
344- continue ;
345- }
346- if let ToolResultStatus :: Error = result. status {
347- // case 2 and 3
348- tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
349- tool_use. args = serde_json:: json!( { } ) ;
350- } else {
351- // case 1
352- let full_name = tool_name_list. iter ( ) . find ( |name| name. ends_with ( & tool_use. name ) ) ;
353- // We should be able to find a match but if not we'll just treat it as
354- // a dummy and move on
355- if let Some ( full_name) = full_name {
356- tool_use. name = ( * full_name) . to_string ( ) ;
357- } else {
358- tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
359- tool_use. args = serde_json:: json!( { } ) ;
360- }
361- }
362- }
363- }
364- }
317+ self . enforce_tool_use_history_invariants ( ) ;
365318 }
366319 }
367320
368- // Here we also need to make sure that the tool result corresponds to one of the tools
369- // in the list. Otherwise we will see validation error from the backend. There are three
370- // such circumstances where intervention would be needed:
371- // 1. The model had decided to call a tool with its partial name AND there is only one such tool, in
372- // which case we would automatically resolve this tool call to its correct name. This will NOT
373- // result in an error in its tool result. The intervention here is to substitute the partial name
374- // with its full name.
375- // 2. The model had decided to call a tool with its partial name AND there are multiple tools it
376- // could be referring to, in which case we WILL return an error in the tool result. The
377- // intervention here is to substitute the ambiguous, partial name with a dummy.
378- // 3. The model had decided to call a tool that does not exist. The intervention here is to
379- // substitute the non-existent tool name with a dummy.
380- fn enforce_tool_use_invariants ( & mut self , history_of_interest : & mut Vec < ( UserMessage , AssistantMessage ) > ) {
321+ /// Here we also need to make sure that the tool result corresponds to one of the tools
322+ /// in the list. Otherwise we will see validation error from the backend. There are three
323+ /// such circumstances where intervention would be needed:
324+ /// 1. The model had decided to call a tool with its partial name AND there is only one such
325+ /// tool, in which case we would automatically resolve this tool call to its correct name.
326+ /// This will NOT result in an error in its tool result. The intervention here is to
327+ /// substitute the partial name with its full name.
328+ /// 2. The model had decided to call a tool with its partial name AND there are multiple tools
329+ /// it could be referring to, in which case we WILL return an error in the tool result. The
330+ /// intervention here is to substitute the ambiguous, partial name with a dummy.
331+ /// 3. The model had decided to call a tool that does not exist. The intervention here is to
332+ /// substitute the non-existent tool name with a dummy.
333+ pub fn enforce_tool_use_history_invariants ( & mut self ) {
381334 let tool_name_list = self . tool_manager . tn_map . keys ( ) . map ( String :: as_str) . collect :: < Vec < _ > > ( ) ;
382- let mut tool_uses = history_of_interest
383- . iter_mut ( )
384- . filter_map ( |( _user_msg, asst_msg) | {
385- if let AssistantMessage :: ToolUse { ref mut tool_uses, .. } = asst_msg {
386- Some ( tool_uses)
387- } else {
388- None
335+ // We need to first determine what the range of interest is. There are two places where we
336+ // would call this function:
337+ // 1. When there are changes to the list of available tools, in which case we comb through the
338+ // entire conversation
339+ // 2. When we send a message, in which case we only examine the most recent entry
340+ let ( tool_use_results, mut tool_uses) =
341+ if let ( Some ( ( _, AssistantMessage :: ToolUse { ref mut tool_uses, .. } ) ) , Some ( user_msg) ) = (
342+ self . history
343+ . range_mut ( self . valid_history_range . 0 ..self . valid_history_range . 1 )
344+ . last ( ) ,
345+ & mut self . next_message ,
346+ ) {
347+ let tool_use_results = user_msg
348+ . tool_use_results ( )
349+ . map_or ( Vec :: new ( ) , |results| results. iter ( ) . collect :: < Vec < _ > > ( ) ) ;
350+ let tool_uses = tool_uses. iter_mut ( ) . collect :: < Vec < _ > > ( ) ;
351+ ( tool_use_results, tool_uses)
352+ } else {
353+ self . history
354+ . iter_mut ( )
355+ . filter_map ( |( user_msg, asst_msg) | {
356+ if let ( Some ( tool_use_results) , AssistantMessage :: ToolUse { ref mut tool_uses, .. } ) =
357+ ( user_msg. tool_use_results ( ) , asst_msg)
358+ {
359+ Some ( ( tool_use_results, tool_uses) )
360+ } else {
361+ None
362+ }
363+ } )
364+ . fold (
365+ ( Vec :: < & ToolUseResult > :: new ( ) , Vec :: < & mut AssistantToolUse > :: new ( ) ) ,
366+ |( mut tool_use_results, mut tool_uses) , ( results, uses) | {
367+ let mut results = results. iter ( ) . collect :: < Vec < _ > > ( ) ;
368+ let mut uses = uses. iter_mut ( ) . collect :: < Vec < _ > > ( ) ;
369+ tool_use_results. append ( & mut results) ;
370+ tool_uses. append ( & mut uses) ;
371+ ( tool_use_results, tool_uses)
372+ } ,
373+ )
374+ } ;
375+ // Replace tool uses associated with tools that does not exist / no longer exists with
376+ // dummy (i.e. put them to sleep / dormant)
377+ for result in tool_use_results {
378+ let tool_use_id = result. tool_use_id . as_str ( ) ;
379+ let corresponding_tool_use = tool_uses. iter_mut ( ) . find ( |tool_use| tool_use_id == tool_use. id ) ;
380+ if let Some ( tool_use) = corresponding_tool_use {
381+ if tool_name_list. contains ( & tool_use. name . as_str ( ) ) {
382+ // If this tool matches of the tools in our list, this is not our
383+ // concern, error or not.
384+ continue ;
389385 }
390- } )
391- . flatten ( ) ;
392- let tool_use_results = if let Some ( user_msg) = & self . next_message {
393- // We only check to verify the last message if [Self::next_message] is set
394- user_msg. tool_use_results ( ) . map ( |arr| arr. iter ( ) . collect :: < Vec < _ > > ( ) )
395- } else {
396- // Otherwise, we check the entire conversation
397- Some (
398- history_of_interest
399- . iter ( )
400- . filter_map ( |( user_msg, _) | user_msg. tool_use_results ( ) )
401- . flatten ( )
402- . collect :: < Vec < _ > > ( ) ,
403- )
404- } ;
405- if let Some ( tool_use_results) = tool_use_results {
406- // Note that we need to use the keys in tool manager's tn_map as the keys are the
407- // actual tool names as exposed to the model and the backend. If we use the actual
408- // names as they are recognized by their respective servers, we risk concluding
409- // with false positives.
410- for result in tool_use_results {
411- let tool_use_id = result. tool_use_id . as_str ( ) ;
412- let corresponding_tool_use = tool_uses. find ( |tool_use| tool_use_id == tool_use. id ) ;
413- if let Some ( tool_use) = corresponding_tool_use {
414- if tool_name_list. contains ( & tool_use. name . as_str ( ) ) {
415- // If this tool matches of the tools in our list, this is not our
416- // concern, error or not.
417- continue ;
418- }
419- if let ToolResultStatus :: Error = result. status {
420- // case 2 and 3
386+ if let ToolResultStatus :: Error = result. status {
387+ // case 2 and 3
388+ tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
389+ tool_use. args = serde_json:: json!( { } ) ;
390+ } else {
391+ // case 1
392+ let full_name = tool_name_list. iter ( ) . find ( |name| name. ends_with ( & tool_use. name ) ) ;
393+ // We should be able to find a match but if not we'll just treat it as
394+ // a dummy and move on
395+ if let Some ( full_name) = full_name {
396+ tool_use. name = ( * full_name) . to_string ( ) ;
397+ } else {
421398 tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
422399 tool_use. args = serde_json:: json!( { } ) ;
423- } else {
424- // case 1
425- let full_name = tool_name_list. iter ( ) . find ( |name| name. ends_with ( & tool_use. name ) ) ;
426- // We should be able to find a match but if not we'll just treat it as
427- // a dummy and move on
428- if let Some ( full_name) = full_name {
429- tool_use. name = ( * full_name) . to_string ( ) ;
430- } else {
431- tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
432- tool_use. args = serde_json:: json!( { } ) ;
433- }
434400 }
435401 }
436402 }
437403 }
404+ // Revive tools that were previously dormant if they now corresponds to one of the tools in
405+ // our list of available tools. Note that this check only works because tn_map does NOT
406+ // contain names of native tools.
407+ for tool_use in tool_uses {
408+ if tool_use. name == DUMMY_TOOL_NAME
409+ && tool_use
410+ . orig_name
411+ . as_ref ( )
412+ . is_some_and ( |name| tool_name_list. contains ( & ( * name) . as_str ( ) ) )
413+ {
414+ tool_use. name = tool_use
415+ . orig_name
416+ . as_ref ( )
417+ . map_or ( DUMMY_TOOL_NAME . to_string ( ) , |name| name. clone ( ) ) ;
418+ tool_use. args = tool_use
419+ . orig_args
420+ . as_ref ( )
421+ . map_or ( serde_json:: json!( { } ) , |args| args. clone ( ) ) ;
422+ }
423+ }
438424 }
439425
440426 pub fn add_tool_results ( & mut self , tool_results : Vec < ToolUseResult > ) {
@@ -492,7 +478,7 @@ impl ConversationState {
492478 return ;
493479 }
494480 self . tool_manager . update ( ) . await ;
495- // TODO: make this more targetted so we don't have to clone the entire list of tools
481+ // TODO: make this more targeted so we don't have to clone the entire list of tools
496482 self . tools = self
497483 . tool_manager
498484 . schema
@@ -509,6 +495,10 @@ impl ConversationState {
509495 acc
510496 } ) ;
511497 self . tool_manager . has_new_stuff . store ( false , Ordering :: Release ) ;
498+ // We call this in [Self::enforce_conversation_invariants] as well. But we need to call it
499+ // here as well because when it's being called in [Self::enforce_conversation_invariants]
500+ // it is only checking the last entry.
501+ self . enforce_tool_use_history_invariants ( ) ;
512502 }
513503
514504 /// Returns a conversation state representation which reflects the exact conversation to send
@@ -1108,6 +1098,7 @@ mod tests {
11081098 id: "tool_id" . to_string( ) ,
11091099 name: "tool name" . to_string( ) ,
11101100 args: serde_json:: Value :: Null ,
1101+ ..Default :: default ( )
11111102 } ] ) ,
11121103 & mut database,
11131104 ) ;
@@ -1138,6 +1129,7 @@ mod tests {
11381129 id: "tool_id" . to_string( ) ,
11391130 name: "tool name" . to_string( ) ,
11401131 args: serde_json:: Value :: Null ,
1132+ ..Default :: default ( )
11411133 } ] ) ,
11421134 & mut database,
11431135 ) ;
0 commit comments