Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 207 additions & 28 deletions codex-rs/app-server-test-client/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,14 @@ use codex_app_server_protocol::SendUserMessageParams;
use codex_app_server_protocol::SendUserMessageResponse;
use codex_app_server_protocol::ServerNotification;
use codex_app_server_protocol::ServerRequest;
use codex_app_server_protocol::ThreadItem;
use codex_app_server_protocol::ThreadResumeParams;
use codex_app_server_protocol::ThreadResumeResponse;
use codex_app_server_protocol::ThreadRollbackParams;
use codex_app_server_protocol::ThreadRollbackResponse;
use codex_app_server_protocol::ThreadStartParams;
use codex_app_server_protocol::ThreadStartResponse;
use codex_app_server_protocol::Turn;
use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::TurnStartResponse;
use codex_app_server_protocol::TurnStatus;
Expand Down Expand Up @@ -113,6 +119,9 @@ enum CliCommand {
TestLogin,
/// Fetch the current account rate limits from the Codex app-server.
GetAccountRateLimits,
/// Send multiple turns, roll back the most recent turn, and verify the thread history changed.
#[command(name = "thread-rollback")]
ThreadRollback,
}

fn main() -> Result<()> {
Expand All @@ -134,6 +143,7 @@ fn main() -> Result<()> {
} => send_follow_up_v2(codex_bin, first_message, follow_up_message),
CliCommand::TestLogin => test_login(codex_bin),
CliCommand::GetAccountRateLimits => get_account_rate_limits(codex_bin),
CliCommand::ThreadRollback => thread_rollback(codex_bin),
}
}

Expand Down Expand Up @@ -213,10 +223,7 @@ fn send_message_v2_with_policies(
turn_params.approval_policy = approval_policy;
turn_params.sandbox_policy = sandbox_policy;

let turn_response = client.turn_start(turn_params)?;
println!("< turn/start response: {turn_response:?}");

client.stream_turn(&thread_response.thread.id, &turn_response.turn.id)?;
let _ = client.run_turn(turn_params)?;

Ok(())
}
Expand All @@ -234,27 +241,8 @@ fn send_follow_up_v2(
let thread_response = client.thread_start(ThreadStartParams::default())?;
println!("< thread/start response: {thread_response:?}");

let first_turn_params = TurnStartParams {
thread_id: thread_response.thread.id.clone(),
input: vec![V2UserInput::Text {
text: first_message,
}],
..Default::default()
};
let first_turn_response = client.turn_start(first_turn_params)?;
println!("< turn/start response (initial): {first_turn_response:?}");
client.stream_turn(&thread_response.thread.id, &first_turn_response.turn.id)?;

let follow_up_params = TurnStartParams {
thread_id: thread_response.thread.id.clone(),
input: vec![V2UserInput::Text {
text: follow_up_message,
}],
..Default::default()
};
let follow_up_response = client.turn_start(follow_up_params)?;
println!("< turn/start response (follow-up): {follow_up_response:?}");
client.stream_turn(&thread_response.thread.id, &follow_up_response.turn.id)?;
let _ = client.run_turn_text(&thread_response.thread.id, first_message)?;
let _ = client.run_turn_text(&thread_response.thread.id, follow_up_message)?;

Ok(())
}
Expand Down Expand Up @@ -301,6 +289,143 @@ fn get_account_rate_limits(codex_bin: String) -> Result<()> {
Ok(())
}

fn thread_rollback(codex_bin: String) -> Result<()> {
let codex_bin_resume = codex_bin.clone();
let mut client = CodexClient::spawn(codex_bin)?;

let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");

let thread_response = client.thread_start(ThreadStartParams::default())?;
println!("< thread/start response: {thread_response:?}");
let thread_id = thread_response.thread.id;

let _ = client.run_turn_text(&thread_id, "Say pineapple")?;
let _ = client.run_turn_text(&thread_id, "Say banana")?;

let rollback_response = client.thread_rollback(ThreadRollbackParams {
thread_id: thread_id.clone(),
num_turns: 1,
})?;
println!("< thread/rollback response: {rollback_response:?}");

let answer = client
.run_turn_text(&thread_id, "What was the last word you said?")?
.context("turn completed without an agent message item")?;

if answer.to_lowercase().contains("pineapple") {
println!("Rollback success!");
} else {
println!("Rollback did not work as expected!");
}

let mut resume_client = CodexClient::spawn(codex_bin_resume)?;
let initialize = resume_client.initialize()?;
println!("< initialize response (resume client): {initialize:?}");

let resume_response = resume_client.thread_resume(ThreadResumeParams {
thread_id: thread_id.clone(),
..Default::default()
})?;
println!("< thread/resume response: {resume_response:?}");

verify_resumed_thread_after_rollback(
&resume_response,
"Say pineapple",
"Say banana",
"What was the last word you said?",
"pineapple",
)?;
println!("Resume verification success!");

Ok(())
}

fn verify_resumed_thread_after_rollback(
resume: &ThreadResumeResponse,
expected_first_prompt: &str,
rolled_back_prompt: &str,
expected_follow_up_prompt: &str,
expected_word: &str,
) -> Result<()> {
let mut saw_expected_first_turn = false;
let mut saw_expected_follow_up_turn = false;

for turn in &resume.thread.turns {
let user_messages = turn_user_messages(turn);
let agent_messages = turn_agent_messages(turn);

for user_message in &user_messages {
if user_message.contains(rolled_back_prompt) {
bail!(
"thread/resume returned a rolled back prompt: {rolled_back_prompt:?} (thread {})",
resume.thread.id
);
}
}

if user_messages
.iter()
.any(|message| message.contains(expected_first_prompt))
&& agent_messages
.iter()
.any(|message| message.to_lowercase().contains(expected_word))
{
saw_expected_first_turn = true;
}

if user_messages
.iter()
.any(|message| message.contains(expected_follow_up_prompt))
&& agent_messages
.iter()
.any(|message| message.to_lowercase().contains(expected_word))
{
saw_expected_follow_up_turn = true;
}
}

if !saw_expected_first_turn {
bail!(
"thread/resume did not include expected prompt {expected_first_prompt:?} with answer containing {expected_word:?}"
);
}

if !saw_expected_follow_up_turn {
bail!(
"thread/resume did not include expected prompt {expected_follow_up_prompt:?} with answer containing {expected_word:?}"
);
}

Ok(())
}

fn turn_user_messages(turn: &Turn) -> Vec<String> {
turn.items
.iter()
.filter_map(|item| match item {
ThreadItem::UserMessage { content, .. } => Some(content),
_ => None,
})
.flat_map(|content| {
content.iter().filter_map(|input| match input {
V2UserInput::Text { text } => Some(text.clone()),
_ => None,
})
})
.collect()
}

fn turn_agent_messages(turn: &Turn) -> Vec<String> {
turn.items
.iter()
.filter_map(|item| match item {
ThreadItem::AgentMessage { text, .. } => Some(text.clone()),
_ => None,
})
.collect()
}

struct CodexClient {
child: Child,
stdin: Option<ChildStdin>,
Expand Down Expand Up @@ -422,6 +547,16 @@ impl CodexClient {
self.send_request(request, request_id, "thread/start")
}

fn thread_resume(&mut self, params: ThreadResumeParams) -> Result<ThreadResumeResponse> {
let request_id = self.request_id();
let request = ClientRequest::ThreadResume {
request_id: request_id.clone(),
params,
};

self.send_request(request, request_id, "thread/resume")
}

fn turn_start(&mut self, params: TurnStartParams) -> Result<TurnStartResponse> {
let request_id = self.request_id();
let request = ClientRequest::TurnStart {
Expand All @@ -432,6 +567,39 @@ impl CodexClient {
self.send_request(request, request_id, "turn/start")
}

fn run_turn(&mut self, params: TurnStartParams) -> Result<Option<String>> {
let thread_id = params.thread_id.clone();
let turn_response = self.turn_start(params)?;
println!("< turn/start response: {turn_response:?}");
self.stream_turn(&thread_id, &turn_response.turn.id)
}

fn run_turn_text(
&mut self,
thread_id: &str,
user_message: impl Into<String>,
) -> Result<Option<String>> {
let turn_params = TurnStartParams {
thread_id: thread_id.to_string(),
input: vec![V2UserInput::Text {
text: user_message.into(),
}],
..Default::default()
};

self.run_turn(turn_params)
}

fn thread_rollback(&mut self, params: ThreadRollbackParams) -> Result<ThreadRollbackResponse> {
let request_id = self.request_id();
let request = ClientRequest::ThreadRollback {
request_id: request_id.clone(),
params,
};

self.send_request(request, request_id, "thread/rollback")
}

fn login_chat_gpt(&mut self) -> Result<LoginChatGptResponse> {
let request_id = self.request_id();
let request = ClientRequest::LoginChatGpt {
Expand Down Expand Up @@ -526,7 +694,9 @@ impl CodexClient {
}
}

fn stream_turn(&mut self, thread_id: &str, turn_id: &str) -> Result<()> {
fn stream_turn(&mut self, thread_id: &str, turn_id: &str) -> Result<Option<String>> {
let mut last_agent_message = None::<String>;

loop {
let notification = self.next_notification()?;

Expand Down Expand Up @@ -561,7 +731,16 @@ impl CodexClient {
println!("\n< item started: {:?}", payload.item);
}
ServerNotification::ItemCompleted(payload) => {
println!("< item completed: {:?}", payload.item);
if payload.thread_id == thread_id && payload.turn_id == turn_id {
if let ThreadItem::AgentMessage { text, .. } = payload.item {
last_agent_message = Some(text);
println!("< agent message completed >");
} else {
println!("< item completed: {:?}", payload.item);
}
} else {
println!("< item completed: {:?}", payload.item);
}
}
ServerNotification::TurnCompleted(payload) => {
if payload.turn.id == turn_id {
Expand All @@ -583,7 +762,7 @@ impl CodexClient {
}
}

Ok(())
Ok(last_agent_message)
}

fn extract_event(
Expand Down
Loading