11use std:: collections:: {
22 HashMap ,
3+ HashSet ,
34 VecDeque ,
45} ;
56use std:: sync:: Arc ;
@@ -32,7 +33,6 @@ use super::hooks::{
3233} ;
3334use super :: message:: {
3435 AssistantMessage ,
35- AssistantToolUse ,
3636 ToolUseResult ,
3737 ToolUseResultBlock ,
3838 UserMessage ,
@@ -60,7 +60,6 @@ use crate::api_client::model::{
6060 ToolInputSchema ,
6161 ToolResult ,
6262 ToolResultContentBlock ,
63- ToolResultStatus ,
6463 ToolSpecification ,
6564 ToolUse ,
6665 UserInputMessage ,
@@ -347,7 +346,7 @@ impl ConversationState {
347346 }
348347 }
349348
350- self . enforce_tool_use_history_invariants ( true ) ;
349+ self . enforce_tool_use_history_invariants ( ) ;
351350 }
352351
353352 /// Here we also need to make sure that the tool result corresponds to one of the tools
@@ -362,105 +361,51 @@ impl ConversationState {
362361 /// intervention here is to substitute the ambiguous, partial name with a dummy.
363362 /// 3. The model had decided to call a tool that does not exist. The intervention here is to
364363 /// substitute the non-existent tool name with a dummy.
365- pub fn enforce_tool_use_history_invariants ( & mut self , last_only : bool ) {
366- let tool_name_list = self . tool_manager . tn_map . keys ( ) . map ( String :: as_str) . collect :: < Vec < _ > > ( ) ;
367- // We need to first determine what the range of interest is. There are two places where we
368- // would call this function:
369- // 1. When there are changes to the list of available tools, in which case we comb through the
370- // entire conversation
371- // 2. When we send a message, in which case we only examine the most recent entry
372- let ( tool_use_results, mut tool_uses) = if last_only {
373- if let ( Some ( ( _, AssistantMessage :: ToolUse { ref mut tool_uses, .. } ) ) , Some ( user_msg) ) = (
374- self . history
375- . range_mut ( self . valid_history_range . 0 ..self . valid_history_range . 1 )
376- . last ( ) ,
377- & mut self . next_message ,
378- ) {
379- let tool_use_results = user_msg
380- . tool_use_results ( )
381- . map_or ( Vec :: new ( ) , |results| results. iter ( ) . collect :: < Vec < _ > > ( ) ) ;
382- let tool_uses = tool_uses. iter_mut ( ) . collect :: < Vec < _ > > ( ) ;
383- ( tool_use_results, tool_uses)
384- } else {
385- ( Vec :: new ( ) , Vec :: new ( ) )
386- }
387- } else {
388- let tool_use_results = self . next_message . as_ref ( ) . map_or ( Vec :: new ( ) , |user_msg| {
389- user_msg
390- . tool_use_results ( )
391- . map_or ( Vec :: new ( ) , |results| results. iter ( ) . collect :: < Vec < _ > > ( ) )
392- } ) ;
393- self . history
394- . iter_mut ( )
395- . filter_map ( |( user_msg, asst_msg) | {
396- if let ( Some ( tool_use_results) , AssistantMessage :: ToolUse { ref mut tool_uses, .. } ) =
397- ( user_msg. tool_use_results ( ) , asst_msg)
398- {
399- Some ( ( tool_use_results, tool_uses) )
400- } else {
401- None
402- }
364+ pub fn enforce_tool_use_history_invariants ( & mut self ) {
365+ let tool_names: HashSet < _ > = self
366+ . tools
367+ . values ( )
368+ . flat_map ( |tools| {
369+ tools. iter ( ) . map ( |tool| match tool {
370+ Tool :: ToolSpecification ( tool_specification) => tool_specification. name . as_str ( ) ,
403371 } )
404- . fold (
405- ( tool_use_results, Vec :: < & mut AssistantToolUse > :: new ( ) ) ,
406- |( mut tool_use_results, mut tool_uses) , ( results, uses) | {
407- let mut results = results. iter ( ) . collect :: < Vec < _ > > ( ) ;
408- let mut uses = uses. iter_mut ( ) . collect :: < Vec < _ > > ( ) ;
409- tool_use_results. append ( & mut results) ;
410- tool_uses. append ( & mut uses) ;
411- ( tool_use_results, tool_uses)
412- } ,
413- )
414- } ;
372+ } )
373+ . collect ( ) ;
415374
416- // Replace tool uses associated with tools that does not exist / no longer exists with
417- // dummy (i.e. put them to sleep / dormant)
418- for result in tool_use_results {
419- let tool_use_id = result. tool_use_id . as_str ( ) ;
420- let corresponding_tool_use = tool_uses. iter_mut ( ) . find ( |tool_use| tool_use_id == tool_use. id ) ;
421- if let Some ( tool_use) = corresponding_tool_use {
422- if tool_name_list. contains ( & tool_use. name . as_str ( ) ) {
423- // If this tool matches of the tools in our list, this is not our
424- // concern, error or not.
425- continue ;
426- }
427- if let ToolResultStatus :: Error = result. status {
428- // case 2 and 3
429- tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
430- tool_use. args = serde_json:: json!( { } ) ;
431- } else {
432- // case 1
433- let full_name = tool_name_list. iter ( ) . find ( |name| name. ends_with ( & tool_use. name ) ) ;
434- // We should be able to find a match but if not we'll just treat it as
435- // a dummy and move on
436- if let Some ( full_name) = full_name {
437- tool_use. name = ( * full_name) . to_string ( ) ;
438- } else {
439- tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
440- tool_use. args = serde_json:: json!( { } ) ;
375+ for ( _, assistant) in & mut self . history {
376+ if let AssistantMessage :: ToolUse { ref mut tool_uses, .. } = assistant {
377+ for tool_use in tool_uses {
378+ if tool_names. contains ( tool_use. name . as_str ( ) ) {
379+ continue ;
441380 }
442- }
443- }
444- }
445381
446- // Revive tools that were previously dormant if they now corresponds to one of the tools in
447- // our list of available tools. Note that this check only works because tn_map does NOT
448- // contain names of native tools.
449- for tool_use in tool_uses {
450- if tool_use. name == DUMMY_TOOL_NAME
451- && tool_use
452- . orig_name
453- . as_ref ( )
454- . is_some_and ( |name| tool_name_list. contains ( & ( * name) . as_str ( ) ) )
455- {
456- tool_use. name = tool_use
457- . orig_name
458- . as_ref ( )
459- . map_or ( DUMMY_TOOL_NAME . to_string ( ) , |name| name. clone ( ) ) ;
460- tool_use. args = tool_use
461- . orig_args
462- . as_ref ( )
463- . map_or ( serde_json:: json!( { } ) , |args| args. clone ( ) ) ;
382+ if tool_names. contains ( tool_use. orig_name . as_str ( ) ) {
383+ tool_use. name = tool_use. orig_name . clone ( ) ;
384+ tool_use. args = tool_use. orig_args . clone ( ) ;
385+ continue ;
386+ }
387+
388+ let names: Vec < & str > = tool_names
389+ . iter ( )
390+ . filter_map ( |name| {
391+ if name. ends_with ( & tool_use. name ) {
392+ Some ( * name)
393+ } else {
394+ None
395+ }
396+ } )
397+ . collect ( ) ;
398+
399+ // There's only one tool use matching, so we can just replace it with the
400+ // found name.
401+ if names. len ( ) == 1 {
402+ tool_use. name = ( * names. first ( ) . unwrap ( ) ) . to_string ( ) ;
403+ continue ;
404+ }
405+
406+ // Otherwise, we have to replace it with a dummy.
407+ tool_use. name = DUMMY_TOOL_NAME . to_string ( ) ;
408+ }
464409 }
465410 }
466411 }
@@ -540,12 +485,13 @@ impl ConversationState {
540485 // We call this in [Self::enforce_conversation_invariants] as well. But we need to call it
541486 // here as well because when it's being called in [Self::enforce_conversation_invariants]
542487 // it is only checking the last entry.
543- self . enforce_tool_use_history_invariants ( false ) ;
488+ self . enforce_tool_use_history_invariants ( ) ;
544489 }
545490
546491 /// Returns a conversation state representation which reflects the exact conversation to send
547492 /// back to the model.
548493 pub async fn backend_conversation_state ( & mut self , run_hooks : bool , quiet : bool ) -> BackendConversationState < ' _ > {
494+ self . update_state ( false ) . await ;
549495 self . enforce_conversation_invariants ( ) ;
550496
551497 // Run hooks and add to conversation start and next user message.
0 commit comments