Skip to content

Commit c283f9f

Browse files
authored
Add an operation to override current task context (#2431)
- Added an operation to override current task context - Added a test to check that cache stays the same
1 parent c9963b5 commit c283f9f

File tree

4 files changed

+263
-2
lines changed

4 files changed

+263
-2
lines changed

codex-rs/core/src/client.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use crate::error::CodexErr;
3333
use crate::error::Result;
3434
use crate::error::UsageLimitReachedError;
3535
use crate::flags::CODEX_RS_SSE_FIXTURE;
36+
use crate::model_family::ModelFamily;
3637
use crate::model_provider_info::ModelProviderInfo;
3738
use crate::model_provider_info::WireApi;
3839
use crate::models::ResponseItem;
@@ -311,6 +312,30 @@ impl ModelClient {
311312
pub fn get_provider(&self) -> ModelProviderInfo {
312313
self.provider.clone()
313314
}
315+
316+
/// Returns the currently configured model slug.
317+
pub fn get_model(&self) -> String {
318+
self.config.model.clone()
319+
}
320+
321+
/// Returns the currently configured model family.
322+
pub fn get_model_family(&self) -> ModelFamily {
323+
self.config.model_family.clone()
324+
}
325+
326+
/// Returns the current reasoning effort setting.
327+
pub fn get_reasoning_effort(&self) -> ReasoningEffortConfig {
328+
self.effort
329+
}
330+
331+
/// Returns the current reasoning summary setting.
332+
pub fn get_reasoning_summary(&self) -> ReasoningSummaryConfig {
333+
self.summary
334+
}
335+
336+
pub fn get_auth(&self) -> Option<CodexAuth> {
337+
self.auth.clone()
338+
}
314339
}
315340

316341
#[derive(Debug, Deserialize, Serialize)]

codex-rs/core/src/codex.rs

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -989,14 +989,91 @@ async fn submission_loop(
989989
rx_sub: Receiver<Submission>,
990990
) {
991991
// Wrap once to avoid cloning TurnContext for each task.
992-
let turn_context = Arc::new(turn_context);
992+
let mut turn_context = Arc::new(turn_context);
993993
// To break out of this loop, send Op::Shutdown.
994994
while let Ok(sub) = rx_sub.recv().await {
995995
debug!(?sub, "Submission");
996996
match sub.op {
997997
Op::Interrupt => {
998998
sess.interrupt_task();
999999
}
1000+
Op::OverrideTurnContext {
1001+
cwd,
1002+
approval_policy,
1003+
sandbox_policy,
1004+
model,
1005+
effort,
1006+
summary,
1007+
} => {
1008+
// Recalculate the persistent turn context with provided overrides.
1009+
let prev = Arc::clone(&turn_context);
1010+
let provider = prev.client.get_provider();
1011+
1012+
// Effective model + family
1013+
let (effective_model, effective_family) = if let Some(m) = model {
1014+
let fam =
1015+
find_family_for_model(&m).unwrap_or_else(|| config.model_family.clone());
1016+
(m, fam)
1017+
} else {
1018+
(prev.client.get_model(), prev.client.get_model_family())
1019+
};
1020+
1021+
// Effective reasoning settings
1022+
let effective_effort = effort.unwrap_or(prev.client.get_reasoning_effort());
1023+
let effective_summary = summary.unwrap_or(prev.client.get_reasoning_summary());
1024+
1025+
let auth = prev.client.get_auth();
1026+
// Build updated config for the client
1027+
let mut updated_config = (*config).clone();
1028+
updated_config.model = effective_model.clone();
1029+
updated_config.model_family = effective_family.clone();
1030+
1031+
let client = ModelClient::new(
1032+
Arc::new(updated_config),
1033+
auth,
1034+
provider,
1035+
effective_effort,
1036+
effective_summary,
1037+
sess.session_id,
1038+
);
1039+
1040+
let new_approval_policy = approval_policy.unwrap_or(prev.approval_policy);
1041+
let new_sandbox_policy = sandbox_policy
1042+
.clone()
1043+
.unwrap_or(prev.sandbox_policy.clone());
1044+
let new_cwd = cwd.clone().unwrap_or_else(|| prev.cwd.clone());
1045+
1046+
let tools_config = ToolsConfig::new(
1047+
&effective_family,
1048+
new_approval_policy,
1049+
new_sandbox_policy.clone(),
1050+
config.include_plan_tool,
1051+
config.include_apply_patch_tool,
1052+
);
1053+
1054+
let new_turn_context = TurnContext {
1055+
client,
1056+
tools_config,
1057+
user_instructions: prev.user_instructions.clone(),
1058+
base_instructions: prev.base_instructions.clone(),
1059+
approval_policy: new_approval_policy,
1060+
sandbox_policy: new_sandbox_policy.clone(),
1061+
shell_environment_policy: prev.shell_environment_policy.clone(),
1062+
cwd: new_cwd.clone(),
1063+
disable_response_storage: prev.disable_response_storage,
1064+
};
1065+
1066+
// Install the new persistent context for subsequent tasks/turns.
1067+
turn_context = Arc::new(new_turn_context);
1068+
if cwd.is_some() || approval_policy.is_some() || sandbox_policy.is_some() {
1069+
sess.record_conversation_items(&[ResponseItem::from(EnvironmentContext::new(
1070+
new_cwd,
1071+
new_approval_policy,
1072+
new_sandbox_policy,
1073+
))])
1074+
.await;
1075+
}
1076+
}
10001077
Op::UserInput { items } => {
10011078
// attempt to inject input into current task
10021079
if let Err(items) = sess.inject_input(items) {
@@ -1057,7 +1134,7 @@ async fn submission_loop(
10571134
cwd,
10581135
disable_response_storage: turn_context.disable_response_storage,
10591136
};
1060-
1137+
// TODO: record the new environment context in the conversation history
10611138
// no current task, spawn a new one with the per‑turn context
10621139
let task =
10631140
AgentTask::spawn(sess.clone(), Arc::new(fresh_turn_context), sub.id, items);

codex-rs/core/tests/prompt_caching.rs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
use codex_core::ConversationManager;
22
use codex_core::ModelProviderInfo;
33
use codex_core::built_in_model_providers;
4+
use codex_core::protocol::AskForApproval;
45
use codex_core::protocol::EventMsg;
56
use codex_core::protocol::InputItem;
67
use codex_core::protocol::Op;
8+
use codex_core::protocol::SandboxPolicy;
9+
use codex_core::protocol_config_types::ReasoningEffort as ReasoningEffortConfig;
10+
use codex_core::protocol_config_types::ReasoningSummary as ReasoningSummaryConfig;
711
use codex_login::CodexAuth;
812
use core_test_support::load_default_config_for_test;
913
use core_test_support::load_sse_fixture_with_id;
@@ -129,3 +133,126 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
129133
);
130134
assert_eq!(body2["input"], expected_body2);
131135
}
136+
137+
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
138+
async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() {
139+
use pretty_assertions::assert_eq;
140+
141+
let server = MockServer::start().await;
142+
143+
let sse = sse_completed("resp");
144+
let template = ResponseTemplate::new(200)
145+
.insert_header("content-type", "text/event-stream")
146+
.set_body_raw(sse, "text/event-stream");
147+
148+
// Expect two POSTs to /v1/responses
149+
Mock::given(method("POST"))
150+
.and(path("/v1/responses"))
151+
.respond_with(template)
152+
.expect(2)
153+
.mount(&server)
154+
.await;
155+
156+
let model_provider = ModelProviderInfo {
157+
base_url: Some(format!("{}/v1", server.uri())),
158+
..built_in_model_providers()["openai"].clone()
159+
};
160+
161+
let cwd = TempDir::new().unwrap();
162+
let codex_home = TempDir::new().unwrap();
163+
let mut config = load_default_config_for_test(&codex_home);
164+
config.cwd = cwd.path().to_path_buf();
165+
config.model_provider = model_provider;
166+
config.user_instructions = Some("be consistent and helpful".to_string());
167+
168+
let conversation_manager = ConversationManager::default();
169+
let codex = conversation_manager
170+
.new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key")))
171+
.await
172+
.expect("create new conversation")
173+
.conversation;
174+
175+
// First turn
176+
codex
177+
.submit(Op::UserInput {
178+
items: vec![InputItem::Text {
179+
text: "hello 1".into(),
180+
}],
181+
})
182+
.await
183+
.unwrap();
184+
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
185+
186+
// Change everything about the turn context.
187+
let new_cwd = TempDir::new().unwrap();
188+
let writable = TempDir::new().unwrap();
189+
codex
190+
.submit(Op::OverrideTurnContext {
191+
cwd: Some(new_cwd.path().to_path_buf()),
192+
approval_policy: Some(AskForApproval::Never),
193+
sandbox_policy: Some(SandboxPolicy::WorkspaceWrite {
194+
writable_roots: vec![writable.path().to_path_buf()],
195+
network_access: true,
196+
exclude_tmpdir_env_var: true,
197+
exclude_slash_tmp: true,
198+
}),
199+
model: Some("o3".to_string()),
200+
effort: Some(ReasoningEffortConfig::High),
201+
summary: Some(ReasoningSummaryConfig::Detailed),
202+
})
203+
.await
204+
.unwrap();
205+
206+
// Second turn after overrides
207+
codex
208+
.submit(Op::UserInput {
209+
items: vec![InputItem::Text {
210+
text: "hello 2".into(),
211+
}],
212+
})
213+
.await
214+
.unwrap();
215+
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
216+
217+
// Verify we issued exactly two requests, and the cached prefix stayed identical.
218+
let requests = server.received_requests().await.unwrap();
219+
assert_eq!(requests.len(), 2, "expected two POST requests");
220+
221+
let body1 = requests[0].body_json::<serde_json::Value>().unwrap();
222+
let body2 = requests[1].body_json::<serde_json::Value>().unwrap();
223+
224+
// prompt_cache_key should remain constant across overrides
225+
assert_eq!(
226+
body1["prompt_cache_key"], body2["prompt_cache_key"],
227+
"prompt_cache_key should not change across overrides"
228+
);
229+
230+
// The entire prefix from the first request should be identical and reused
231+
// as the prefix of the second request, ensuring cache hit potential.
232+
let expected_user_message_2 = serde_json::json!({
233+
"type": "message",
234+
"id": serde_json::Value::Null,
235+
"role": "user",
236+
"content": [ { "type": "input_text", "text": "hello 2" } ]
237+
});
238+
// After overriding the turn context, the environment context should be emitted again
239+
// reflecting the new cwd, approval policy and sandbox settings.
240+
let expected_env_text_2 = format!(
241+
"<environment_context>\nCurrent working directory: {}\nApproval policy: never\nSandbox mode: workspace-write\nNetwork access: enabled\n</environment_context>",
242+
new_cwd.path().to_string_lossy()
243+
);
244+
let expected_env_msg_2 = serde_json::json!({
245+
"type": "message",
246+
"id": serde_json::Value::Null,
247+
"role": "user",
248+
"content": [ { "type": "input_text", "text": expected_env_text_2 } ]
249+
});
250+
let expected_body2 = serde_json::json!(
251+
[
252+
body1["input"].as_array().unwrap().as_slice(),
253+
[expected_env_msg_2, expected_user_message_2].as_slice(),
254+
]
255+
.concat()
256+
);
257+
assert_eq!(body2["input"], expected_body2);
258+
}

codex-rs/protocol/src/protocol.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,38 @@ pub enum Op {
7575
summary: ReasoningSummaryConfig,
7676
},
7777

78+
/// Override parts of the persistent turn context for subsequent turns.
79+
///
80+
/// All fields are optional; when omitted, the existing value is preserved.
81+
/// This does not enqueue any input – it only updates defaults used for
82+
/// future `UserInput` turns.
83+
OverrideTurnContext {
84+
/// Updated `cwd` for sandbox/tool calls.
85+
#[serde(skip_serializing_if = "Option::is_none")]
86+
cwd: Option<PathBuf>,
87+
88+
/// Updated command approval policy.
89+
#[serde(skip_serializing_if = "Option::is_none")]
90+
approval_policy: Option<AskForApproval>,
91+
92+
/// Updated sandbox policy for tool calls.
93+
#[serde(skip_serializing_if = "Option::is_none")]
94+
sandbox_policy: Option<SandboxPolicy>,
95+
96+
/// Updated model slug. When set, the model family is derived
97+
/// automatically.
98+
#[serde(skip_serializing_if = "Option::is_none")]
99+
model: Option<String>,
100+
101+
/// Updated reasoning effort (honored only for reasoning-capable models).
102+
#[serde(skip_serializing_if = "Option::is_none")]
103+
effort: Option<ReasoningEffortConfig>,
104+
105+
/// Updated reasoning summary preference (honored only for reasoning-capable models).
106+
#[serde(skip_serializing_if = "Option::is_none")]
107+
summary: Option<ReasoningSummaryConfig>,
108+
},
109+
78110
/// Approve a command execution
79111
ExecApproval {
80112
/// The id of the submission we are approving

0 commit comments

Comments
 (0)