Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions codex-rs/core/src/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -636,3 +636,113 @@ impl<S> AggregatedChatStream<S> {
Self::new(inner, AggregateMode::Streaming)
}
}

#[cfg(test)]
mod adapter_tests {
use super::*;
use futures::StreamExt;
use futures::stream;

fn ev_text_delta(s: &str) -> Result<ResponseEvent> {
Ok(ResponseEvent::OutputTextDelta(s.to_string()))
}

fn ev_completed() -> Result<ResponseEvent> {
Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
})
}

// Helper kept for future tests; silence dead_code warning when unused.
#[allow(dead_code)]
fn ev_item_done_full_message(text: &str) -> Result<ResponseEvent> {
Ok(ResponseEvent::OutputItemDone(ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
}))
}

#[tokio::test]
async fn streaming_mode_forwards_text_deltas_and_emits_final_message() {
// Arrange a stream: two deltas, then completed
let src = stream::iter(vec![
ev_text_delta("Hi "),
ev_text_delta("there"),
ev_completed(),
]);

let mut s = AggregatedChatStream::streaming_mode(src);
let mut out = Vec::new();
while let Some(ev) = s.next().await {
match ev {
Ok(e) => out.push(e),
Err(_) => panic!("stream error"),
}
if out.len() > 4 {
break;
}
}

// Expect: two deltas forwarded, then aggregated final message, then completed
assert!(matches!(out.get(0), Some(ResponseEvent::OutputTextDelta(d)) if d == "Hi "));
assert!(matches!(out.get(1), Some(ResponseEvent::OutputTextDelta(d)) if d == "there"));
// The next should be a full assistant message with concatenated text
match out.get(2) {
Some(ResponseEvent::OutputItemDone(ResponseItem::Message {
content, role, ..
})) => {
assert_eq!(role, "assistant");
let text = content.iter().find_map(|c| match c {
ContentItem::OutputText { text } => Some(text.clone()),
_ => None,
});
assert_eq!(text.as_deref(), Some("Hi there"));
}
other => panic!("unexpected third event: {other:?}"),
}
assert!(matches!(out.get(3), Some(ResponseEvent::Completed { .. })));
}

#[tokio::test]
async fn aggregate_mode_suppresses_text_deltas_and_emits_only_final_message() {
// Arrange a stream: two deltas, then completed
let src = stream::iter(vec![
ev_text_delta("Hi "),
ev_text_delta("there"),
ev_completed(),
]);

let mut s = AggregateStreamExt::aggregate(src);
let mut out = Vec::new();
while let Some(ev) = s.next().await {
match ev {
Ok(e) => out.push(e),
Err(_) => panic!("stream error"),
}
if out.len() > 2 {
break;
}
}

// Expect: no deltas, only final full message then completed
assert_eq!(out.len(), 2);
match out.get(0) {
Some(ResponseEvent::OutputItemDone(ResponseItem::Message {
content, role, ..
})) => {
assert_eq!(role, "assistant");
let text = content.iter().find_map(|c| match c {
ContentItem::OutputText { text } => Some(text.clone()),
_ => None,
});
assert_eq!(text.as_deref(), Some("Hi there"));
}
other => panic!("unexpected first event: {other:?}"),
}
assert!(matches!(out.get(1), Some(ResponseEvent::Completed { .. })));
}
}
14 changes: 5 additions & 9 deletions codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use tracing::trace;
use tracing::warn;
use uuid::Uuid;

use crate::chat_completions::AggregateStreamExt;
use crate::chat_completions::stream_chat_completions;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
Expand Down Expand Up @@ -98,14 +97,11 @@ impl ModelClient {
)
.await?;

// Wrap it with the aggregation adapter so callers see *only*
// the final assistant message per turn (matching the
// behaviour of the Responses API).
let mut aggregated = if self.config.show_raw_agent_reasoning {
crate::chat_completions::AggregatedChatStream::streaming_mode(response_stream)
} else {
response_stream.aggregate()
};
// Use streaming mode so normal assistant text deltas are forwarded live
// while still accumulating to emit a single final OutputItemDone at turn end.
// Reasoning visibility remains controlled separately by config downstream.
let mut aggregated =
crate::chat_completions::AggregatedChatStream::streaming_mode(response_stream);

// Bridge the aggregated stream back into a standard
// `ResponseStream` by forwarding events through a channel.
Expand Down
137 changes: 137 additions & 0 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,143 @@ impl Session {
}
}

#[cfg(test)]
mod mapper_tests {
use super::*;
use tempfile::TempDir;

fn make_session(show_raw: bool) -> (Arc<Session>, async_channel::Receiver<Event>) {
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use std::sync::Arc as StdArc;

let (tx_event, rx_event) = async_channel::unbounded::<Event>();

// Create a minimal Config suitable for tests
let codex_home = TempDir::new().expect("temp dir");
let cfg: StdArc<Config> = StdArc::new(
Config::load_from_base_config_with_overrides(
crate::config::ConfigToml::default(),
crate::config::ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)
.expect("config"),
);
let provider = ModelProviderInfo {
name: "test".to_string(),
base_url: None,
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(1000),
requires_openai_auth: false,
};
let client = ModelClient::new(
cfg.clone(),
None,
provider,
cfg.model_reasoning_effort,
cfg.model_reasoning_summary,
uuid::Uuid::nil(),
);

let sess = Arc::new(Session {
client,
tx_event,
ctrl_c: Arc::new(Notify::new()),
cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
base_instructions: None,
user_instructions: None,
approval_policy: AskForApproval::OnRequest,
sandbox_policy: SandboxPolicy::new_read_only_policy(),
shell_environment_policy: ShellEnvironmentPolicy::default(),
writable_roots: Mutex::new(Vec::new()),
disable_response_storage: true,
tools_config: ToolsConfig::new(
&cfg.model_family,
AskForApproval::OnRequest,
SandboxPolicy::new_read_only_policy(),
cfg.include_plan_tool,
),
mcp_connection_manager: McpConnectionManager::default(),
notify: None,
rollout: Mutex::new(None),
state: Mutex::new(State::default()),
codex_linux_sandbox_exe: None,
user_shell: shell::Shell::Unknown,
show_raw_agent_reasoning: show_raw,
});
(sess, rx_event)
}

#[tokio::test]
async fn normal_text_deltas_always_emit_agent_message_delta() {
let (sess, rx) = make_session(false);
let sub_id = "s".to_string();

let delta = "hi".to_string();
{
let mut st = sess.state.lock().unwrap();
st.history.append_assistant_text(&delta);
}
let event = Event {
id: sub_id.clone(),
msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }),
};
sess.tx_event.send(event).await.ok();

let out = rx.recv().await.expect("no event");
match out.msg {
EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }) => {
assert_eq!(delta, "hi");
}
other => panic!("unexpected event: {other:?}"),
}
}

#[tokio::test]
async fn reasoning_content_delta_gated_by_flag() {
let (sess_off, rx_off) = make_session(false);
let sub_id = "s".to_string();

let delta = "think".to_string();
if sess_off.show_raw_agent_reasoning {
let event = Event {
id: sub_id.clone(),
msg: EventMsg::AgentReasoningRawContentDelta(AgentReasoningRawContentDeltaEvent {
delta: delta.clone(),
}),
};
sess_off.tx_event.send(event).await.ok();
}
assert!(rx_off.try_recv().is_err());

let (sess_on, rx_on) = make_session(true);
if sess_on.show_raw_agent_reasoning {
let event = Event {
id: sub_id.clone(),
msg: EventMsg::AgentReasoningRawContentDelta(AgentReasoningRawContentDeltaEvent {
delta: delta.clone(),
}),
};
sess_on.tx_event.send(event).await.ok();
}
let out = rx_on.recv().await.expect("no event");
match out.msg {
EventMsg::AgentReasoningRawContentDelta(AgentReasoningRawContentDeltaEvent {
delta,
}) => {
assert_eq!(delta, "think");
}
other => panic!("unexpected event: {other:?}"),
}
}
}
/// Mutable state of the agent
#[derive(Default)]
struct State {
Expand Down
Loading