Skip to content

Commit 9659583

Browse files
authored
feat: add close tool implementation for collab (#9090)
Pretty straight forward. A known follow-up will be to drop it from the AgentControl
1 parent 623707a commit 9659583

File tree

2 files changed

+84
-16
lines changed

2 files changed

+84
-16
lines changed

codex-rs/core/src/agent/control.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ impl AgentControl {
6868
.await
6969
}
7070

71+
/// Submit a shutdown request to an existing agent thread.
72+
pub(crate) async fn shutdown_agent(&self, agent_id: ThreadId) -> CodexResult<String> {
73+
let state = self.upgrade()?;
74+
state.send_op(agent_id, Op::Shutdown {}).await
75+
}
76+
7177
#[allow(dead_code)] // Will be used for collab tools.
7278
/// Fetch the last known status for `agent_id`, returning `NotFound` when unavailable.
7379
pub(crate) async fn get_status(&self, agent_id: ThreadId) -> AgentStatus {

codex-rs/core/src/tools/handlers/collab.rs

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ impl ToolHandler for CollabHandler {
5656
"spawn_agent" => spawn::handle(session, turn, arguments).await,
5757
"send_input" => send_input::handle(session, arguments).await,
5858
"wait" => wait::handle(session, arguments).await,
59-
"close_agent" => close_agent::handle(arguments).await,
59+
"close_agent" => close_agent::handle(session, arguments).await,
6060
other => Err(FunctionCallError::RespondToModel(format!(
6161
"unsupported collab tool {other}"
6262
))),
@@ -240,14 +240,56 @@ mod wait {
240240

241241
pub mod close_agent {
242242
use super::*;
243+
use crate::codex::Session;
244+
use std::sync::Arc;
245+
246+
#[derive(Debug, Deserialize, Serialize)]
247+
pub(super) struct CloseAgentResult {
248+
pub(super) status: AgentStatus,
249+
}
243250

244-
pub async fn handle(arguments: String) -> Result<ToolOutput, FunctionCallError> {
251+
pub async fn handle(
252+
session: Arc<Session>,
253+
arguments: String,
254+
) -> Result<ToolOutput, FunctionCallError> {
245255
let args: CloseAgentArgs = parse_arguments(&arguments)?;
246-
let _agent_id = agent_id(&args.id)?;
247-
// TODO(jif): implement agent shutdown and return the final status.
248-
Err(FunctionCallError::Fatal(
249-
"close_agent not implemented".to_string(),
250-
))
256+
let agent_id = agent_id(&args.id)?;
257+
let mut status_rx = session
258+
.services
259+
.agent_control
260+
.subscribe_status(agent_id)
261+
.await
262+
.map_err(|err| match err {
263+
CodexErr::ThreadNotFound(id) => {
264+
FunctionCallError::RespondToModel(format!("agent with id {id} not found"))
265+
}
266+
err => FunctionCallError::Fatal(err.to_string()),
267+
})?;
268+
let status = status_rx.borrow_and_update().clone();
269+
270+
if !matches!(status, AgentStatus::Shutdown) {
271+
let _ = session
272+
.services
273+
.agent_control
274+
.shutdown_agent(agent_id)
275+
.await
276+
.map_err(|err| match err {
277+
CodexErr::ThreadNotFound(id) => {
278+
FunctionCallError::RespondToModel(format!("agent with id {id} not found"))
279+
}
280+
err => FunctionCallError::Fatal(err.to_string()),
281+
})?;
282+
}
283+
284+
let content = serde_json::to_string(&CloseAgentResult { status }).map_err(|err| {
285+
FunctionCallError::Fatal(format!("failed to serialize close_agent result: {err}"))
286+
})?;
287+
288+
Ok(ToolOutput::Function {
289+
content,
290+
success: Some(true),
291+
content_items: None,
292+
})
251293
}
252294
}
253295

@@ -587,21 +629,41 @@ mod tests {
587629
}
588630

589631
#[tokio::test]
590-
async fn close_agent_reports_not_implemented() {
591-
let (session, turn) = make_session_and_context().await;
632+
async fn close_agent_submits_shutdown_and_returns_status() {
633+
let (mut session, turn) = make_session_and_context().await;
634+
let manager = thread_manager();
635+
session.services.agent_control = manager.agent_control();
636+
let config = turn.client.config().as_ref().clone();
637+
let thread = manager.start_thread(config).await.expect("start thread");
638+
let agent_id = thread.thread_id;
639+
let status_before = manager.agent_control().get_status(agent_id).await;
640+
592641
let invocation = invocation(
593642
Arc::new(session),
594643
Arc::new(turn),
595644
"close_agent",
596-
function_payload(json!({"id": ThreadId::new().to_string()})),
645+
function_payload(json!({"id": agent_id.to_string()})),
597646
);
598-
let Err(err) = CollabHandler.handle(invocation).await else {
599-
panic!("close_agent should fail");
647+
let output = CollabHandler
648+
.handle(invocation)
649+
.await
650+
.expect("close_agent should succeed");
651+
let ToolOutput::Function {
652+
content, success, ..
653+
} = output
654+
else {
655+
panic!("expected function output");
600656
};
601-
assert_eq!(
602-
err,
603-
FunctionCallError::Fatal("close_agent not implemented".to_string())
604-
);
657+
let result: close_agent::CloseAgentResult =
658+
serde_json::from_str(&content).expect("close_agent result should be json");
659+
assert_eq!(result.status, status_before);
660+
assert_eq!(success, Some(true));
661+
662+
let ops = manager.captured_ops();
663+
let submitted_shutdown = ops
664+
.iter()
665+
.any(|(id, op)| *id == agent_id && matches!(op, Op::Shutdown));
666+
assert_eq!(submitted_shutdown, true);
605667
}
606668

607669
#[tokio::test]

0 commit comments

Comments
 (0)