Skip to content

Commit f4eec12

Browse files
feat: extend active tools to lifecycle of turn
1 parent 4ca5bdf commit f4eec12

File tree

6 files changed

+488
-39
lines changed

6 files changed

+488
-39
lines changed

codex-rs/core/src/codex.rs

Lines changed: 216 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,14 +1107,19 @@ impl Session {
11071107
}
11081108
}
11091109

1110-
pub(crate) async fn set_next_mcp_tool_selection(&self, tool_names: Vec<String>) {
1110+
pub(crate) async fn merge_mcp_tool_selection(&self, tool_names: Vec<String>) -> Vec<String> {
11111111
let mut state = self.state.lock().await;
1112-
state.set_next_mcp_tool_selection(tool_names);
1112+
state.merge_mcp_tool_selection(tool_names)
11131113
}
11141114

1115-
pub(crate) async fn take_next_mcp_tool_selection(&self) -> Option<Vec<String>> {
1115+
pub(crate) async fn get_mcp_tool_selection(&self) -> Option<Vec<String>> {
1116+
let state = self.state.lock().await;
1117+
state.get_mcp_tool_selection()
1118+
}
1119+
1120+
pub(crate) async fn clear_mcp_tool_selection(&self) {
11161121
let mut state = self.state.lock().await;
1117-
state.take_next_mcp_tool_selection()
1122+
state.clear_mcp_tool_selection();
11181123
}
11191124

11201125
async fn record_initial_history(&self, conversation_history: InitialHistory) {
@@ -3797,6 +3802,28 @@ fn filter_codex_apps_mcp_tools(
37973802
mcp_tools
37983803
}
37993804

3805+
fn filter_codex_apps_mcp_tools_only(
3806+
mut mcp_tools: HashMap<String, crate::mcp_connection_manager::ToolInfo>,
3807+
connectors: &[connectors::AppInfo],
3808+
) -> HashMap<String, crate::mcp_connection_manager::ToolInfo> {
3809+
let allowed: HashSet<&str> = connectors
3810+
.iter()
3811+
.map(|connector| connector.id.as_str())
3812+
.collect();
3813+
3814+
mcp_tools.retain(|_, tool| {
3815+
if tool.server_name != CODEX_APPS_MCP_SERVER_NAME {
3816+
return false;
3817+
}
3818+
let Some(connector_id) = codex_apps_connector_id(tool) else {
3819+
return false;
3820+
};
3821+
allowed.contains(connector_id)
3822+
});
3823+
3824+
mcp_tools
3825+
}
3826+
38003827
fn filter_mcp_tools_by_name(
38013828
mut mcp_tools: HashMap<String, crate::mcp_connection_manager::ToolInfo>,
38023829
selected_tools: &[String],
@@ -3840,29 +3867,39 @@ async fn run_sampling_request(
38403867
.list_all_tools()
38413868
.or_cancel(&cancellation_token)
38423869
.await?;
3870+
38433871
let search_tool_enabled = turn_context.config.features.enabled(Feature::SearchTool);
3844-
if search_tool_enabled {
3845-
if let Some(selected_tools) = sess.take_next_mcp_tool_selection().await {
3846-
mcp_tools = filter_mcp_tools_by_name(mcp_tools, &selected_tools);
3847-
} else {
3848-
mcp_tools.clear();
3849-
}
3872+
let apps_enabled = turn_context.config.features.enabled(Feature::Apps);
3873+
3874+
let apps_connectors = if apps_enabled {
3875+
Some(filter_connectors_for_input(
3876+
connectors::accessible_connectors_from_mcp_tools(&mcp_tools),
3877+
&input,
3878+
tool_selection.explicit_app_paths,
3879+
tool_selection.skill_name_counts_lower,
3880+
))
38503881
} else {
3851-
let connectors_for_tools = if turn_context.config.features.enabled(Feature::Apps) {
3852-
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
3853-
Some(filter_connectors_for_input(
3854-
connectors,
3855-
&input,
3856-
tool_selection.explicit_app_paths,
3857-
tool_selection.skill_name_counts_lower,
3858-
))
3859-
} else {
3860-
None
3861-
};
3862-
if let Some(connectors) = connectors_for_tools.as_ref() {
3863-
mcp_tools = filter_codex_apps_mcp_tools(mcp_tools, connectors);
3882+
None
3883+
};
3884+
3885+
if search_tool_enabled {
3886+
let mut selected_mcp_tools =
3887+
if let Some(selected_tools) = sess.get_mcp_tool_selection().await {
3888+
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tools)
3889+
} else {
3890+
HashMap::new()
3891+
};
3892+
3893+
if let Some(connectors) = apps_connectors.as_ref() {
3894+
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, connectors);
3895+
selected_mcp_tools.extend(apps_mcp_tools);
38643896
}
3897+
3898+
mcp_tools = selected_mcp_tools;
3899+
} else if let Some(connectors) = apps_connectors.as_ref() {
3900+
mcp_tools = filter_codex_apps_mcp_tools(mcp_tools, connectors);
38653901
}
3902+
38663903
let router = Arc::new(ToolRouter::from_config(
38673904
&turn_context.tools_config,
38683905
Some(
@@ -4676,6 +4713,8 @@ pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -
46764713
pub(crate) use tests::make_session_and_context;
46774714
#[cfg(test)]
46784715
pub(crate) use tests::make_session_and_context_with_rx;
4716+
#[cfg(test)]
4717+
pub(crate) use tests::make_session_configuration_for_tests;
46794718

46804719
#[cfg(test)]
46814720
mod tests {
@@ -4685,6 +4724,7 @@ mod tests {
46854724
use crate::config::test_config;
46864725
use crate::exec::ExecToolCallOutput;
46874726
use crate::function_tool::FunctionCallError;
4727+
use crate::mcp_connection_manager::ToolInfo;
46884728
use crate::shell::default_user_shell;
46894729
use crate::tools::format_exec_output_str;
46904730

@@ -4721,6 +4761,8 @@ mod tests {
47214761

47224762
use codex_protocol::mcp::CallToolResult as McpCallToolResult;
47234763
use pretty_assertions::assert_eq;
4764+
use rmcp::model::JsonObject;
4765+
use rmcp::model::Tool;
47244766
use serde::Deserialize;
47254767
use serde_json::json;
47264768
use std::path::PathBuf;
@@ -4757,6 +4799,30 @@ mod tests {
47574799
}
47584800
}
47594801

4802+
fn make_mcp_tool(
4803+
server_name: &str,
4804+
tool_name: &str,
4805+
connector_id: Option<&str>,
4806+
connector_name: Option<&str>,
4807+
) -> ToolInfo {
4808+
ToolInfo {
4809+
server_name: server_name.to_string(),
4810+
tool_name: tool_name.to_string(),
4811+
tool: Tool {
4812+
name: tool_name.to_string().into(),
4813+
title: None,
4814+
description: Some(format!("Test tool: {tool_name}").into()),
4815+
input_schema: Arc::new(JsonObject::default()),
4816+
output_schema: None,
4817+
annotations: None,
4818+
icons: None,
4819+
meta: None,
4820+
},
4821+
connector_id: connector_id.map(str::to_string),
4822+
connector_name: connector_name.map(str::to_string),
4823+
}
4824+
}
4825+
47604826
#[tokio::test]
47614827
async fn get_base_instructions_no_user_content() {
47624828
let prompt_with_apply_patch_instructions =
@@ -4860,6 +4926,93 @@ mod tests {
48604926
assert_eq!(selected, Vec::new());
48614927
}
48624928

4929+
#[test]
4930+
fn search_tool_selection_keeps_codex_apps_tools_without_mentions() {
4931+
let selected_tool_names = vec![
4932+
"mcp__codex_apps__calendar_create_event".to_string(),
4933+
"mcp__rmcp__echo".to_string(),
4934+
];
4935+
let mcp_tools = HashMap::from([
4936+
(
4937+
"mcp__codex_apps__calendar_create_event".to_string(),
4938+
make_mcp_tool(
4939+
CODEX_APPS_MCP_SERVER_NAME,
4940+
"calendar_create_event",
4941+
Some("calendar"),
4942+
Some("Calendar"),
4943+
),
4944+
),
4945+
(
4946+
"mcp__rmcp__echo".to_string(),
4947+
make_mcp_tool("rmcp", "echo", None, None),
4948+
),
4949+
]);
4950+
4951+
let mut selected_mcp_tools =
4952+
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tool_names);
4953+
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
4954+
let connectors = filter_connectors_for_input(
4955+
connectors,
4956+
&[user_message("run the selected tools")],
4957+
&[],
4958+
&HashMap::new(),
4959+
);
4960+
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, &connectors);
4961+
selected_mcp_tools.extend(apps_mcp_tools);
4962+
4963+
let mut tool_names: Vec<String> = selected_mcp_tools.into_keys().collect();
4964+
tool_names.sort();
4965+
assert_eq!(
4966+
tool_names,
4967+
vec![
4968+
"mcp__codex_apps__calendar_create_event".to_string(),
4969+
"mcp__rmcp__echo".to_string(),
4970+
]
4971+
);
4972+
}
4973+
4974+
#[test]
4975+
fn apps_mentions_add_codex_apps_tools_to_search_selected_set() {
4976+
let selected_tool_names = vec!["mcp__rmcp__echo".to_string()];
4977+
let mcp_tools = HashMap::from([
4978+
(
4979+
"mcp__codex_apps__calendar_create_event".to_string(),
4980+
make_mcp_tool(
4981+
CODEX_APPS_MCP_SERVER_NAME,
4982+
"calendar_create_event",
4983+
Some("calendar"),
4984+
Some("Calendar"),
4985+
),
4986+
),
4987+
(
4988+
"mcp__rmcp__echo".to_string(),
4989+
make_mcp_tool("rmcp", "echo", None, None),
4990+
),
4991+
]);
4992+
4993+
let mut selected_mcp_tools =
4994+
filter_mcp_tools_by_name(mcp_tools.clone(), &selected_tool_names);
4995+
let connectors = connectors::accessible_connectors_from_mcp_tools(&mcp_tools);
4996+
let connectors = filter_connectors_for_input(
4997+
connectors,
4998+
&[user_message("use $calendar and then echo the response")],
4999+
&[],
5000+
&HashMap::new(),
5001+
);
5002+
let apps_mcp_tools = filter_codex_apps_mcp_tools_only(mcp_tools, &connectors);
5003+
selected_mcp_tools.extend(apps_mcp_tools);
5004+
5005+
let mut tool_names: Vec<String> = selected_mcp_tools.into_keys().collect();
5006+
tool_names.sort();
5007+
assert_eq!(
5008+
tool_names,
5009+
vec![
5010+
"mcp__codex_apps__calendar_create_event".to_string(),
5011+
"mcp__rmcp__echo".to_string(),
5012+
]
5013+
);
5014+
}
5015+
48635016
#[tokio::test]
48645017
async fn reconstruct_history_matches_live_compactions() {
48655018
let (session, turn_context) = make_session_and_context().await;
@@ -5484,6 +5637,46 @@ mod tests {
54845637
)
54855638
}
54865639

5640+
pub(crate) async fn make_session_configuration_for_tests() -> SessionConfiguration {
5641+
let codex_home = tempfile::tempdir().expect("create temp dir");
5642+
let config = build_test_config(codex_home.path()).await;
5643+
let config = Arc::new(config);
5644+
let model = ModelsManager::get_model_offline(config.model.as_deref());
5645+
let model_info = ModelsManager::construct_model_info_offline(model.as_str(), &config);
5646+
let reasoning_effort = config.model_reasoning_effort;
5647+
let collaboration_mode = CollaborationMode {
5648+
mode: ModeKind::Default,
5649+
settings: Settings {
5650+
model,
5651+
reasoning_effort,
5652+
developer_instructions: None,
5653+
},
5654+
};
5655+
5656+
SessionConfiguration {
5657+
provider: config.model_provider.clone(),
5658+
collaboration_mode,
5659+
model_reasoning_summary: config.model_reasoning_summary,
5660+
developer_instructions: config.developer_instructions.clone(),
5661+
user_instructions: config.user_instructions.clone(),
5662+
personality: config.personality,
5663+
base_instructions: config
5664+
.base_instructions
5665+
.clone()
5666+
.unwrap_or_else(|| model_info.get_model_instructions(config.personality)),
5667+
compact_prompt: config.compact_prompt.clone(),
5668+
approval_policy: config.approval_policy.clone(),
5669+
sandbox_policy: config.sandbox_policy.clone(),
5670+
windows_sandbox_level: WindowsSandboxLevel::from_config(&config),
5671+
cwd: config.cwd.clone(),
5672+
codex_home: config.codex_home.clone(),
5673+
thread_name: None,
5674+
original_config_do_not_use: Arc::clone(&config),
5675+
session_source: SessionSource::Exec,
5676+
dynamic_tools: Vec::new(),
5677+
}
5678+
}
5679+
54875680
pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
54885681
let (tx_event, _rx_event) = async_channel::unbounded();
54895682
let codex_home = tempfile::tempdir().expect("create temp dir");

0 commit comments

Comments
 (0)