Skip to content

Commit ec0c5e5

Browse files
agu-zConradIrwin
andcommitted
Do not return handler task
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
1 parent 09263a9 commit ec0c5e5

File tree

2 files changed

+65
-73
lines changed

2 files changed

+65
-73
lines changed

rust/acp.rs

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use serde_json::value::RawValue;
2121
use std::{
2222
collections::HashMap,
2323
fmt::Display,
24+
rc::Rc,
2425
sync::{
2526
Arc,
2627
atomic::{AtomicI32, Ordering::SeqCst},
@@ -40,14 +41,10 @@ impl AgentConnection {
4041
handler: H,
4142
outgoing_bytes: impl Unpin + AsyncWrite,
4243
incoming_bytes: impl Unpin + AsyncRead,
43-
spawn: impl Fn(LocalBoxFuture<'static, ()>),
44-
) -> (
45-
Self,
46-
impl Future<Output = ()>,
47-
impl Future<Output = Result<()>>,
48-
) {
44+
spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
45+
) -> (Self, impl Future<Output = Result<()>>) {
4946
let handler = Arc::new(handler);
50-
let (connection, handler_task, io_task) = Connection::new(
47+
let (connection, io_task) = Connection::new(
5148
Box::new(move |request| {
5249
let handler = handler.clone();
5350
async move { handler.call(request).await }.boxed_local()
@@ -56,7 +53,7 @@ impl AgentConnection {
5653
incoming_bytes,
5754
spawn,
5855
);
59-
(Self(connection), handler_task, io_task)
56+
(Self(connection), io_task)
6057
}
6158

6259
/// Send a request to the agent and wait for a response.
@@ -83,14 +80,10 @@ impl ClientConnection {
8380
handler: H,
8481
outgoing_bytes: impl Unpin + AsyncWrite,
8582
incoming_bytes: impl Unpin + AsyncRead,
86-
spawn: impl Fn(LocalBoxFuture<'static, ()>),
87-
) -> (
88-
Self,
89-
impl Future<Output = ()>,
90-
impl Future<Output = Result<()>>,
91-
) {
83+
spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
84+
) -> (Self, impl Future<Output = Result<()>>) {
9285
let handler = Arc::new(handler);
93-
let (connection, handler_task, io_task) = Connection::new(
86+
let (connection, io_task) = Connection::new(
9487
Box::new(move |request| {
9588
let handler = handler.clone();
9689
async move { handler.call(request).await }.boxed_local()
@@ -99,7 +92,7 @@ impl ClientConnection {
9992
incoming_bytes,
10093
spawn,
10194
);
102-
(Self(connection), handler_task, io_task)
95+
(Self(connection), io_task)
10396
}
10497

10598
pub fn request<R: ClientRequest>(
@@ -193,28 +186,24 @@ where
193186
request_handler: Box<dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<In::Response>>>,
194187
outgoing_bytes: impl Unpin + AsyncWrite,
195188
incoming_bytes: impl Unpin + AsyncRead,
196-
spawn: impl Fn(LocalBoxFuture<'static, ()>),
197-
) -> (
198-
Self,
199-
impl Future<Output = ()>,
200-
impl Future<Output = Result<()>>,
201-
) {
189+
spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
190+
) -> (Self, impl Future<Output = Result<()>>) {
202191
let (outgoing_tx, outgoing_rx) = mpsc::unbounded();
203192
let (incoming_tx, incoming_rx) = mpsc::unbounded();
204193
let this = Self {
205194
response_senders: ResponseSenders::default(),
206195
outgoing_tx: outgoing_tx.clone(),
207196
next_id: AtomicI32::new(0),
208197
};
209-
let handler_task = Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
198+
Self::handle_incoming(outgoing_tx, incoming_rx, request_handler, spawn);
210199
let io_task = Self::handle_io(
211200
outgoing_rx,
212201
incoming_tx,
213202
this.response_senders.clone(),
214203
outgoing_bytes,
215204
incoming_bytes,
216205
);
217-
(this, handler_task, io_task)
206+
(this, io_task)
218207
}
219208

220209
fn request(
@@ -308,41 +297,48 @@ where
308297
Ok(())
309298
}
310299

311-
async fn handle_incoming(
300+
fn handle_incoming(
312301
outgoing_tx: UnboundedSender<OutgoingMessage<Out, In::Response>>,
313302
mut incoming_rx: UnboundedReceiver<(i32, In)>,
314303
incoming_handler: Box<
315304
dyn 'static + Fn(In) -> LocalBoxFuture<'static, Result<In::Response>>,
316305
>,
317-
spawn: impl Fn(LocalBoxFuture<'static, ()>),
306+
spawn: impl Fn(LocalBoxFuture<'static, ()>) + 'static,
318307
) {
319-
while let Some((id, params)) = incoming_rx.next().await {
320-
let result = incoming_handler(params);
321-
let outgoing_tx = outgoing_tx.clone();
322-
spawn(
323-
async move {
324-
let result = result.await;
325-
match result {
326-
Ok(result) => {
327-
outgoing_tx
328-
.unbounded_send(OutgoingMessage::OkResponse { id, result })
329-
.ok();
330-
}
331-
Err(error) => {
332-
outgoing_tx
333-
.unbounded_send(OutgoingMessage::ErrorResponse {
334-
id,
335-
error: Error {
336-
code: -32603,
337-
message: error.to_string(),
338-
},
339-
})
340-
.ok();
308+
let spawn = Rc::new(spawn);
309+
let spawn2 = spawn.clone();
310+
spawn(
311+
async move {
312+
while let Some((id, params)) = incoming_rx.next().await {
313+
let result = incoming_handler(params);
314+
let outgoing_tx = outgoing_tx.clone();
315+
spawn2(
316+
async move {
317+
let result = result.await;
318+
match result {
319+
Ok(result) => {
320+
outgoing_tx
321+
.unbounded_send(OutgoingMessage::OkResponse { id, result })
322+
.ok();
323+
}
324+
Err(error) => {
325+
outgoing_tx
326+
.unbounded_send(OutgoingMessage::ErrorResponse {
327+
id,
328+
error: Error {
329+
code: -32603,
330+
message: error.to_string(),
331+
},
332+
})
333+
.ok();
334+
}
335+
}
341336
}
342-
}
337+
.boxed_local(),
338+
)
343339
}
344-
.boxed_local(),
345-
)
346-
}
340+
}
341+
.boxed_local(),
342+
)
347343
}
348344
}

rust/acp_tests.rs

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,27 +66,23 @@ async fn test_client_agent_communication() {
6666
let (client_to_agent_tx, client_to_agent_rx) = async_pipe::pipe();
6767
let (agent_to_client_tx, agent_to_client_rx) = async_pipe::pipe();
6868

69-
let (client_connection, client_handle_task, client_io_task) =
70-
AgentConnection::connect_to_agent(
71-
client,
72-
client_to_agent_tx,
73-
agent_to_client_rx,
74-
|fut| {
75-
tokio::task::spawn_local(fut);
76-
},
77-
);
78-
let (agent_connection, agent_handle_task, agent_io_task) =
79-
ClientConnection::connect_to_client(
80-
agent,
81-
agent_to_client_tx,
82-
client_to_agent_rx,
83-
|fut| {
84-
tokio::task::spawn_local(fut);
85-
},
86-
);
87-
88-
let _task = tokio::task::spawn_local(client_handle_task);
89-
let _task = tokio::task::spawn_local(agent_handle_task);
69+
let (client_connection, client_io_task) = AgentConnection::connect_to_agent(
70+
client,
71+
client_to_agent_tx,
72+
agent_to_client_rx,
73+
|fut| {
74+
tokio::task::spawn_local(fut);
75+
},
76+
);
77+
let (agent_connection, agent_io_task) = ClientConnection::connect_to_client(
78+
agent,
79+
agent_to_client_tx,
80+
client_to_agent_rx,
81+
|fut| {
82+
tokio::task::spawn_local(fut);
83+
},
84+
);
85+
9086
let _task = tokio::spawn(client_io_task);
9187
let _task = tokio::spawn(agent_io_task);
9288

0 commit comments

Comments
 (0)