77
88use std:: collections:: { HashMap , VecDeque } ;
99
10- use genai:: chat:: { ChatMessage , ContentPart , MessageContent , ToolCall , ToolResponse } ;
10+ use genai:: chat:: {
11+ ChatMessage , ChatRole , ContentPart , MessageContent , MessageOptions , ToolCall , ToolResponse ,
12+ } ;
1113use serde:: { Deserialize , Serialize } ;
1214use ts_rs:: TS ;
1315
@@ -102,6 +104,10 @@ pub enum Event {
102104
103105 /// User explicitly cancelled the current operation.
104106 Cancel ,
107+
108+ /// Update the system prompt (replaces conversation[0]).
109+ /// Used when transitioning from initial generation to edit mode.
110+ UpdateSystemPrompt ( ChatMessage ) ,
105111}
106112
107113// ============================================================================
@@ -214,15 +220,46 @@ pub struct Agent {
214220
215221impl Agent {
216222 /// Create a new agent in Idle state.
217- pub fn new ( ) -> Self {
223+ pub fn new ( mut system_prompt : ChatMessage ) -> Self {
224+ system_prompt. options = Some ( MessageOptions {
225+ cache_control : Some ( genai:: chat:: CacheControl :: Ephemeral ) ,
226+ } ) ;
227+
228+ let context = Context {
229+ conversation : vec ! [ system_prompt] ,
230+ ..Context :: default ( )
231+ } ;
232+
218233 Agent {
219234 state : State :: Idle ,
220- context : Context :: default ( ) ,
235+ context,
221236 }
222237 }
223238
224239 /// Create an agent from saved state.
225- pub fn from_saved ( state : State , context : Context ) -> Self {
240+ pub fn from_saved ( state : State , mut context : Context , mut system_prompt : ChatMessage ) -> Self {
241+ system_prompt. options = Some ( MessageOptions {
242+ cache_control : Some ( genai:: chat:: CacheControl :: Ephemeral ) ,
243+ } ) ;
244+
245+ // Ensure the provided system_prompt is present as the first message in conversation.
246+ if let Some ( first) = context. conversation . get_mut ( 0 ) {
247+ let first_role = WrappedChatRole ( first. role . clone ( ) ) ;
248+
249+ if first_role == WrappedChatRole ( ChatRole :: System ) {
250+ * first = system_prompt. clone ( ) ;
251+ } else {
252+ // Role doesn't match, so prepend
253+ let mut new_conversation = vec ! [ system_prompt. clone( ) ] ;
254+ new_conversation. extend ( context. conversation . clone ( ) ) ;
255+ // Replace context.conversation with new one
256+ context. conversation = new_conversation;
257+ }
258+ } else {
259+ // conversation is empty, just insert
260+ context. conversation . push ( system_prompt. clone ( ) ) ;
261+ }
262+
226263 Agent { state, context }
227264 }
228265
@@ -434,11 +471,42 @@ impl Agent {
434471 Transition :: single ( Effect :: Cancelled )
435472 }
436473
474+ // ================================================================
475+ // Update system prompt
476+ // ================================================================
477+ ( _, Event :: UpdateSystemPrompt ( msg) ) => {
478+ // Replace conversation[0] with the new system prompt
479+ if !self . context . conversation . is_empty ( ) {
480+ self . context . conversation [ 0 ] = msg;
481+ }
482+ Transition :: none ( )
483+ }
484+
437485 ( _, _) => Transition :: none ( ) ,
438486 }
439487 }
440488}
441489
490+ struct WrappedChatRole ( genai:: chat:: ChatRole ) ;
491+
492+ impl PartialEq for WrappedChatRole {
493+ fn eq ( & self , other : & Self ) -> bool {
494+ use genai:: chat:: ChatRole ;
495+
496+ fn eq_chat_role ( a : & ChatRole , b : & ChatRole ) -> bool {
497+ matches ! (
498+ ( a, b) ,
499+ ( ChatRole :: System , ChatRole :: System )
500+ | ( ChatRole :: User , ChatRole :: User )
501+ | ( ChatRole :: Assistant , ChatRole :: Assistant )
502+ | ( ChatRole :: Tool , ChatRole :: Tool )
503+ )
504+ }
505+
506+ eq_chat_role ( & self . 0 , & other. 0 )
507+ }
508+ }
509+
442510// ============================================================================
443511// Tests
444512// ============================================================================
@@ -463,19 +531,19 @@ mod tests {
463531
464532 #[ test]
465533 fn idle_to_sending_on_user_message ( ) {
466- let mut agent = Agent :: new ( ) ;
534+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
467535 assert_eq ! ( agent. state( ) , & State :: Idle ) ;
468536
469537 let t = agent. handle ( Event :: UserMessage ( user_msg ( "hello" ) ) ) ;
470538
471539 assert_eq ! ( agent. state( ) , & State :: Sending ) ;
472540 assert_eq ! ( t. effects, vec![ Effect :: StartRequest ] ) ;
473- assert_eq ! ( agent. context( ) . conversation. len( ) , 1 ) ;
541+ assert_eq ! ( agent. context( ) . conversation. len( ) , 2 ) ; // system + user
474542 }
475543
476544 #[ test]
477545 fn messages_queued_during_sending ( ) {
478- let mut agent = Agent :: new ( ) ;
546+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
479547 agent. handle ( Event :: UserMessage ( user_msg ( "first" ) ) ) ;
480548
481549 let t = agent. handle ( Event :: UserMessage ( user_msg ( "second" ) ) ) ;
@@ -487,7 +555,7 @@ mod tests {
487555
488556 #[ test]
489557 fn sending_to_streaming_on_stream_start ( ) {
490- let mut agent = Agent :: new ( ) ;
558+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
491559 agent. handle ( Event :: UserMessage ( user_msg ( "hello" ) ) ) ;
492560
493561 let t = agent. handle ( Event :: StreamStart ) ;
@@ -498,7 +566,7 @@ mod tests {
498566
499567 #[ test]
500568 fn streaming_emits_chunks ( ) {
501- let mut agent = Agent :: new ( ) ;
569+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
502570 agent. handle ( Event :: UserMessage ( user_msg ( "hello" ) ) ) ;
503571 agent. handle ( Event :: StreamStart ) ;
504572
@@ -518,7 +586,7 @@ mod tests {
518586
519587 #[ test]
520588 fn stream_end_no_tools_goes_idle ( ) {
521- let mut agent = Agent :: new ( ) ;
589+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
522590 agent. handle ( Event :: UserMessage ( user_msg ( "hello" ) ) ) ;
523591 agent. handle ( Event :: StreamStart ) ;
524592 agent. handle ( Event :: StreamChunk ( StreamChunk {
@@ -529,12 +597,12 @@ mod tests {
529597
530598 assert_eq ! ( agent. state( ) , & State :: Idle ) ;
531599 assert_eq ! ( t. effects, vec![ Effect :: ResponseComplete ] ) ;
532- assert_eq ! ( agent. context( ) . conversation. len( ) , 2 ) ; // user + assistant
600+ assert_eq ! ( agent. context( ) . conversation. len( ) , 3 ) ; // system + user + assistant
533601 }
534602
535603 #[ test]
536604 fn stream_end_with_queued_message_goes_sending ( ) {
537- let mut agent = Agent :: new ( ) ;
605+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
538606 agent. handle ( Event :: UserMessage ( user_msg ( "first" ) ) ) ;
539607 agent. handle ( Event :: UserMessage ( user_msg ( "second" ) ) ) ; // queued
540608 agent. handle ( Event :: StreamStart ) ;
@@ -550,12 +618,12 @@ mod tests {
550618 vec![ Effect :: ResponseComplete , Effect :: StartRequest ]
551619 ) ;
552620 assert ! ( agent. context( ) . queued_messages. is_empty( ) ) ;
553- assert_eq ! ( agent. context( ) . conversation. len( ) , 3 ) ; // first + response + second
621+ assert_eq ! ( agent. context( ) . conversation. len( ) , 4 ) ; // system + first + response + second
554622 }
555623
556624 #[ test]
557625 fn stream_end_with_tools_goes_to_pending_tools ( ) {
558- let mut agent = Agent :: new ( ) ;
626+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
559627 agent. handle ( Event :: UserMessage ( user_msg ( "what time is it?" ) ) ) ;
560628 agent. handle ( Event :: StreamStart ) ;
561629 agent. handle ( Event :: StreamChunk ( StreamChunk {
@@ -580,7 +648,7 @@ mod tests {
580648
581649 #[ test]
582650 fn tool_result_completes_and_starts_request ( ) {
583- let mut agent = Agent :: new ( ) ;
651+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
584652 agent. handle ( Event :: UserMessage ( user_msg ( "what time is it?" ) ) ) ;
585653 agent. handle ( Event :: StreamStart ) ;
586654 agent. handle ( Event :: StreamEnd {
@@ -600,13 +668,13 @@ mod tests {
600668 assert ! ( agent. context( ) . pending_tools. is_empty( ) ) ;
601669 // Tool results are now pushed to conversation as ToolResponse messages
602670 assert ! ( agent. context( ) . tool_results. is_empty( ) ) ;
603- // Conversation: user msg, assistant msg (with tool call), tool response
604- assert_eq ! ( agent. context( ) . conversation. len( ) , 3 ) ;
671+ // Conversation: system, user msg, assistant msg (with tool call), tool response
672+ assert_eq ! ( agent. context( ) . conversation. len( ) , 4 ) ;
605673 }
606674
607675 #[ test]
608676 fn multiple_tools_waits_for_all ( ) {
609- let mut agent = Agent :: new ( ) ;
677+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
610678 agent. handle ( Event :: UserMessage ( user_msg ( "query" ) ) ) ;
611679 agent. handle ( Event :: StreamStart ) ;
612680 agent. handle ( Event :: StreamEnd {
@@ -640,7 +708,7 @@ mod tests {
640708
641709 #[ test]
642710 fn cancel_from_streaming ( ) {
643- let mut agent = Agent :: new ( ) ;
711+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
644712 agent. handle ( Event :: UserMessage ( user_msg ( "hello" ) ) ) ;
645713 agent. handle ( Event :: StreamStart ) ;
646714
@@ -652,7 +720,7 @@ mod tests {
652720
653721 #[ test]
654722 fn request_failed_goes_idle ( ) {
655- let mut agent = Agent :: new ( ) ;
723+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
656724 agent. handle ( Event :: UserMessage ( user_msg ( "hello" ) ) ) ;
657725
658726 let t = agent. handle ( Event :: RequestFailed {
@@ -670,7 +738,7 @@ mod tests {
670738
671739 #[ test]
672740 fn assistant_message_contains_tool_calls ( ) {
673- let mut agent = Agent :: new ( ) ;
741+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
674742 agent. handle ( Event :: UserMessage ( user_msg ( "what time is it?" ) ) ) ;
675743 agent. handle ( Event :: StreamStart ) ;
676744 agent. handle ( Event :: StreamChunk ( StreamChunk {
@@ -683,7 +751,8 @@ mod tests {
683751 } ) ;
684752
685753 // Check the assistant message in conversation has both text and tool call
686- let assistant_msg = & agent. context ( ) . conversation [ 1 ] ;
754+ // conversation[0] = system, [1] = user, [2] = assistant
755+ let assistant_msg = & agent. context ( ) . conversation [ 2 ] ;
687756 assert ! ( matches!( assistant_msg. role, ChatRole :: Assistant ) ) ;
688757 let parts = assistant_msg. content . clone ( ) . into_parts ( ) ;
689758 assert_eq ! ( parts. len( ) , 2 ) ;
@@ -695,7 +764,7 @@ mod tests {
695764
696765 #[ test]
697766 fn queued_messages_stay_queued_until_tools_complete ( ) {
698- let mut agent = Agent :: new ( ) ;
767+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
699768 agent. handle ( Event :: UserMessage ( user_msg ( "first" ) ) ) ;
700769 agent. handle ( Event :: UserMessage ( user_msg ( "second" ) ) ) ; // queued
701770 agent. handle ( Event :: StreamStart ) ;
@@ -711,8 +780,8 @@ mod tests {
711780 // Messages stay queued until tools complete
712781 assert_eq ! ( agent. state( ) , & State :: PendingTools ) ;
713782 assert_eq ! ( agent. context( ) . queued_messages. len( ) , 1 ) ;
714- // Conversation: first msg, assistant msg
715- assert_eq ! ( agent. context( ) . conversation. len( ) , 2 ) ;
783+ // Conversation: system, first msg, assistant msg
784+ assert_eq ! ( agent. context( ) . conversation. len( ) , 3 ) ;
716785
717786 // Tool completes - now queued messages are drained
718787 agent. handle ( Event :: ToolResult ( ToolResult {
@@ -722,16 +791,16 @@ mod tests {
722791
723792 assert_eq ! ( agent. state( ) , & State :: Sending ) ;
724793 assert ! ( agent. context( ) . queued_messages. is_empty( ) ) ;
725- // Conversation: first msg, assistant msg, tool response, second msg
726- assert_eq ! ( agent. context( ) . conversation. len( ) , 4 ) ;
727- // Verify the fourth message is the queued "second"
728- let fourth_msg = & agent. context ( ) . conversation [ 3 ] ;
729- assert ! ( matches!( fourth_msg . role, ChatRole :: User ) ) ;
794+ // Conversation: system, first msg, assistant msg, tool response, second msg
795+ assert_eq ! ( agent. context( ) . conversation. len( ) , 5 ) ;
796+ // Verify the fifth message is the queued "second"
797+ let fifth_msg = & agent. context ( ) . conversation [ 4 ] ;
798+ assert ! ( matches!( fifth_msg . role, ChatRole :: User ) ) ;
730799 }
731800
732801 #[ test]
733802 fn user_message_during_pending_tools_gets_queued ( ) {
734- let mut agent = Agent :: new ( ) ;
803+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
735804 agent. handle ( Event :: UserMessage ( user_msg ( "query" ) ) ) ;
736805 agent. handle ( Event :: StreamStart ) ;
737806 agent. handle ( Event :: StreamEnd {
@@ -761,13 +830,13 @@ mod tests {
761830 vec![ Effect :: ToolResultReceived , Effect :: StartRequest ]
762831 ) ;
763832 assert ! ( agent. context( ) . queued_messages. is_empty( ) ) ;
764- // Conversation: user msg, assistant msg, tool response, queued msg
765- assert_eq ! ( agent. context( ) . conversation. len( ) , 4 ) ;
833+ // Conversation: system, user msg, assistant msg, tool response, queued msg
834+ assert_eq ! ( agent. context( ) . conversation. len( ) , 5 ) ;
766835 }
767836
768837 #[ test]
769838 fn cancel_pending_tools_generates_cancelled_responses ( ) {
770- let mut agent = Agent :: new ( ) ;
839+ let mut agent = Agent :: new ( ChatMessage :: system ( "you are a helpful assistant" ) ) ;
771840 agent. handle ( Event :: UserMessage ( user_msg ( "query" ) ) ) ;
772841 agent. handle ( Event :: StreamStart ) ;
773842 agent. handle ( Event :: StreamEnd {
@@ -787,7 +856,7 @@ mod tests {
787856 assert_eq ! ( t. effects, vec![ Effect :: Cancelled ] ) ;
788857 assert ! ( agent. context( ) . pending_tools. is_empty( ) ) ;
789858 // Tool results were pushed to conversation as error responses
790- // Conversation: user msg, assistant msg, tool response, tool response
791- assert_eq ! ( agent. context( ) . conversation. len( ) , 4 ) ;
859+ // Conversation: system, user msg, assistant msg, tool response, tool response
860+ assert_eq ! ( agent. context( ) . conversation. len( ) , 5 ) ;
792861 }
793862}
0 commit comments