@@ -27,7 +27,6 @@ use tracing::{
2727 debug,
2828 error,
2929 info,
30- trace,
3130 warn,
3231} ;
3332
@@ -137,75 +136,64 @@ impl ConversationState {
137136 }
138137
139138 /// Updates the history so that, when non-empty, the following invariants are in place:
140- /// 1. The history length is <= MAX_CONVERSATION_STATE_HISTORY_LEN if the next user message does
141- /// not contain tool results. Oldest messages are dropped.
139+ /// 1. The history length is ` <= MAX_CONVERSATION_STATE_HISTORY_LEN`. Oldest messages are
140+ /// dropped.
142141 /// 2. The first message is from the user, and does not contain tool results. Oldest messages
143142 /// are dropped.
144143 /// 3. The last message is from the assistant. The last message is dropped if it is from the
145144 /// user.
146145 pub fn fix_history ( & mut self ) {
147- if self . history . is_empty ( ) {
148- return ;
146+ // Trim the conversation history by finding the second oldest message from the user without
147+ // tool results - this will be the new oldest message in the history.
148+ if self . history . len ( ) > MAX_CONVERSATION_STATE_HISTORY_LEN {
149+ match self
150+ . history
151+ . iter ( )
152+ . enumerate ( )
153+ // Skip the first message which should be from the user.
154+ . skip ( 1 )
155+ . find ( |( _, m) | -> bool {
156+ match m {
157+ ChatMessage :: UserInputMessage ( m) => {
158+ matches ! (
159+ m. user_input_message_context. as_ref( ) ,
160+ Some ( ctx) if ctx. tool_results. as_ref( ) . is_none_or( |v| v. is_empty( ) )
161+ )
162+ } ,
163+ ChatMessage :: AssistantResponseMessage ( _) => false ,
164+ }
165+ } )
166+ . map ( |v| v. 0 )
167+ {
168+ Some ( i) => {
169+ debug ! ( "removing the first {i} elements in the history" ) ;
170+ self . history . drain ( ..i) ;
171+ } ,
172+ None => {
173+ debug ! ( "no valid starting user message found in the history, clearing" ) ;
174+ self . history . clear ( ) ;
175+
176+ // Edge case: if the next message contains tool results, then we have to just
177+ // abandon them.
178+ match & mut self . next_message {
179+ Some ( UserInputMessage {
180+ ref mut content,
181+ user_input_message_context : Some ( ctx) ,
182+ ..
183+ } ) if ctx. tool_results . as_ref ( ) . is_some_and ( |r| !r. is_empty ( ) ) => {
184+ * content = "The conversation history has overflowed, clearing state" . to_string ( ) ;
185+ ctx. tool_results . take ( ) ;
186+ } ,
187+ _ => { } ,
188+ }
189+ } ,
190+ }
149191 }
150192
151- // Invariant (3).
152193 if let Some ( ChatMessage :: UserInputMessage ( msg) ) = self . history . iter ( ) . last ( ) {
153194 debug ! ( ?msg, "last message in history is from the user, dropping" ) ;
154195 self . history . pop_back ( ) ;
155196 }
156-
157- // Check if the next message contains tool results - if it does, then return early.
158- // Required in the case that the entire history consists of tool results; every message is
159- // therefore required to avoid validation errors in the backend.
160- match self . next_message . as_ref ( ) {
161- Some ( UserInputMessage {
162- user_input_message_context : Some ( ctx) ,
163- ..
164- } ) if ctx. tool_results . as_ref ( ) . is_none_or ( |r| r. is_empty ( ) ) => {
165- debug ! (
166- curr_history_len = self . history. len( ) ,
167- max_history_len = MAX_CONVERSATION_STATE_HISTORY_LEN ,
168- "next user message does not contain tool results, removing messages if required"
169- ) ;
170- } ,
171- _ => {
172- debug ! ( "next user message contains tool results, not modifying the history" ) ;
173- return ;
174- } ,
175- }
176-
177- // Invariant (1).
178- while self . history . len ( ) > MAX_CONVERSATION_STATE_HISTORY_LEN {
179- self . history . pop_front ( ) ;
180- }
181-
182- // Invariant (2).
183- match self
184- . history
185- . iter ( )
186- . enumerate ( )
187- . find ( |( _, m) | -> bool {
188- match m {
189- ChatMessage :: UserInputMessage ( m) => {
190- matches ! (
191- m. user_input_message_context. as_ref( ) ,
192- Some ( ctx) if ctx. tool_results. as_ref( ) . is_none_or( |v| v. is_empty( ) )
193- )
194- } ,
195- ChatMessage :: AssistantResponseMessage ( _) => false ,
196- }
197- } )
198- . map ( |v| v. 0 )
199- {
200- Some ( i) => {
201- trace ! ( "removing the first {i} elements in the history" ) ;
202- self . history . drain ( ..i) ;
203- } ,
204- None => {
205- trace ! ( "no valid starting user message found in the history, clearing" ) ;
206- self . history . clear ( ) ;
207- } ,
208- }
209197 }
210198
211199 pub fn add_tool_results ( & mut self , tool_results : Vec < ToolResult > ) {
@@ -229,6 +217,7 @@ impl ConversationState {
229217 self . next_message = Some ( msg) ;
230218 }
231219
220+ /// Sets the next user message with "cancelled" tool results.
232221 pub fn abandon_tool_use ( & mut self , tools_to_be_abandoned : Vec < ( String , super :: tools:: Tool ) > , deny_input : String ) {
233222 debug_assert ! ( self . next_message. is_none( ) ) ;
234223 let tool_results = tools_to_be_abandoned
@@ -260,6 +249,38 @@ impl ConversationState {
260249 self . next_message = Some ( msg) ;
261250 }
262251
252+ /// Sets the next user message with "interrupted" tool results.
253+ pub fn interrupt_tool_use ( & mut self , interrupted_tools : Vec < ( String , super :: tools:: Tool ) > , deny_input : String ) {
254+ debug_assert ! ( self . next_message. is_none( ) ) ;
255+ let tool_results = interrupted_tools
256+ . into_iter ( )
257+ . map ( |( tool_use_id, _) | ToolResult {
258+ tool_use_id,
259+ content : vec ! [ ToolResultContentBlock :: Text (
260+ "Tool use was interrupted by the user" . to_string( ) ,
261+ ) ] ,
262+ status : fig_api_client:: model:: ToolResultStatus :: Error ,
263+ } )
264+ . collect :: < Vec < _ > > ( ) ;
265+ let user_input_message_context = UserInputMessageContext {
266+ shell_state : None ,
267+ env_state : Some ( build_env_state ( ) ) ,
268+ tool_results : Some ( tool_results) ,
269+ tools : if self . tools . is_empty ( ) {
270+ None
271+ } else {
272+ Some ( self . tools . clone ( ) )
273+ } ,
274+ ..Default :: default ( )
275+ } ;
276+ let msg = UserInputMessage {
277+ content : deny_input,
278+ user_input_message_context : Some ( user_input_message_context) ,
279+ user_intent : None ,
280+ } ;
281+ self . next_message = Some ( msg) ;
282+ }
283+
263284 /// Returns a [FigConversationState] capable of being sent by
264285 /// [fig_api_client::StreamingClient] while preparing the current conversation state to be sent
265286 /// in the next message.
@@ -344,6 +365,7 @@ mod tests {
344365 use fig_api_client:: model:: {
345366 AssistantResponseMessage ,
346367 ToolResultStatus ,
368+ ToolUse ,
347369 } ;
348370
349371 use super :: * ;
@@ -365,72 +387,83 @@ mod tests {
365387 println ! ( "{env_state:?}" ) ;
366388 }
367389
390+ fn assert_conversation_state_invariants ( state : FigConversationState , i : usize ) {
391+ if let Some ( Some ( msg) ) = state. history . as_ref ( ) . map ( |h| h. first ( ) ) {
392+ assert ! (
393+ matches!( msg, ChatMessage :: UserInputMessage ( _) ) ,
394+ "{i}: First message in the history must be from the user, instead found: {:?}" ,
395+ msg
396+ ) ;
397+ }
398+ if let Some ( Some ( msg) ) = state. history . as_ref ( ) . map ( |h| h. last ( ) ) {
399+ assert ! (
400+ matches!( msg, ChatMessage :: AssistantResponseMessage ( _) ) ,
401+ "{i}: Last message in the history must be from the assistant, instead found: {:?}" ,
402+ msg
403+ ) ;
404+ // If the last message from the assistant contains tool uses, then the next user
405+ // message must contain tool results.
406+ match ( state. user_input_message . user_input_message_context , msg) {
407+ (
408+ Some ( ctx) ,
409+ ChatMessage :: AssistantResponseMessage ( AssistantResponseMessage {
410+ tool_uses : Some ( tool_uses) ,
411+ ..
412+ } ) ,
413+ ) if !tool_uses. is_empty ( ) => {
414+ assert ! (
415+ ctx. tool_results. is_some_and( |r| !r. is_empty( ) ) ,
416+ "The user input message must contain tool results when the last assistant message contains tool uses"
417+ ) ;
418+ } ,
419+ _ => { } ,
420+ }
421+ }
422+
423+ let actual_history_len = state. history . unwrap_or_default ( ) . len ( ) ;
424+ assert ! (
425+ actual_history_len <= MAX_CONVERSATION_STATE_HISTORY_LEN ,
426+ "history should not extend past the max limit of {}, instead found length {}" ,
427+ MAX_CONVERSATION_STATE_HISTORY_LEN ,
428+ actual_history_len
429+ ) ;
430+ }
431+
368432 #[ tokio:: test]
369- async fn test_conversation_state_history_handling ( ) {
433+ async fn test_conversation_state_history_handling_truncation ( ) {
370434 let mut conversation_state = ConversationState :: new ( load_tools ( ) . unwrap ( ) ) ;
371435
372436 // First, build a large conversation history. We need to ensure that the order is always
373437 // User -> Assistant -> User -> Assistant ...and so on.
374438 conversation_state. append_new_user_message ( "start" . to_string ( ) ) ;
375- for i in 0 ..=100 {
439+ for i in 0 ..=( MAX_CONVERSATION_STATE_HISTORY_LEN + 100 ) {
376440 let s = conversation_state. as_sendable_conversation_state ( ) ;
377- assert ! (
378- s. history
379- . as_ref( )
380- . is_none_or( |h| h. first( ) . is_none_or( |m| matches!( m, ChatMessage :: UserInputMessage ( _) ) ) ) ,
381- "First message in the history must be from the user"
382- ) ;
383- assert ! (
384- s. history. as_ref( ) . is_none_or( |h| h
385- . last( )
386- . is_none_or( |m| matches!( m, ChatMessage :: AssistantResponseMessage ( _) ) ) ) ,
387- "Last message in the history must be from the assistant"
388- ) ;
441+ assert_conversation_state_invariants ( s, i) ;
389442 conversation_state. push_assistant_message ( AssistantResponseMessage {
390443 message_id : None ,
391444 content : i. to_string ( ) ,
392445 tool_uses : None ,
393446 } ) ;
394447 conversation_state. append_new_user_message ( i. to_string ( ) ) ;
395448 }
396-
397- let s = conversation_state. as_sendable_conversation_state ( ) ;
398- assert_eq ! (
399- s. history. as_ref( ) . unwrap( ) . len( ) ,
400- MAX_CONVERSATION_STATE_HISTORY_LEN ,
401- "history should be capped at {}" ,
402- MAX_CONVERSATION_STATE_HISTORY_LEN
403- ) ;
404- let first_msg = s. history . as_ref ( ) . unwrap ( ) . first ( ) . unwrap ( ) ;
405- match first_msg {
406- ChatMessage :: UserInputMessage ( _) => { } ,
407- other @ ChatMessage :: AssistantResponseMessage ( _) => {
408- panic ! ( "First message should be from the user, instead found {:?}" , other)
409- } ,
410- }
411- let last_msg = s. history . as_ref ( ) . unwrap ( ) . iter ( ) . last ( ) . unwrap ( ) ;
412- match last_msg {
413- ChatMessage :: AssistantResponseMessage ( assistant_response_message) => {
414- assert_eq ! ( assistant_response_message. content, "100" ) ;
415- } ,
416- other @ ChatMessage :: UserInputMessage ( _) => {
417- panic ! ( "Last message should be from the assistant, instead found {:?}" , other)
418- } ,
419- }
420449 }
421450
422451 #[ tokio:: test]
423452 async fn test_conversation_state_history_handling_with_tool_results ( ) {
424- let mut conversation_state = ConversationState :: new ( load_tools ( ) . unwrap ( ) ) ;
425-
426453 // Build a long conversation history of tool use results.
454+ let mut conversation_state = ConversationState :: new ( load_tools ( ) . unwrap ( ) ) ;
427455 conversation_state. append_new_user_message ( "start" . to_string ( ) ) ;
428456 for i in 0 ..=( MAX_CONVERSATION_STATE_HISTORY_LEN + 100 ) {
429- let _ = conversation_state. as_sendable_conversation_state ( ) ;
457+ let s = conversation_state. as_sendable_conversation_state ( ) ;
458+ assert_conversation_state_invariants ( s, i) ;
430459 conversation_state. push_assistant_message ( AssistantResponseMessage {
431460 message_id : None ,
432461 content : i. to_string ( ) ,
433- tool_uses : None ,
462+ tool_uses : Some ( vec ! [ ToolUse {
463+ tool_use_id: "tool_id" . to_string( ) ,
464+ name: "tool name" . to_string( ) ,
465+ input: aws_smithy_types:: Document :: Null ,
466+ } ] ) ,
434467 } ) ;
435468 conversation_state. add_tool_results ( vec ! [ ToolResult {
436469 tool_use_id: "tool_id" . to_string( ) ,
@@ -439,13 +472,35 @@ mod tests {
439472 } ] ) ;
440473 }
441474
442- let s = conversation_state. as_sendable_conversation_state ( ) ;
443- let actual_history_len = s. history . as_ref ( ) . unwrap ( ) . len ( ) ;
444- assert ! (
445- actual_history_len > MAX_CONVERSATION_STATE_HISTORY_LEN ,
446- "history should extend past the max limit of {}, instead found length {}" ,
447- MAX_CONVERSATION_STATE_HISTORY_LEN ,
448- actual_history_len
449- ) ;
475+ // Build a long conversation history of user messages mixed in with tool results.
476+ let mut conversation_state = ConversationState :: new ( load_tools ( ) . unwrap ( ) ) ;
477+ conversation_state. append_new_user_message ( "start" . to_string ( ) ) ;
478+ for i in 0 ..=( MAX_CONVERSATION_STATE_HISTORY_LEN + 100 ) {
479+ let s = conversation_state. as_sendable_conversation_state ( ) ;
480+ assert_conversation_state_invariants ( s, i) ;
481+ if i % 3 == 0 {
482+ conversation_state. push_assistant_message ( AssistantResponseMessage {
483+ message_id : None ,
484+ content : i. to_string ( ) ,
485+ tool_uses : Some ( vec ! [ ToolUse {
486+ tool_use_id: "tool_id" . to_string( ) ,
487+ name: "tool name" . to_string( ) ,
488+ input: aws_smithy_types:: Document :: Null ,
489+ } ] ) ,
490+ } ) ;
491+ conversation_state. add_tool_results ( vec ! [ ToolResult {
492+ tool_use_id: "tool_id" . to_string( ) ,
493+ content: vec![ ] ,
494+ status: ToolResultStatus :: Success ,
495+ } ] ) ;
496+ } else {
497+ conversation_state. push_assistant_message ( AssistantResponseMessage {
498+ message_id : None ,
499+ content : i. to_string ( ) ,
500+ tool_uses : None ,
501+ } ) ;
502+ conversation_state. append_new_user_message ( i. to_string ( ) ) ;
503+ }
504+ }
450505 }
451506}
0 commit comments