diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index dae140bc02..052e6d4f11 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -636,3 +636,113 @@ impl AggregatedChatStream { Self::new(inner, AggregateMode::Streaming) } } + +#[cfg(test)] +mod adapter_tests { + use super::*; + use futures::StreamExt; + use futures::stream; + + fn ev_text_delta(s: &str) -> Result { + Ok(ResponseEvent::OutputTextDelta(s.to_string())) + } + + fn ev_completed() -> Result { + Ok(ResponseEvent::Completed { + response_id: String::new(), + token_usage: None, + }) + } + + // Helper kept for future tests; silence dead_code warning when unused. + #[allow(dead_code)] + fn ev_item_done_full_message(text: &str) -> Result { + Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + })) + } + + #[tokio::test] + async fn streaming_mode_forwards_text_deltas_and_emits_final_message() { + // Arrange a stream: two deltas, then completed + let src = stream::iter(vec![ + ev_text_delta("Hi "), + ev_text_delta("there"), + ev_completed(), + ]); + + let mut s = AggregatedChatStream::streaming_mode(src); + let mut out = Vec::new(); + while let Some(ev) = s.next().await { + match ev { + Ok(e) => out.push(e), + Err(_) => panic!("stream error"), + } + if out.len() > 4 { + break; + } + } + + // Expect: two deltas forwarded, then aggregated final message, then completed + assert!(matches!(out.get(0), Some(ResponseEvent::OutputTextDelta(d)) if d == "Hi ")); + assert!(matches!(out.get(1), Some(ResponseEvent::OutputTextDelta(d)) if d == "there")); + // The next should be a full assistant message with concatenated text + match out.get(2) { + Some(ResponseEvent::OutputItemDone(ResponseItem::Message { + content, role, .. + })) => { + assert_eq!(role, "assistant"); + let text = content.iter().find_map(|c| match c { + ContentItem::OutputText { text } => Some(text.clone()), + _ => None, + }); + assert_eq!(text.as_deref(), Some("Hi there")); + } + other => panic!("unexpected third event: {other:?}"), + } + assert!(matches!(out.get(3), Some(ResponseEvent::Completed { .. }))); + } + + #[tokio::test] + async fn aggregate_mode_suppresses_text_deltas_and_emits_only_final_message() { + // Arrange a stream: two deltas, then completed + let src = stream::iter(vec![ + ev_text_delta("Hi "), + ev_text_delta("there"), + ev_completed(), + ]); + + let mut s = AggregateStreamExt::aggregate(src); + let mut out = Vec::new(); + while let Some(ev) = s.next().await { + match ev { + Ok(e) => out.push(e), + Err(_) => panic!("stream error"), + } + if out.len() > 2 { + break; + } + } + + // Expect: no deltas, only final full message then completed + assert_eq!(out.len(), 2); + match out.get(0) { + Some(ResponseEvent::OutputItemDone(ResponseItem::Message { + content, role, .. + })) => { + assert_eq!(role, "assistant"); + let text = content.iter().find_map(|c| match c { + ContentItem::OutputText { text } => Some(text.clone()), + _ => None, + }); + assert_eq!(text.as_deref(), Some("Hi there")); + } + other => panic!("unexpected first event: {other:?}"), + } + assert!(matches!(out.get(1), Some(ResponseEvent::Completed { .. }))); + } +} diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 0caf1170a6..cb0818ae32 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -19,7 +19,6 @@ use tracing::trace; use tracing::warn; use uuid::Uuid; -use crate::chat_completions::AggregateStreamExt; use crate::chat_completions::stream_chat_completions; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; @@ -98,14 +97,11 @@ impl ModelClient { ) .await?; - // Wrap it with the aggregation adapter so callers see *only* - // the final assistant message per turn (matching the - // behaviour of the Responses API). - let mut aggregated = if self.config.show_raw_agent_reasoning { - crate::chat_completions::AggregatedChatStream::streaming_mode(response_stream) - } else { - response_stream.aggregate() - }; + // Use streaming mode so normal assistant text deltas are forwarded live + // while still accumulating to emit a single final OutputItemDone at turn end. + // Reasoning visibility remains controlled separately by config downstream. + let mut aggregated = + crate::chat_completions::AggregatedChatStream::streaming_mode(response_stream); // Bridge the aggregated stream back into a standard // `ResponseStream` by forwarding events through a channel. diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 3bf6288fdf..c67ade833f 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -246,6 +246,143 @@ impl Session { } } +#[cfg(test)] +mod mapper_tests { + use super::*; + use tempfile::TempDir; + + fn make_session(show_raw: bool) -> (Arc, async_channel::Receiver) { + use crate::model_provider_info::ModelProviderInfo; + use crate::model_provider_info::WireApi; + use std::sync::Arc as StdArc; + + let (tx_event, rx_event) = async_channel::unbounded::(); + + // Create a minimal Config suitable for tests + let codex_home = TempDir::new().expect("temp dir"); + let cfg: StdArc = StdArc::new( + Config::load_from_base_config_with_overrides( + crate::config::ConfigToml::default(), + crate::config::ConfigOverrides::default(), + codex_home.path().to_path_buf(), + ) + .expect("config"), + ); + let provider = ModelProviderInfo { + name: "test".to_string(), + base_url: None, + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(1000), + requires_openai_auth: false, + }; + let client = ModelClient::new( + cfg.clone(), + None, + provider, + cfg.model_reasoning_effort, + cfg.model_reasoning_summary, + uuid::Uuid::nil(), + ); + + let sess = Arc::new(Session { + client, + tx_event, + ctrl_c: Arc::new(Notify::new()), + cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + base_instructions: None, + user_instructions: None, + approval_policy: AskForApproval::OnRequest, + sandbox_policy: SandboxPolicy::new_read_only_policy(), + shell_environment_policy: ShellEnvironmentPolicy::default(), + writable_roots: Mutex::new(Vec::new()), + disable_response_storage: true, + tools_config: ToolsConfig::new( + &cfg.model_family, + AskForApproval::OnRequest, + SandboxPolicy::new_read_only_policy(), + cfg.include_plan_tool, + ), + mcp_connection_manager: McpConnectionManager::default(), + notify: None, + rollout: Mutex::new(None), + state: Mutex::new(State::default()), + codex_linux_sandbox_exe: None, + user_shell: shell::Shell::Unknown, + show_raw_agent_reasoning: show_raw, + }); + (sess, rx_event) + } + + #[tokio::test] + async fn normal_text_deltas_always_emit_agent_message_delta() { + let (sess, rx) = make_session(false); + let sub_id = "s".to_string(); + + let delta = "hi".to_string(); + { + let mut st = sess.state.lock().unwrap(); + st.history.append_assistant_text(&delta); + } + let event = Event { + id: sub_id.clone(), + msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }), + }; + sess.tx_event.send(event).await.ok(); + + let out = rx.recv().await.expect("no event"); + match out.msg { + EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => { + assert_eq!(delta, "hi"); + } + other => panic!("unexpected event: {other:?}"), + } + } + + #[tokio::test] + async fn reasoning_content_delta_gated_by_flag() { + let (sess_off, rx_off) = make_session(false); + let sub_id = "s".to_string(); + + let delta = "think".to_string(); + if sess_off.show_raw_agent_reasoning { + let event = Event { + id: sub_id.clone(), + msg: EventMsg::AgentReasoningRawContentDelta(AgentReasoningRawContentDeltaEvent { + delta: delta.clone(), + }), + }; + sess_off.tx_event.send(event).await.ok(); + } + assert!(rx_off.try_recv().is_err()); + + let (sess_on, rx_on) = make_session(true); + if sess_on.show_raw_agent_reasoning { + let event = Event { + id: sub_id.clone(), + msg: EventMsg::AgentReasoningRawContentDelta(AgentReasoningRawContentDeltaEvent { + delta: delta.clone(), + }), + }; + sess_on.tx_event.send(event).await.ok(); + } + let out = rx_on.recv().await.expect("no event"); + match out.msg { + EventMsg::AgentReasoningRawContentDelta(AgentReasoningRawContentDeltaEvent { + delta, + }) => { + assert_eq!(delta, "think"); + } + other => panic!("unexpected event: {other:?}"), + } + } +} /// Mutable state of the agent #[derive(Default)] struct State {