Skip to content

Commit 2ca9f98

Browse files
author
rusty
committed
Add join handles to wait tasks
1 parent 322fb2a commit 2ca9f98

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

src/listener/tcp_listener.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use async_std::prelude::*;
1010
use async_std::{io, task};
1111

1212
use futures_util::future::Either;
13+
use futures_util::stream::FuturesUnordered;
1314

1415
/// This represents a tide [Listener](crate::listener::Listener) that
1516
/// wraps an [async_std::net::TcpListener]. It is implemented as an
@@ -24,6 +25,7 @@ pub struct TcpListener<State> {
2425
listener: Option<net::TcpListener>,
2526
server: Option<Server<State>>,
2627
info: Option<ListenInfo>,
28+
join_handles: Vec<task::JoinHandle<()>>,
2729
}
2830

2931
impl<State> TcpListener<State> {
@@ -33,6 +35,7 @@ impl<State> TcpListener<State> {
3335
listener: None,
3436
server: None,
3537
info: None,
38+
join_handles: Vec::new(),
3639
}
3740
}
3841

@@ -42,11 +45,15 @@ impl<State> TcpListener<State> {
4245
listener: Some(tcp_listener.into()),
4346
server: None,
4447
info: None,
48+
join_handles: Vec::new(),
4549
}
4650
}
4751
}
4852

49-
fn handle_tcp<State: Clone + Send + Sync + 'static>(app: Server<State>, stream: TcpStream) {
53+
fn handle_tcp<State: Clone + Send + Sync + 'static>(
54+
app: Server<State>,
55+
stream: TcpStream,
56+
) -> task::JoinHandle<()> {
5057
task::spawn(async move {
5158
let local_addr = stream.local_addr().ok();
5259
let peer_addr = stream.peer_addr().ok();
@@ -60,7 +67,7 @@ fn handle_tcp<State: Clone + Send + Sync + 'static>(app: Server<State>, stream:
6067
if let Err(error) = fut.await {
6168
log::error!("async-h1 error", { error: error.to_string() });
6269
}
63-
});
70+
})
6471
}
6572

6673
#[async_trait::async_trait]
@@ -118,10 +125,19 @@ where
118125
}
119126

120127
Ok(stream) => {
121-
handle_tcp(server.clone(), stream);
128+
let handle = handle_tcp(server.clone(), stream);
129+
self.join_handles.push(handle);
122130
}
123131
};
124132
}
133+
134+
let join_handles = std::mem::take(&mut self.join_handles);
135+
join_handles
136+
.into_iter()
137+
.collect::<FuturesUnordered<task::JoinHandle<()>>>()
138+
.collect::<()>()
139+
.await;
140+
125141
Ok(())
126142
}
127143

src/listener/unix_listener.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use async_std::prelude::*;
1111
use async_std::{io, task};
1212

1313
use futures_util::future::Either;
14+
use futures_util::stream::FuturesUnordered;
1415

1516
/// This represents a tide [Listener](crate::listener::Listener) that
1617
/// wraps an [async_std::os::unix::net::UnixListener]. It is implemented as an
@@ -25,6 +26,7 @@ pub struct UnixListener<State> {
2526
listener: Option<net::UnixListener>,
2627
server: Option<Server<State>>,
2728
info: Option<ListenInfo>,
29+
join_handles: Vec<task::JoinHandle<()>>,
2830
}
2931

3032
impl<State> UnixListener<State> {
@@ -34,6 +36,7 @@ impl<State> UnixListener<State> {
3436
listener: None,
3537
server: None,
3638
info: None,
39+
join_handles: Vec::new(),
3740
}
3841
}
3942

@@ -43,11 +46,15 @@ impl<State> UnixListener<State> {
4346
listener: Some(unix_listener.into()),
4447
server: None,
4548
info: None,
49+
join_handles: Vec::new(),
4650
}
4751
}
4852
}
4953

50-
fn handle_unix<State: Clone + Send + Sync + 'static>(app: Server<State>, stream: UnixStream) {
54+
fn handle_unix<State: Clone + Send + Sync + 'static>(
55+
app: Server<State>,
56+
stream: UnixStream,
57+
) -> task::JoinHandle<()> {
5158
task::spawn(async move {
5259
let local_addr = unix_socket_addr_to_string(stream.local_addr());
5360
let peer_addr = unix_socket_addr_to_string(stream.peer_addr());
@@ -61,7 +68,7 @@ fn handle_unix<State: Clone + Send + Sync + 'static>(app: Server<State>, stream:
6168
if let Err(error) = fut.await {
6269
log::error!("async-h1 error", { error: error.to_string() });
6370
}
64-
});
71+
})
6572
}
6673

6774
#[async_trait::async_trait]
@@ -116,10 +123,19 @@ where
116123
}
117124

118125
Ok(stream) => {
119-
handle_unix(server.clone(), stream);
126+
let handle = handle_unix(server.clone(), stream);
127+
self.join_handles.push(handle);
120128
}
121129
};
122130
}
131+
132+
let join_handles = std::mem::take(&mut self.join_handles);
133+
join_handles
134+
.into_iter()
135+
.collect::<FuturesUnordered<task::JoinHandle<()>>>()
136+
.collect::<()>()
137+
.await;
138+
123139
Ok(())
124140
}
125141

0 commit comments

Comments
 (0)