Skip to content

Commit 50f11ae

Browse files
feat: extend active tools to lifecycle of turn
1 parent 4ca5bdf commit 50f11ae

File tree

6 files changed

+324
-21
lines changed

6 files changed

+324
-21
lines changed

codex-rs/core/src/codex.rs

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,14 +1107,19 @@ impl Session {
11071107
}
11081108
}
11091109

1110-
pub(crate) async fn set_next_mcp_tool_selection(&self, tool_names: Vec<String>) {
1110+
pub(crate) async fn merge_mcp_tool_selection(&self, tool_names: Vec<String>) -> Vec<String> {
11111111
let mut state = self.state.lock().await;
1112-
state.set_next_mcp_tool_selection(tool_names);
1112+
state.merge_mcp_tool_selection(tool_names)
11131113
}
11141114

1115-
pub(crate) async fn take_next_mcp_tool_selection(&self) -> Option<Vec<String>> {
1115+
pub(crate) async fn get_mcp_tool_selection(&self) -> Option<Vec<String>> {
1116+
let state = self.state.lock().await;
1117+
state.get_mcp_tool_selection()
1118+
}
1119+
1120+
pub(crate) async fn clear_mcp_tool_selection(&self) {
11161121
let mut state = self.state.lock().await;
1117-
state.take_next_mcp_tool_selection()
1122+
state.clear_mcp_tool_selection();
11181123
}
11191124

11201125
async fn record_initial_history(&self, conversation_history: InitialHistory) {
@@ -3842,7 +3847,7 @@ async fn run_sampling_request(
38423847
.await?;
38433848
let search_tool_enabled = turn_context.config.features.enabled(Feature::SearchTool);
38443849
if search_tool_enabled {
3845-
if let Some(selected_tools) = sess.take_next_mcp_tool_selection().await {
3850+
if let Some(selected_tools) = sess.get_mcp_tool_selection().await {
38463851
mcp_tools = filter_mcp_tools_by_name(mcp_tools, &selected_tools);
38473852
} else {
38483853
mcp_tools.clear();
@@ -4676,6 +4681,8 @@ pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -
46764681
pub(crate) use tests::make_session_and_context;
46774682
#[cfg(test)]
46784683
pub(crate) use tests::make_session_and_context_with_rx;
4684+
#[cfg(test)]
4685+
pub(crate) use tests::make_session_configuration_for_tests;
46794686

46804687
#[cfg(test)]
46814688
mod tests {
@@ -5484,6 +5491,46 @@ mod tests {
54845491
)
54855492
}
54865493

5494+
pub(crate) async fn make_session_configuration_for_tests() -> SessionConfiguration {
5495+
let codex_home = tempfile::tempdir().expect("create temp dir");
5496+
let config = build_test_config(codex_home.path()).await;
5497+
let config = Arc::new(config);
5498+
let model = ModelsManager::get_model_offline(config.model.as_deref());
5499+
let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config);
5500+
let reasoning_effort = config.model_reasoning_effort;
5501+
let collaboration_mode = CollaborationMode {
5502+
mode: ModeKind::Default,
5503+
settings: Settings {
5504+
model,
5505+
reasoning_effort,
5506+
developer_instructions: None,
5507+
},
5508+
};
5509+
5510+
SessionConfiguration {
5511+
provider: config.model_provider.clone(),
5512+
collaboration_mode,
5513+
model_reasoning_summary: config.model_reasoning_summary,
5514+
developer_instructions: config.developer_instructions.clone(),
5515+
user_instructions: config.user_instructions.clone(),
5516+
personality: config.personality,
5517+
base_instructions: config
5518+
.base_instructions
5519+
.clone()
5520+
.unwrap_or_else(|| model_info.get_model_instructions(config.personality)),
5521+
compact_prompt: config.compact_prompt.clone(),
5522+
approval_policy: config.approval_policy.clone(),
5523+
sandbox_policy: config.sandbox_policy.clone(),
5524+
windows_sandbox_level: WindowsSandboxLevel::from_config(&config),
5525+
cwd: config.cwd.clone(),
5526+
codex_home: config.codex_home.clone(),
5527+
thread_name: None,
5528+
original_config_do_not_use: Arc::clone(&config),
5529+
session_source: SessionSource::Exec,
5530+
dynamic_tools: Vec::new(),
5531+
}
5532+
}
5533+
54875534
pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
54885535
let (tx_event, _rx_event) = async_channel::unbounded();
54895536
let codex_home = tempfile::tempdir().expect("create temp dir");

codex-rs/core/src/state/session.rs

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub(crate) struct SessionState {
2424
/// TODO(owen): This is a temporary solution to avoid updating a thread's updated_at
2525
/// timestamp when resuming a session. Remove this once SQLite is in place.
2626
pub(crate) initial_context_seeded: bool,
27-
pub(crate) next_mcp_tool_selection: Option<Vec<String>>,
27+
pub(crate) active_mcp_tool_selection: Option<Vec<String>>,
2828
}
2929

3030
impl SessionState {
@@ -39,7 +39,7 @@ impl SessionState {
3939
dependency_env: HashMap::new(),
4040
mcp_dependency_prompted: HashSet::new(),
4141
initial_context_seeded: false,
42-
next_mcp_tool_selection: None,
42+
active_mcp_tool_selection: None,
4343
}
4444
}
4545

@@ -128,12 +128,30 @@ impl SessionState {
128128
self.dependency_env.clone()
129129
}
130130

131-
pub(crate) fn set_next_mcp_tool_selection(&mut self, tool_names: Vec<String>) {
132-
self.next_mcp_tool_selection = Some(tool_names);
131+
pub(crate) fn merge_mcp_tool_selection(&mut self, tool_names: Vec<String>) -> Vec<String> {
132+
if tool_names.is_empty() {
133+
return self.active_mcp_tool_selection.clone().unwrap_or_default();
134+
}
135+
136+
let mut merged = self.active_mcp_tool_selection.take().unwrap_or_default();
137+
let mut seen: HashSet<String> = merged.iter().cloned().collect();
138+
139+
for tool_name in tool_names {
140+
if seen.insert(tool_name.clone()) {
141+
merged.push(tool_name);
142+
}
143+
}
144+
145+
self.active_mcp_tool_selection = Some(merged.clone());
146+
merged
133147
}
134148

135-
pub(crate) fn take_next_mcp_tool_selection(&mut self) -> Option<Vec<String>> {
136-
self.next_mcp_tool_selection.take()
149+
pub(crate) fn get_mcp_tool_selection(&self) -> Option<Vec<String>> {
150+
self.active_mcp_tool_selection.clone()
151+
}
152+
153+
pub(crate) fn clear_mcp_tool_selection(&mut self) {
154+
self.active_mcp_tool_selection = None;
137155
}
138156
}
139157

@@ -150,3 +168,79 @@ fn merge_rate_limit_fields(
150168
}
151169
snapshot
152170
}
171+
172+
#[cfg(test)]
173+
mod tests {
174+
use super::*;
175+
use crate::codex::make_session_configuration_for_tests;
176+
use pretty_assertions::assert_eq;
177+
178+
#[tokio::test]
179+
async fn merge_mcp_tool_selection_deduplicates_and_preserves_order() {
180+
let session_configuration = make_session_configuration_for_tests().await;
181+
let mut state = SessionState::new(session_configuration);
182+
183+
let merged = state.merge_mcp_tool_selection(vec![
184+
"mcp__rmcp__echo".to_string(),
185+
"mcp__rmcp__image".to_string(),
186+
"mcp__rmcp__echo".to_string(),
187+
]);
188+
assert_eq!(
189+
merged,
190+
vec![
191+
"mcp__rmcp__echo".to_string(),
192+
"mcp__rmcp__image".to_string(),
193+
]
194+
);
195+
196+
let merged = state.merge_mcp_tool_selection(vec![
197+
"mcp__rmcp__image".to_string(),
198+
"mcp__rmcp__search".to_string(),
199+
]);
200+
assert_eq!(
201+
merged,
202+
vec![
203+
"mcp__rmcp__echo".to_string(),
204+
"mcp__rmcp__image".to_string(),
205+
"mcp__rmcp__search".to_string(),
206+
]
207+
);
208+
}
209+
210+
#[tokio::test]
211+
async fn merge_mcp_tool_selection_empty_input_is_noop() {
212+
let session_configuration = make_session_configuration_for_tests().await;
213+
let mut state = SessionState::new(session_configuration);
214+
state.merge_mcp_tool_selection(vec![
215+
"mcp__rmcp__echo".to_string(),
216+
"mcp__rmcp__image".to_string(),
217+
]);
218+
219+
let merged = state.merge_mcp_tool_selection(Vec::new());
220+
assert_eq!(
221+
merged,
222+
vec![
223+
"mcp__rmcp__echo".to_string(),
224+
"mcp__rmcp__image".to_string(),
225+
]
226+
);
227+
assert_eq!(
228+
state.get_mcp_tool_selection(),
229+
Some(vec![
230+
"mcp__rmcp__echo".to_string(),
231+
"mcp__rmcp__image".to_string(),
232+
])
233+
);
234+
}
235+
236+
#[tokio::test]
237+
async fn clear_mcp_tool_selection_removes_selection() {
238+
let session_configuration = make_session_configuration_for_tests().await;
239+
let mut state = SessionState::new(session_configuration);
240+
state.merge_mcp_tool_selection(vec!["mcp__rmcp__echo".to_string()]);
241+
242+
state.clear_mcp_tool_selection();
243+
244+
assert_eq!(state.get_mcp_tool_selection(), None);
245+
}
246+
}

codex-rs/core/src/tasks/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ impl Session {
117117
task: T,
118118
) {
119119
self.abort_all_tasks(TurnAbortReason::Replaced).await;
120+
self.clear_mcp_tool_selection().await;
120121
self.seed_initial_context_if_needed(turn_context.as_ref())
121122
.await;
122123

codex-rs/core/src/tools/handlers/search_tool_bm25.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ impl ToolHandler for SearchToolBm25Handler {
114114
entries.sort_by(|a, b| a.name.cmp(&b.name));
115115

116116
if entries.is_empty() {
117-
session.set_next_mcp_tool_selection(Vec::new()).await;
117+
let active_selected_tools = session.get_mcp_tool_selection().await.unwrap_or_default();
118118
let content = json!({
119119
"query": query,
120120
"total_tools": 0,
121-
"selected_tools": [],
121+
"active_selected_tools": active_selected_tools,
122122
"tools": [],
123123
})
124124
.to_string();
@@ -157,14 +157,12 @@ impl ToolHandler for SearchToolBm25Handler {
157157
}));
158158
}
159159

160-
session
161-
.set_next_mcp_tool_selection(selected_tools.clone())
162-
.await;
160+
let active_selected_tools = session.merge_mcp_tool_selection(selected_tools).await;
163161

164162
let content = json!({
165163
"query": query,
166164
"total_tools": entries.len(),
167-
"selected_tools": selected_tools,
165+
"active_selected_tools": active_selected_tools,
168166
"tools": result_payloads,
169167
})
170168
.to_string();

codex-rs/core/templates/search_tool/developer_instructions.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ Follow this workflow:
88
- `query` (required): focused terms that describe the capability you need.
99
- `limit` (optional): maximum number of tools to return (default `8`, max `50`).
1010
2. Use the returned `tools` list to decide which MCP tools are relevant.
11-
3. On the next request, only the returned `selected_tools` will be available. Invoke the MCP tool(s) you need there.
12-
4. MCP tool selections are consumed after that next request. Search again if you need different MCP tools.
11+
3. Matching tools are added to `active_selected_tools`. Only tools in `active_selected_tools` are available for the remainder of the current turn.
12+
4. Repeated searches in the same turn are additive: new matches are unioned into `active_selected_tools`.
13+
5. `active_selected_tools` resets at the start of the next turn.
1314

1415
Notes:
1516
- Core tools remain available without searching.

0 commit comments

Comments
 (0)