Skip to content

Commit f255542

Browse files
authored
Simplify parallel (#4829)
make tool processing return a future and then collect futures. handle cleanup on Drop
1 parent 27f169b commit f255542

File tree

4 files changed

+58
-114
lines changed

4 files changed

+58
-114
lines changed

codex-rs/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

codex-rs/core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ tokio = { workspace = true, features = [
6161
"rt-multi-thread",
6262
"signal",
6363
] }
64-
tokio-util = { workspace = true }
64+
tokio-util = { workspace = true, features = ["rt"] }
6565
toml = { workspace = true }
6666
toml_edit = { workspace = true }
6767
tracing = { workspace = true, features = ["log"] }

codex-rs/core/src/codex.rs

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ use codex_protocol::protocol::SessionSource;
2323
use codex_protocol::protocol::TaskStartedEvent;
2424
use codex_protocol::protocol::TurnAbortReason;
2525
use codex_protocol::protocol::TurnContextItem;
26+
use futures::future::BoxFuture;
2627
use futures::prelude::*;
28+
use futures::stream::FuturesOrdered;
2729
use mcp_types::CallToolResult;
2830
use serde_json;
2931
use serde_json::Value;
@@ -2101,39 +2103,33 @@ async fn try_run_turn(
21012103
sess.persist_rollout_items(&[rollout_item]).await;
21022104
let mut stream = turn_context.client.clone().stream(&prompt).await?;
21032105

2104-
let mut output = Vec::new();
2105-
let mut tool_runtime = ToolCallRuntime::new(
2106+
let tool_runtime = ToolCallRuntime::new(
21062107
Arc::clone(&router),
21072108
Arc::clone(&sess),
21082109
Arc::clone(&turn_context),
21092110
Arc::clone(&turn_diff_tracker),
21102111
sub_id.to_string(),
21112112
);
2113+
let mut output: FuturesOrdered<BoxFuture<CodexResult<ProcessedResponseItem>>> =
2114+
FuturesOrdered::new();
21122115

21132116
loop {
21142117
// Poll the next item from the model stream. We must inspect *both* Ok and Err
21152118
// cases so that transient stream failures (e.g., dropped SSE connection before
21162119
// `response.completed`) bubble up and trigger the caller's retry logic.
21172120
let event = stream.next().await;
21182121
let event = match event {
2119-
Some(event) => event,
2122+
Some(res) => res?,
21202123
None => {
2121-
tool_runtime.abort_all();
21222124
return Err(CodexErr::Stream(
21232125
"stream closed before response.completed".into(),
21242126
None,
21252127
));
21262128
}
21272129
};
21282130

2129-
let event = match event {
2130-
Ok(ev) => ev,
2131-
Err(e) => {
2132-
tool_runtime.abort_all();
2133-
// Propagate the underlying stream error to the caller (run_turn), which
2134-
// will apply the configured `stream_max_retries` policy.
2135-
return Err(e);
2136-
}
2131+
let add_completed = &mut |response_item: ProcessedResponseItem| {
2132+
output.push_back(future::ready(Ok(response_item)).boxed());
21372133
};
21382134

21392135
match event {
@@ -2143,14 +2139,18 @@ async fn try_run_turn(
21432139
Ok(Some(call)) => {
21442140
let payload_preview = call.payload.log_payload().into_owned();
21452141
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
2146-
let index = output.len();
2147-
output.push(ProcessedResponseItem {
2148-
item,
2149-
response: None,
2150-
});
2151-
tool_runtime
2152-
.handle_tool_call(call, index, output.as_mut_slice())
2153-
.await?;
2142+
2143+
let response = tool_runtime.handle_tool_call(call);
2144+
2145+
output.push_back(
2146+
async move {
2147+
Ok(ProcessedResponseItem {
2148+
item,
2149+
response: Some(response.await?),
2150+
})
2151+
}
2152+
.boxed(),
2153+
);
21542154
}
21552155
Ok(None) => {
21562156
let response = handle_non_tool_response_item(
@@ -2160,7 +2160,7 @@ async fn try_run_turn(
21602160
item.clone(),
21612161
)
21622162
.await?;
2163-
output.push(ProcessedResponseItem { item, response });
2163+
add_completed(ProcessedResponseItem { item, response });
21642164
}
21652165
Err(FunctionCallError::MissingLocalShellCallId) => {
21662166
let msg = "LocalShellCall without call_id or id";
@@ -2177,7 +2177,7 @@ async fn try_run_turn(
21772177
success: None,
21782178
},
21792179
};
2180-
output.push(ProcessedResponseItem {
2180+
add_completed(ProcessedResponseItem {
21812181
item,
21822182
response: Some(response),
21832183
});
@@ -2190,7 +2190,7 @@ async fn try_run_turn(
21902190
success: None,
21912191
},
21922192
};
2193-
output.push(ProcessedResponseItem {
2193+
add_completed(ProcessedResponseItem {
21942194
item,
21952195
response: Some(response),
21962196
});
@@ -2221,7 +2221,7 @@ async fn try_run_turn(
22212221
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
22222222
.await;
22232223

2224-
tool_runtime.resolve_pending(output.as_mut_slice()).await?;
2224+
let processed_items: Vec<ProcessedResponseItem> = output.try_collect().await?;
22252225

22262226
let unified_diff = {
22272227
let mut tracker = turn_diff_tracker.lock().await;
@@ -2237,7 +2237,7 @@ async fn try_run_turn(
22372237
}
22382238

22392239
let result = TurnRunResult {
2240-
processed_items: output,
2240+
processed_items,
22412241
total_token_usage: token_usage.clone(),
22422242
};
22432243

Lines changed: 31 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use std::sync::Arc;
22

3-
use tokio::task::JoinHandle;
3+
use tokio::sync::RwLock;
4+
use tokio_util::either::Either;
5+
use tokio_util::task::AbortOnDropHandle;
46

57
use crate::codex::Session;
68
use crate::codex::TurnContext;
@@ -11,20 +13,13 @@ use crate::tools::router::ToolCall;
1113
use crate::tools::router::ToolRouter;
1214
use codex_protocol::models::ResponseInputItem;
1315

14-
use crate::codex::ProcessedResponseItem;
15-
16-
struct PendingToolCall {
17-
index: usize,
18-
handle: JoinHandle<Result<ResponseInputItem, FunctionCallError>>,
19-
}
20-
2116
pub(crate) struct ToolCallRuntime {
2217
router: Arc<ToolRouter>,
2318
session: Arc<Session>,
2419
turn_context: Arc<TurnContext>,
2520
tracker: SharedTurnDiffTracker,
2621
sub_id: String,
27-
pending_calls: Vec<PendingToolCall>,
22+
parallel_execution: Arc<RwLock<()>>,
2823
}
2924

3025
impl ToolCallRuntime {
@@ -41,97 +36,45 @@ impl ToolCallRuntime {
4136
turn_context,
4237
tracker,
4338
sub_id,
44-
pending_calls: Vec::new(),
39+
parallel_execution: Arc::new(RwLock::new(())),
4540
}
4641
}
4742

48-
pub(crate) async fn handle_tool_call(
49-
&mut self,
43+
pub(crate) fn handle_tool_call(
44+
&self,
5045
call: ToolCall,
51-
output_index: usize,
52-
output: &mut [ProcessedResponseItem],
53-
) -> Result<(), CodexErr> {
46+
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
5447
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
55-
if supports_parallel {
56-
self.spawn_parallel(call, output_index);
57-
} else {
58-
self.resolve_pending(output).await?;
59-
let response = self.dispatch_serial(call).await?;
60-
let slot = output.get_mut(output_index).ok_or_else(|| {
61-
CodexErr::Fatal(format!("tool output index {output_index} out of bounds"))
62-
})?;
63-
slot.response = Some(response);
64-
}
65-
66-
Ok(())
67-
}
68-
69-
pub(crate) fn abort_all(&mut self) {
70-
while let Some(pending) = self.pending_calls.pop() {
71-
pending.handle.abort();
72-
}
73-
}
74-
75-
pub(crate) async fn resolve_pending(
76-
&mut self,
77-
output: &mut [ProcessedResponseItem],
78-
) -> Result<(), CodexErr> {
79-
while let Some(PendingToolCall { index, handle }) = self.pending_calls.pop() {
80-
match handle.await {
81-
Ok(Ok(response)) => {
82-
if let Some(slot) = output.get_mut(index) {
83-
slot.response = Some(response);
84-
}
85-
}
86-
Ok(Err(FunctionCallError::Fatal(message))) => {
87-
self.abort_all();
88-
return Err(CodexErr::Fatal(message));
89-
}
90-
Ok(Err(other)) => {
91-
self.abort_all();
92-
return Err(CodexErr::Fatal(other.to_string()));
93-
}
94-
Err(join_err) => {
95-
self.abort_all();
96-
return Err(CodexErr::Fatal(format!(
97-
"tool task failed to join: {join_err}"
98-
)));
99-
}
100-
}
101-
}
102-
103-
Ok(())
104-
}
10548

106-
fn spawn_parallel(&mut self, call: ToolCall, index: usize) {
10749
let router = Arc::clone(&self.router);
10850
let session = Arc::clone(&self.session);
10951
let turn = Arc::clone(&self.turn_context);
11052
let tracker = Arc::clone(&self.tracker);
11153
let sub_id = self.sub_id.clone();
112-
let handle = tokio::spawn(async move {
113-
router
114-
.dispatch_tool_call(session, turn, tracker, sub_id, call)
115-
.await
116-
});
117-
self.pending_calls.push(PendingToolCall { index, handle });
118-
}
54+
let lock = Arc::clone(&self.parallel_execution);
55+
56+
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
57+
AbortOnDropHandle::new(tokio::spawn(async move {
58+
let _guard = if supports_parallel {
59+
Either::Left(lock.read().await)
60+
} else {
61+
Either::Right(lock.write().await)
62+
};
63+
64+
router
65+
.dispatch_tool_call(session, turn, tracker, sub_id, call)
66+
.await
67+
}));
11968

120-
async fn dispatch_serial(&self, call: ToolCall) -> Result<ResponseInputItem, CodexErr> {
121-
match self
122-
.router
123-
.dispatch_tool_call(
124-
Arc::clone(&self.session),
125-
Arc::clone(&self.turn_context),
126-
Arc::clone(&self.tracker),
127-
self.sub_id.clone(),
128-
call,
129-
)
130-
.await
131-
{
132-
Ok(response) => Ok(response),
133-
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
134-
Err(other) => Err(CodexErr::Fatal(other.to_string())),
69+
async move {
70+
match handle.await {
71+
Ok(Ok(response)) => Ok(response),
72+
Ok(Err(FunctionCallError::Fatal(message))) => Err(CodexErr::Fatal(message)),
73+
Ok(Err(other)) => Err(CodexErr::Fatal(other.to_string())),
74+
Err(err) => Err(CodexErr::Fatal(format!(
75+
"tool task failed to receive: {err:?}"
76+
))),
77+
}
13578
}
13679
}
13780
}

0 commit comments

Comments
 (0)