@@ -314,8 +314,9 @@ impl ConversationState {
314314 tool_uses. iter ( ) . map ( |t| t. id . as_str ( ) ) ,
315315 ) ;
316316 }
317- self . enforce_tool_use_history_invariants ( ) ;
318317 }
318+
319+ self . enforce_tool_use_history_invariants ( true ) ;
319320 }
320321
321322 /// Here we also need to make sure that the tool result corresponds to one of the tools
@@ -330,14 +331,14 @@ impl ConversationState {
330331 /// intervention here is to substitute the ambiguous, partial name with a dummy.
331332 /// 3. The model had decided to call a tool that does not exist. The intervention here is to
332333 /// substitute the non-existent tool name with a dummy.
333- pub fn enforce_tool_use_history_invariants ( & mut self ) {
334+ pub fn enforce_tool_use_history_invariants ( & mut self , last_only : bool ) {
334335 let tool_name_list = self . tool_manager . tn_map . keys ( ) . map ( String :: as_str) . collect :: < Vec < _ > > ( ) ;
335336 // We need to first determine what the range of interest is. There are two places where we
336337 // would call this function:
337338 // 1. When there are changes to the list of available tools, in which case we comb through the
338339 // entire conversation
339340 // 2. When we send a message, in which case we only examine the most recent entry
340- let ( tool_use_results, mut tool_uses) =
341+ let ( tool_use_results, mut tool_uses) = if last_only {
341342 if let ( Some ( ( _, AssistantMessage :: ToolUse { ref mut tool_uses, .. } ) ) , Some ( user_msg) ) = (
342343 self . history
343344 . range_mut ( self . valid_history_range . 0 ..self . valid_history_range . 1 )
@@ -350,28 +351,37 @@ impl ConversationState {
350351 let tool_uses = tool_uses. iter_mut ( ) . collect :: < Vec < _ > > ( ) ;
351352 ( tool_use_results, tool_uses)
352353 } else {
353- self . history
354- . iter_mut ( )
355- . filter_map ( |( user_msg, asst_msg) | {
356- if let ( Some ( tool_use_results) , AssistantMessage :: ToolUse { ref mut tool_uses, .. } ) =
357- ( user_msg. tool_use_results ( ) , asst_msg)
358- {
359- Some ( ( tool_use_results, tool_uses) )
360- } else {
361- None
362- }
363- } )
364- . fold (
365- ( Vec :: < & ToolUseResult > :: new ( ) , Vec :: < & mut AssistantToolUse > :: new ( ) ) ,
366- |( mut tool_use_results, mut tool_uses) , ( results, uses) | {
367- let mut results = results. iter ( ) . collect :: < Vec < _ > > ( ) ;
368- let mut uses = uses. iter_mut ( ) . collect :: < Vec < _ > > ( ) ;
369- tool_use_results. append ( & mut results) ;
370- tool_uses. append ( & mut uses) ;
371- ( tool_use_results, tool_uses)
372- } ,
373- )
374- } ;
354+ ( Vec :: new ( ) , Vec :: new ( ) )
355+ }
356+ } else {
357+ let tool_use_results = self . next_message . as_ref ( ) . map_or ( Vec :: new ( ) , |user_msg| {
358+ user_msg
359+ . tool_use_results ( )
360+ . map_or ( Vec :: new ( ) , |results| results. iter ( ) . collect :: < Vec < _ > > ( ) )
361+ } ) ;
362+ self . history
363+ . iter_mut ( )
364+ . filter_map ( |( user_msg, asst_msg) | {
365+ if let ( Some ( tool_use_results) , AssistantMessage :: ToolUse { ref mut tool_uses, .. } ) =
366+ ( user_msg. tool_use_results ( ) , asst_msg)
367+ {
368+ Some ( ( tool_use_results, tool_uses) )
369+ } else {
370+ None
371+ }
372+ } )
373+ . fold (
374+ ( tool_use_results, Vec :: < & mut AssistantToolUse > :: new ( ) ) ,
375+ |( mut tool_use_results, mut tool_uses) , ( results, uses) | {
376+ let mut results = results. iter ( ) . collect :: < Vec < _ > > ( ) ;
377+ let mut uses = uses. iter_mut ( ) . collect :: < Vec < _ > > ( ) ;
378+ tool_use_results. append ( & mut results) ;
379+ tool_uses. append ( & mut uses) ;
380+ ( tool_use_results, tool_uses)
381+ } ,
382+ )
383+ } ;
384+
375385 // Replace tool uses associated with tools that does not exist / no longer exists with
376386 // dummy (i.e. put them to sleep / dormant)
377387 for result in tool_use_results {
@@ -401,6 +411,7 @@ impl ConversationState {
401411 }
402412 }
403413 }
414+
404415 // Revive tools that were previously dormant if they now corresponds to one of the tools in
405416 // our list of available tools. Note that this check only works because tn_map does NOT
406417 // contain names of native tools.
@@ -498,7 +509,7 @@ impl ConversationState {
498509 // We call this in [Self::enforce_conversation_invariants] as well. But we need to call it
499510 // here as well because when it's being called in [Self::enforce_conversation_invariants]
500511 // it is only checking the last entry.
501- self . enforce_tool_use_history_invariants ( ) ;
512+ self . enforce_tool_use_history_invariants ( false ) ;
502513 }
503514
504515 /// Returns a conversation state representation which reflects the exact conversation to send
0 commit comments