Skip to content

Commit 8ccce51

Browse files
authored
Merge branch 'main' into add-github-action-for-nix
2 parents 1513455 + 268a10f commit 8ccce51

File tree

8 files changed

+261
-7
lines changed

8 files changed

+261
-7
lines changed

codex-rs/core/src/client.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ use crate::openai_tools::create_tools_json_for_responses_api;
4747
use crate::protocol::RateLimitSnapshot;
4848
use crate::protocol::RateLimitWindow;
4949
use crate::protocol::TokenUsage;
50+
use crate::state::TaskKind;
5051
use crate::token_data::PlanType;
5152
use crate::util::backoff;
5253
use codex_otel::otel_event_manager::OtelEventManager;
@@ -123,8 +124,16 @@ impl ModelClient {
123124
/// the provider config. Public callers always invoke `stream()` – the
124125
/// specialised helpers are private to avoid accidental misuse.
125126
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
127+
self.stream_with_task_kind(prompt, TaskKind::Regular).await
128+
}
129+
130+
pub(crate) async fn stream_with_task_kind(
131+
&self,
132+
prompt: &Prompt,
133+
task_kind: TaskKind,
134+
) -> Result<ResponseStream> {
126135
match self.provider.wire_api {
127-
WireApi::Responses => self.stream_responses(prompt).await,
136+
WireApi::Responses => self.stream_responses(prompt, task_kind).await,
128137
WireApi::Chat => {
129138
// Create the raw streaming connection first.
130139
let response_stream = stream_chat_completions(
@@ -165,7 +174,11 @@ impl ModelClient {
165174
}
166175

167176
/// Implementation for the OpenAI *Responses* experimental API.
168-
async fn stream_responses(&self, prompt: &Prompt) -> Result<ResponseStream> {
177+
async fn stream_responses(
178+
&self,
179+
prompt: &Prompt,
180+
task_kind: TaskKind,
181+
) -> Result<ResponseStream> {
169182
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
170183
// short circuit for tests
171184
warn!(path, "Streaming from fixture");
@@ -244,7 +257,7 @@ impl ModelClient {
244257
let max_attempts = self.provider.request_max_retries();
245258
for attempt in 0..=max_attempts {
246259
match self
247-
.attempt_stream_responses(attempt, &payload_json, &auth_manager)
260+
.attempt_stream_responses(attempt, &payload_json, &auth_manager, task_kind)
248261
.await
249262
{
250263
Ok(stream) => {
@@ -272,6 +285,7 @@ impl ModelClient {
272285
attempt: u64,
273286
payload_json: &Value,
274287
auth_manager: &Option<Arc<AuthManager>>,
288+
task_kind: TaskKind,
275289
) -> std::result::Result<ResponseStream, StreamAttemptError> {
276290
// Always fetch the latest auth in case a prior attempt refreshed the token.
277291
let auth = auth_manager.as_ref().and_then(|m| m.auth());
@@ -294,6 +308,7 @@ impl ModelClient {
294308
.header("conversation_id", self.conversation_id.to_string())
295309
.header("session_id", self.conversation_id.to_string())
296310
.header(reqwest::header::ACCEPT, "text/event-stream")
311+
.header("Codex-Task-Type", task_kind.header_value())
297312
.json(payload_json);
298313

299314
if let Some(auth) = auth.as_ref()

codex-rs/core/src/codex.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ use crate::rollout::RolloutRecorderParams;
9999
use crate::shell;
100100
use crate::state::ActiveTurn;
101101
use crate::state::SessionServices;
102+
use crate::state::TaskKind;
102103
use crate::tasks::CompactTask;
103104
use crate::tasks::RegularTask;
104105
use crate::tasks::ReviewTask;
@@ -1634,6 +1635,7 @@ pub(crate) async fn run_task(
16341635
turn_context: Arc<TurnContext>,
16351636
sub_id: String,
16361637
input: Vec<InputItem>,
1638+
task_kind: TaskKind,
16371639
) -> Option<String> {
16381640
if input.is_empty() {
16391641
return None;
@@ -1717,6 +1719,7 @@ pub(crate) async fn run_task(
17171719
Arc::clone(&turn_diff_tracker),
17181720
sub_id.clone(),
17191721
turn_input,
1722+
task_kind,
17201723
)
17211724
.await
17221725
{
@@ -1942,6 +1945,7 @@ async fn run_turn(
19421945
turn_diff_tracker: SharedTurnDiffTracker,
19431946
sub_id: String,
19441947
input: Vec<ResponseItem>,
1948+
task_kind: TaskKind,
19451949
) -> CodexResult<TurnRunResult> {
19461950
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
19471951
let router = Arc::new(ToolRouter::from_config(
@@ -1971,6 +1975,7 @@ async fn run_turn(
19711975
Arc::clone(&turn_diff_tracker),
19721976
&sub_id,
19731977
&prompt,
1978+
task_kind,
19741979
)
19751980
.await
19761981
{
@@ -2044,6 +2049,7 @@ async fn try_run_turn(
20442049
turn_diff_tracker: SharedTurnDiffTracker,
20452050
sub_id: &str,
20462051
prompt: &Prompt,
2052+
task_kind: TaskKind,
20472053
) -> CodexResult<TurnRunResult> {
20482054
// call_ids that are part of this response.
20492055
let completed_call_ids = prompt
@@ -2109,7 +2115,11 @@ async fn try_run_turn(
21092115
summary: turn_context.client.get_reasoning_summary(),
21102116
});
21112117
sess.persist_rollout_items(&[rollout_item]).await;
2112-
let mut stream = turn_context.client.clone().stream(&prompt).await?;
2118+
let mut stream = turn_context
2119+
.client
2120+
.clone()
2121+
.stream_with_task_kind(prompt.as_ref(), task_kind)
2122+
.await?;
21132123

21142124
let tool_runtime = ToolCallRuntime::new(
21152125
Arc::clone(&router),

codex-rs/core/src/codex/compact.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::protocol::InputItem;
1616
use crate::protocol::InputMessageKind;
1717
use crate::protocol::TaskStartedEvent;
1818
use crate::protocol::TurnContextItem;
19+
use crate::state::TaskKind;
1920
use crate::truncate::truncate_middle;
2021
use crate::util::backoff;
2122
use askama::Template;
@@ -258,7 +259,11 @@ async fn drain_to_completed(
258259
sub_id: &str,
259260
prompt: &Prompt,
260261
) -> CodexResult<()> {
261-
let mut stream = turn_context.client.clone().stream(prompt).await?;
262+
let mut stream = turn_context
263+
.client
264+
.clone()
265+
.stream_with_task_kind(prompt, TaskKind::Compact)
266+
.await?;
262267
loop {
263268
let maybe_event = stream.next().await;
264269
let Some(event) = maybe_event else {

codex-rs/core/src/state/turn.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ pub(crate) enum TaskKind {
3434
Compact,
3535
}
3636

37+
impl TaskKind {
38+
pub(crate) fn header_value(self) -> &'static str {
39+
match self {
40+
TaskKind::Regular => "standard",
41+
TaskKind::Review => "review",
42+
TaskKind::Compact => "compact",
43+
}
44+
}
45+
}
46+
3747
#[derive(Clone)]
3848
pub(crate) struct RunningTask {
3949
pub(crate) handle: AbortHandle,
@@ -113,3 +123,15 @@ impl ActiveTurn {
113123
}
114124
}
115125
}
126+
127+
#[cfg(test)]
128+
mod tests {
129+
use super::TaskKind;
130+
131+
#[test]
132+
fn header_value_matches_expected_labels() {
133+
assert_eq!(TaskKind::Regular.header_value(), "standard");
134+
assert_eq!(TaskKind::Review.header_value(), "review");
135+
assert_eq!(TaskKind::Compact.header_value(), "compact");
136+
}
137+
}

codex-rs/core/src/tasks/regular.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ impl SessionTask for RegularTask {
2727
input: Vec<InputItem>,
2828
) -> Option<String> {
2929
let sess = session.clone_session();
30-
run_task(sess, ctx, sub_id, input).await
30+
run_task(sess, ctx, sub_id, input, TaskKind::Regular).await
3131
}
3232
}

codex-rs/core/src/tasks/review.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl SessionTask for ReviewTask {
2828
input: Vec<InputItem>,
2929
) -> Option<String> {
3030
let sess = session.clone_session();
31-
run_task(sess, ctx, sub_id, input).await
31+
run_task(sess, ctx, sub_id, input, TaskKind::Review).await
3232
}
3333

3434
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
use std::sync::Arc;
2+
3+
use codex_app_server_protocol::AuthMode;
4+
use codex_core::ContentItem;
5+
use codex_core::ModelClient;
6+
use codex_core::ModelProviderInfo;
7+
use codex_core::Prompt;
8+
use codex_core::ResponseEvent;
9+
use codex_core::ResponseItem;
10+
use codex_core::WireApi;
11+
use codex_otel::otel_event_manager::OtelEventManager;
12+
use codex_protocol::ConversationId;
13+
use core_test_support::load_default_config_for_test;
14+
use core_test_support::responses;
15+
use futures::StreamExt;
16+
use tempfile::TempDir;
17+
use wiremock::matchers::header;
18+
19+
#[tokio::test]
20+
async fn responses_stream_includes_task_type_header() {
21+
core_test_support::skip_if_no_network!();
22+
23+
let server = responses::start_mock_server().await;
24+
let response_body = responses::sse(vec![
25+
responses::ev_response_created("resp-1"),
26+
responses::ev_completed("resp-1"),
27+
]);
28+
29+
let request_recorder = responses::mount_sse_once_match(
30+
&server,
31+
header("Codex-Task-Type", "standard"),
32+
response_body,
33+
)
34+
.await;
35+
36+
let provider = ModelProviderInfo {
37+
name: "mock".into(),
38+
base_url: Some(format!("{}/v1", server.uri())),
39+
env_key: None,
40+
env_key_instructions: None,
41+
wire_api: WireApi::Responses,
42+
query_params: None,
43+
http_headers: None,
44+
env_http_headers: None,
45+
request_max_retries: Some(0),
46+
stream_max_retries: Some(0),
47+
stream_idle_timeout_ms: Some(5_000),
48+
requires_openai_auth: false,
49+
};
50+
51+
let codex_home = TempDir::new().expect("failed to create TempDir");
52+
let mut config = load_default_config_for_test(&codex_home);
53+
config.model_provider_id = provider.name.clone();
54+
config.model_provider = provider.clone();
55+
let effort = config.model_reasoning_effort;
56+
let summary = config.model_reasoning_summary;
57+
let config = Arc::new(config);
58+
59+
let conversation_id = ConversationId::new();
60+
61+
let otel_event_manager = OtelEventManager::new(
62+
conversation_id,
63+
config.model.as_str(),
64+
config.model_family.slug.as_str(),
65+
None,
66+
Some(AuthMode::ChatGPT),
67+
false,
68+
"test".to_string(),
69+
);
70+
71+
let client = ModelClient::new(
72+
Arc::clone(&config),
73+
None,
74+
otel_event_manager,
75+
provider,
76+
effort,
77+
summary,
78+
conversation_id,
79+
);
80+
81+
let mut prompt = Prompt::default();
82+
prompt.input = vec![ResponseItem::Message {
83+
id: None,
84+
role: "user".into(),
85+
content: vec![ContentItem::InputText {
86+
text: "hello".into(),
87+
}],
88+
}];
89+
90+
let mut stream = client.stream(&prompt).await.expect("stream failed");
91+
while let Some(event) = stream.next().await {
92+
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
93+
break;
94+
}
95+
}
96+
97+
let request = request_recorder.single_request();
98+
assert_eq!(
99+
request.header("Codex-Task-Type").as_deref(),
100+
Some("standard")
101+
);
102+
}

0 commit comments

Comments
 (0)