diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index a04b05cb38..c18faae30b 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -2382,6 +2382,9 @@ use crate::features::Features; #[cfg(test)] pub(crate) use tests::make_session_and_context; +#[cfg(test)] +pub(crate) use tests::make_session_and_context_with_rx; + #[cfg(test)] mod tests { use super::*; @@ -2658,7 +2661,7 @@ mod tests { // Like make_session_and_context, but returns Arc and the event receiver // so tests can assert on emitted events. - fn make_session_and_context_with_rx() -> ( + pub(crate) fn make_session_and_context_with_rx() -> ( Arc, Arc, async_channel::Receiver, diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index 4cb4d4a06a..796331d1e8 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -13,6 +13,8 @@ use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; use codex_protocol::protocol::Submission; use codex_protocol::user_input::UserInput; +use std::time::Duration; +use tokio::time::timeout; use tokio_util::sync::CancellationToken; use crate::AuthManager; @@ -60,14 +62,13 @@ pub(crate) async fn run_codex_conversation_interactive( let parent_ctx_clone = Arc::clone(&parent_ctx); let codex_for_events = Arc::clone(&codex); tokio::spawn(async move { - let _ = forward_events( + forward_events( codex_for_events, tx_sub, parent_session_clone, parent_ctx_clone, - cancel_token_events.clone(), + cancel_token_events, ) - .or_cancel(&cancel_token_events) .await; }); @@ -156,53 +157,92 @@ async fn forward_events( parent_ctx: Arc, cancel_token: CancellationToken, ) { - while let Ok(event) = codex.next_event().await { - match event { - // ignore all legacy delta events - Event { - id: _, - msg: EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_), - } => continue, - Event { - id: _, - msg: EventMsg::SessionConfigured(_), - } => continue, - Event { - id, - msg: EventMsg::ExecApprovalRequest(event), - } => { - // Initiate approval via parent session; do not surface to consumer. - handle_exec_approval( - &codex, - id, - &parent_session, - &parent_ctx, - event, - &cancel_token, - ) - .await; - } - Event { - id, - msg: EventMsg::ApplyPatchApprovalRequest(event), - } => { - handle_patch_approval( - &codex, - id, - &parent_session, - &parent_ctx, - event, - &cancel_token, - ) - .await; + let cancelled = cancel_token.cancelled(); + tokio::pin!(cancelled); + + loop { + tokio::select! { + _ = &mut cancelled => { + shutdown_delegate(&codex).await; + break; } - other => { - let _ = tx_sub.send(other).await; + event = codex.next_event() => { + let event = match event { + Ok(event) => event, + Err(_) => break, + }; + match event { + // ignore all legacy delta events + Event { + id: _, + msg: EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_), + } => {} + Event { + id: _, + msg: EventMsg::SessionConfigured(_), + } => {} + Event { + id, + msg: EventMsg::ExecApprovalRequest(event), + } => { + // Initiate approval via parent session; do not surface to consumer. + handle_exec_approval( + &codex, + id, + &parent_session, + &parent_ctx, + event, + &cancel_token, + ) + .await; + } + Event { + id, + msg: EventMsg::ApplyPatchApprovalRequest(event), + } => { + handle_patch_approval( + &codex, + id, + &parent_session, + &parent_ctx, + event, + &cancel_token, + ) + .await; + } + other => { + match tx_sub.send(other).or_cancel(&cancel_token).await { + Ok(Ok(())) => {} + _ => { + shutdown_delegate(&codex).await; + break; + } + } + } + } } } } } +/// Ask the delegate to stop and drain its events so background sends do not hit a closed channel. +async fn shutdown_delegate(codex: &Codex) { + let _ = codex.submit(Op::Interrupt).await; + let _ = codex.submit(Op::Shutdown {}).await; + + let _ = timeout(Duration::from_millis(500), async { + while let Ok(event) = codex.next_event().await { + if matches!( + event.msg, + EventMsg::TurnAborted(_) | EventMsg::TaskComplete(_) + ) { + break; + } + } + }) + .await; +} + /// Forward ops from a caller to a sub-agent, respecting cancellation. async fn forward_ops( codex: Arc, @@ -298,3 +338,85 @@ where } } } + +#[cfg(test)] +mod tests { + use super::*; + use async_channel::bounded; + use codex_protocol::models::ResponseItem; + use codex_protocol::protocol::RawResponseItemEvent; + use codex_protocol::protocol::TurnAbortReason; + use codex_protocol::protocol::TurnAbortedEvent; + use pretty_assertions::assert_eq; + + #[tokio::test] + async fn forward_events_cancelled_while_send_blocked_shuts_down_delegate() { + let (tx_events, rx_events) = bounded(1); + let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY); + let codex = Arc::new(Codex { + next_id: AtomicU64::new(0), + tx_sub, + rx_event: rx_events, + }); + + let (session, ctx, _rx_evt) = crate::codex::make_session_and_context_with_rx(); + + let (tx_out, rx_out) = bounded(1); + tx_out + .send(Event { + id: "full".to_string(), + msg: EventMsg::TurnAborted(TurnAbortedEvent { + reason: TurnAbortReason::Interrupted, + }), + }) + .await + .unwrap(); + + let cancel = CancellationToken::new(); + let forward = tokio::spawn(forward_events( + Arc::clone(&codex), + tx_out.clone(), + session, + ctx, + cancel.clone(), + )); + + tx_events + .send(Event { + id: "evt".to_string(), + msg: EventMsg::RawResponseItem(RawResponseItemEvent { + item: ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "call-1".to_string(), + name: "tool".to_string(), + input: "{}".to_string(), + }, + }), + }) + .await + .unwrap(); + + drop(tx_events); + cancel.cancel(); + timeout(std::time::Duration::from_millis(1000), forward) + .await + .expect("forward_events hung") + .expect("forward_events join error"); + + let received = rx_out.recv().await.expect("prefilled event missing"); + assert_eq!("full", received.id); + let mut ops = Vec::new(); + while let Ok(sub) = rx_sub.try_recv() { + ops.push(sub.op); + } + assert!( + ops.iter().any(|op| matches!(op, Op::Interrupt)), + "expected Interrupt op after cancellation" + ); + assert!( + ops.iter().any(|op| matches!(op, Op::Shutdown)), + "expected Shutdown op after cancellation" + ); + } +}