diff --git a/codex-rs/app-server-test-client/src/main.rs b/codex-rs/app-server-test-client/src/main.rs index b66c59d55a..166acb3cc4 100644 --- a/codex-rs/app-server-test-client/src/main.rs +++ b/codex-rs/app-server-test-client/src/main.rs @@ -43,8 +43,14 @@ use codex_app_server_protocol::SendUserMessageParams; use codex_app_server_protocol::SendUserMessageResponse; use codex_app_server_protocol::ServerNotification; use codex_app_server_protocol::ServerRequest; +use codex_app_server_protocol::ThreadItem; +use codex_app_server_protocol::ThreadResumeParams; +use codex_app_server_protocol::ThreadResumeResponse; +use codex_app_server_protocol::ThreadRollbackParams; +use codex_app_server_protocol::ThreadRollbackResponse; use codex_app_server_protocol::ThreadStartParams; use codex_app_server_protocol::ThreadStartResponse; +use codex_app_server_protocol::Turn; use codex_app_server_protocol::TurnStartParams; use codex_app_server_protocol::TurnStartResponse; use codex_app_server_protocol::TurnStatus; @@ -113,6 +119,9 @@ enum CliCommand { TestLogin, /// Fetch the current account rate limits from the Codex app-server. GetAccountRateLimits, + /// Send multiple turns, roll back the most recent turn, and verify the thread history changed. + #[command(name = "thread-rollback")] + ThreadRollback, } fn main() -> Result<()> { @@ -134,6 +143,7 @@ fn main() -> Result<()> { } => send_follow_up_v2(codex_bin, first_message, follow_up_message), CliCommand::TestLogin => test_login(codex_bin), CliCommand::GetAccountRateLimits => get_account_rate_limits(codex_bin), + CliCommand::ThreadRollback => thread_rollback(codex_bin), } } @@ -213,10 +223,7 @@ fn send_message_v2_with_policies( turn_params.approval_policy = approval_policy; turn_params.sandbox_policy = sandbox_policy; - let turn_response = client.turn_start(turn_params)?; - println!("< turn/start response: {turn_response:?}"); - - client.stream_turn(&thread_response.thread.id, &turn_response.turn.id)?; + let _ = client.run_turn(turn_params)?; Ok(()) } @@ -234,27 +241,8 @@ fn send_follow_up_v2( let thread_response = client.thread_start(ThreadStartParams::default())?; println!("< thread/start response: {thread_response:?}"); - let first_turn_params = TurnStartParams { - thread_id: thread_response.thread.id.clone(), - input: vec![V2UserInput::Text { - text: first_message, - }], - ..Default::default() - }; - let first_turn_response = client.turn_start(first_turn_params)?; - println!("< turn/start response (initial): {first_turn_response:?}"); - client.stream_turn(&thread_response.thread.id, &first_turn_response.turn.id)?; - - let follow_up_params = TurnStartParams { - thread_id: thread_response.thread.id.clone(), - input: vec![V2UserInput::Text { - text: follow_up_message, - }], - ..Default::default() - }; - let follow_up_response = client.turn_start(follow_up_params)?; - println!("< turn/start response (follow-up): {follow_up_response:?}"); - client.stream_turn(&thread_response.thread.id, &follow_up_response.turn.id)?; + let _ = client.run_turn_text(&thread_response.thread.id, first_message)?; + let _ = client.run_turn_text(&thread_response.thread.id, follow_up_message)?; Ok(()) } @@ -301,6 +289,143 @@ fn get_account_rate_limits(codex_bin: String) -> Result<()> { Ok(()) } +fn thread_rollback(codex_bin: String) -> Result<()> { + let codex_bin_resume = codex_bin.clone(); + let mut client = CodexClient::spawn(codex_bin)?; + + let initialize = client.initialize()?; + println!("< initialize response: {initialize:?}"); + + let thread_response = client.thread_start(ThreadStartParams::default())?; + println!("< thread/start response: {thread_response:?}"); + let thread_id = thread_response.thread.id; + + let _ = client.run_turn_text(&thread_id, "Say pineapple")?; + let _ = client.run_turn_text(&thread_id, "Say banana")?; + + let rollback_response = client.thread_rollback(ThreadRollbackParams { + thread_id: thread_id.clone(), + num_turns: 1, + })?; + println!("< thread/rollback response: {rollback_response:?}"); + + let answer = client + .run_turn_text(&thread_id, "What was the last word you said?")? + .context("turn completed without an agent message item")?; + + if answer.to_lowercase().contains("pineapple") { + println!("Rollback success!"); + } else { + println!("Rollback did not work as expected!"); + } + + let mut resume_client = CodexClient::spawn(codex_bin_resume)?; + let initialize = resume_client.initialize()?; + println!("< initialize response (resume client): {initialize:?}"); + + let resume_response = resume_client.thread_resume(ThreadResumeParams { + thread_id: thread_id.clone(), + ..Default::default() + })?; + println!("< thread/resume response: {resume_response:?}"); + + verify_resumed_thread_after_rollback( + &resume_response, + "Say pineapple", + "Say banana", + "What was the last word you said?", + "pineapple", + )?; + println!("Resume verification success!"); + + Ok(()) +} + +fn verify_resumed_thread_after_rollback( + resume: &ThreadResumeResponse, + expected_first_prompt: &str, + rolled_back_prompt: &str, + expected_follow_up_prompt: &str, + expected_word: &str, +) -> Result<()> { + let mut saw_expected_first_turn = false; + let mut saw_expected_follow_up_turn = false; + + for turn in &resume.thread.turns { + let user_messages = turn_user_messages(turn); + let agent_messages = turn_agent_messages(turn); + + for user_message in &user_messages { + if user_message.contains(rolled_back_prompt) { + bail!( + "thread/resume returned a rolled back prompt: {rolled_back_prompt:?} (thread {})", + resume.thread.id + ); + } + } + + if user_messages + .iter() + .any(|message| message.contains(expected_first_prompt)) + && agent_messages + .iter() + .any(|message| message.to_lowercase().contains(expected_word)) + { + saw_expected_first_turn = true; + } + + if user_messages + .iter() + .any(|message| message.contains(expected_follow_up_prompt)) + && agent_messages + .iter() + .any(|message| message.to_lowercase().contains(expected_word)) + { + saw_expected_follow_up_turn = true; + } + } + + if !saw_expected_first_turn { + bail!( + "thread/resume did not include expected prompt {expected_first_prompt:?} with answer containing {expected_word:?}" + ); + } + + if !saw_expected_follow_up_turn { + bail!( + "thread/resume did not include expected prompt {expected_follow_up_prompt:?} with answer containing {expected_word:?}" + ); + } + + Ok(()) +} + +fn turn_user_messages(turn: &Turn) -> Vec { + turn.items + .iter() + .filter_map(|item| match item { + ThreadItem::UserMessage { content, .. } => Some(content), + _ => None, + }) + .flat_map(|content| { + content.iter().filter_map(|input| match input { + V2UserInput::Text { text } => Some(text.clone()), + _ => None, + }) + }) + .collect() +} + +fn turn_agent_messages(turn: &Turn) -> Vec { + turn.items + .iter() + .filter_map(|item| match item { + ThreadItem::AgentMessage { text, .. } => Some(text.clone()), + _ => None, + }) + .collect() +} + struct CodexClient { child: Child, stdin: Option, @@ -422,6 +547,16 @@ impl CodexClient { self.send_request(request, request_id, "thread/start") } + fn thread_resume(&mut self, params: ThreadResumeParams) -> Result { + let request_id = self.request_id(); + let request = ClientRequest::ThreadResume { + request_id: request_id.clone(), + params, + }; + + self.send_request(request, request_id, "thread/resume") + } + fn turn_start(&mut self, params: TurnStartParams) -> Result { let request_id = self.request_id(); let request = ClientRequest::TurnStart { @@ -432,6 +567,39 @@ impl CodexClient { self.send_request(request, request_id, "turn/start") } + fn run_turn(&mut self, params: TurnStartParams) -> Result> { + let thread_id = params.thread_id.clone(); + let turn_response = self.turn_start(params)?; + println!("< turn/start response: {turn_response:?}"); + self.stream_turn(&thread_id, &turn_response.turn.id) + } + + fn run_turn_text( + &mut self, + thread_id: &str, + user_message: impl Into, + ) -> Result> { + let turn_params = TurnStartParams { + thread_id: thread_id.to_string(), + input: vec![V2UserInput::Text { + text: user_message.into(), + }], + ..Default::default() + }; + + self.run_turn(turn_params) + } + + fn thread_rollback(&mut self, params: ThreadRollbackParams) -> Result { + let request_id = self.request_id(); + let request = ClientRequest::ThreadRollback { + request_id: request_id.clone(), + params, + }; + + self.send_request(request, request_id, "thread/rollback") + } + fn login_chat_gpt(&mut self) -> Result { let request_id = self.request_id(); let request = ClientRequest::LoginChatGpt { @@ -526,7 +694,9 @@ impl CodexClient { } } - fn stream_turn(&mut self, thread_id: &str, turn_id: &str) -> Result<()> { + fn stream_turn(&mut self, thread_id: &str, turn_id: &str) -> Result> { + let mut last_agent_message = None::; + loop { let notification = self.next_notification()?; @@ -561,7 +731,16 @@ impl CodexClient { println!("\n< item started: {:?}", payload.item); } ServerNotification::ItemCompleted(payload) => { - println!("< item completed: {:?}", payload.item); + if payload.thread_id == thread_id && payload.turn_id == turn_id { + if let ThreadItem::AgentMessage { text, .. } = payload.item { + last_agent_message = Some(text); + println!("< agent message completed >"); + } else { + println!("< item completed: {:?}", payload.item); + } + } else { + println!("< item completed: {:?}", payload.item); + } } ServerNotification::TurnCompleted(payload) => { if payload.turn.id == turn_id { @@ -583,7 +762,7 @@ impl CodexClient { } } - Ok(()) + Ok(last_agent_message) } fn extract_event(