Skip to content

Commit bc395cb

Browse files
committed
async: graceful shutdown and server restart supported
- Support graceful shutdown. - Support stop_listen -> start to restart. The hot upgrade needs the restart feature. First call stop_listen() to stop new connections coming then call disconnect() to wait all exist request done. if there are failures during stop_listen() and disconnect(), only need to call start() to make rollback. Signed-off-by: Tim Zhang <[email protected]>
1 parent d859795 commit bc395cb

File tree

1 file changed

+134
-37
lines changed

1 file changed

+134
-37
lines changed

src/asynchronous/server.rs

Lines changed: 134 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// SPDX-License-Identifier: Apache-2.0
44
//
55

6+
use nix::unistd;
67
use protobuf::{CodedInputStream, Message};
78
use std::collections::HashMap;
89
use std::os::unix::io::RawFd;
@@ -25,6 +26,7 @@ use tokio::{
2526
prelude::*,
2627
stream::Stream,
2728
sync::mpsc::{channel, Receiver, Sender},
29+
sync::watch,
2830
};
2931
use tokio_vsock::VsockListener;
3032

@@ -33,6 +35,9 @@ pub struct Server {
3335
listeners: Vec<RawFd>,
3436
methods: Arc<HashMap<String, Box<dyn MethodHandler + Send + Sync>>>,
3537
domain: Option<Domain>,
38+
disconnect_tx: Option<watch::Sender<i32>>,
39+
all_conn_done_rx: Option<Receiver<i32>>,
40+
stop_listen_tx: Option<Sender<Sender<RawFd>>>,
3641
}
3742

3843
impl Default for Server {
@@ -41,6 +46,9 @@ impl Default for Server {
4146
listeners: Vec::with_capacity(1),
4247
methods: Arc::new(HashMap::new()),
4348
domain: None,
49+
disconnect_tx: None,
50+
all_conn_done_rx: None,
51+
stop_listen_tx: None,
4452
}
4553
}
4654
}
@@ -60,6 +68,7 @@ impl Server {
6068
let (fd, domain) = common::do_bind(host)?;
6169
self.domain = Some(domain);
6270

71+
common::do_listen(fd)?;
6372
self.listeners.push(fd);
6473
Ok(self)
6574
}
@@ -79,30 +88,27 @@ impl Server {
7988
self
8089
}
8190

82-
fn listen(&self) -> Result<RawFd> {
91+
fn get_listenfd(&self) -> Result<RawFd> {
8392
if self.listeners.is_empty() {
8493
return Err(Error::Others("ttrpc-rust not bind".to_string()));
8594
}
8695

87-
let listenfd = self.listeners[0];
88-
common::do_listen(listenfd)?;
89-
96+
let listenfd = self.listeners[self.listeners.len() - 1];
9097
Ok(listenfd)
9198
}
9299

93-
pub async fn start(&self) -> Result<()> {
94-
let listenfd = self.listen()?;
100+
pub async fn start(&mut self) -> Result<()> {
101+
let listenfd = self.get_listenfd()?;
95102

96103
match self.domain.as_ref().unwrap() {
97104
Domain::Unix => {
98105
let sys_unix_listener;
99106
unsafe {
100107
sys_unix_listener = SysUnixListener::from_raw_fd(listenfd);
101108
}
102-
let mut unix_listener = UnixListener::from_std(sys_unix_listener).unwrap();
103-
let incoming = unix_listener.incoming();
109+
let unix_listener = UnixListener::from_std(sys_unix_listener).unwrap();
104110

105-
self.do_start(listenfd, incoming).await
111+
self.do_start(listenfd, unix_listener).await
106112
}
107113
Domain::Vsock => {
108114
let incoming;
@@ -115,52 +121,143 @@ impl Server {
115121
}
116122
}
117123

118-
pub async fn do_start<I, S>(&self, listenfd: RawFd, mut incoming: I) -> Result<()>
124+
pub async fn do_start<I, S>(&mut self, listenfd: RawFd, mut incoming: I) -> Result<()>
119125
where
120-
I: Stream<Item = std::io::Result<S>> + Unpin,
126+
I: Stream<Item = std::io::Result<S>> + Unpin + Send + 'static + AsRawFd,
121127
S: AsyncRead + AsyncWrite + AsRawFd + Send + 'static,
122128
{
123-
while let Some(result) = incoming.next().await {
124-
match result {
125-
Ok(stream) => {
126-
common::set_fd_close_exec(stream.as_raw_fd())?;
127-
let methods = self.methods.clone();
128-
tokio::spawn(async move {
129-
let (mut reader, mut writer) = split(stream);
130-
let (tx, mut rx): (Sender<Vec<u8>>, Receiver<Vec<u8>>) = channel(100);
131-
132-
tokio::spawn(async move {
133-
while let Some(buf) = rx.recv().await {
134-
if let Err(e) = writer.write_all(&buf).await {
135-
error!("write_message got error: {:?}", e);
136-
}
137-
}
138-
});
129+
let methods = self.methods.clone();
130+
let (disconnect_tx, close_conn_rx) = watch::channel(0);
131+
self.disconnect_tx = Some(disconnect_tx);
139132

140-
loop {
141-
let tx = tx.clone();
133+
let (conn_done_tx, all_conn_done_rx) = channel::<i32>(1);
134+
135+
self.all_conn_done_rx = Some(all_conn_done_rx);
136+
let (stop_listen_tx, mut stop_listen_rx) = channel(1);
137+
self.stop_listen_tx = Some(stop_listen_tx);
138+
139+
tokio::spawn(async move {
140+
loop {
141+
tokio::select! {
142+
conn = incoming.next() => {
143+
if let Some(conn) = conn {
144+
// Accept a new connection
142145
let methods = methods.clone();
146+
match conn {
147+
Ok(stream) => {
148+
let fd = stream.as_raw_fd();
149+
if let Err(e) = common::set_fd_close_exec(fd) {
150+
error!("{:?}", e);
151+
continue;
152+
}
153+
154+
let mut close_conn_rx = close_conn_rx.clone();
143155

144-
match receive(&mut reader).await {
145-
Ok(message) => {
156+
let (req_done_tx, mut all_req_done_rx) = channel::<i32>(1);
157+
let conn_done_tx2 = conn_done_tx.clone();
158+
159+
// The connection handler
146160
tokio::spawn(async move {
147-
handle_request(tx, listenfd, methods, message).await;
161+
let (mut reader, mut writer) = split(stream);
162+
let (tx, mut rx): (Sender<Vec<u8>>, Receiver<Vec<u8>>) = channel(100);
163+
164+
tokio::spawn(async move {
165+
while let Some(buf) = rx.recv().await {
166+
if let Err(e) = writer.write_all(&buf).await {
167+
error!("write_message got error: {:?}", e);
168+
}
169+
}
170+
});
171+
172+
loop {
173+
let tx = tx.clone();
174+
let methods = methods.clone();
175+
let req_done_tx2 = req_done_tx.clone();
176+
177+
tokio::select! {
178+
resp = receive(&mut reader) => {
179+
match resp {
180+
Ok(message) => {
181+
tokio::spawn(async move {
182+
handle_request(tx, listenfd, methods, message).await;
183+
drop(req_done_tx2);
184+
});
185+
}
186+
Err(e) => {
187+
trace!("error {:?}", e);
188+
break;
189+
}
190+
}
191+
}
192+
v = close_conn_rx.recv() => {
193+
// 0 is the init value of this watch, not a valid signal
194+
// is_none means the tx was dropped.
195+
if v.is_none() || v.unwrap() != 0 {
196+
info!("Stop accepting new connections.");
197+
break;
198+
}
199+
}
200+
}
201+
}
202+
203+
drop(req_done_tx);
204+
all_req_done_rx.recv().await;
205+
drop(conn_done_tx2);
148206
});
149207
}
150208
Err(e) => {
151-
trace!("error {:?}", e);
152-
break;
209+
error!("{:?}", e)
153210
}
154211
}
212+
213+
} else {
214+
break;
215+
}
216+
}
217+
fd_tx = stop_listen_rx.recv() => {
218+
if let Some(mut fd_tx) = fd_tx {
219+
let dup_fd = unistd::dup(incoming.as_raw_fd()).unwrap();
220+
common::set_fd_close_exec(dup_fd).unwrap();
221+
drop(incoming);
222+
223+
fd_tx.send(dup_fd).await.unwrap();
224+
break;
155225
}
156-
});
226+
}
157227
}
158-
Err(e) => error!("{:?}", e),
159228
}
160-
}
229+
drop(conn_done_tx);
230+
});
231+
Ok(())
232+
}
233+
234+
pub async fn shutdown(&mut self) -> Result<()> {
235+
self.stop_listen().await;
236+
self.disconnect().await;
161237

162238
Ok(())
163239
}
240+
241+
pub async fn disconnect(&mut self) {
242+
if let Some(tx) = self.disconnect_tx.take() {
243+
tx.broadcast(1).ok();
244+
}
245+
246+
if let Some(mut rx) = self.all_conn_done_rx.take() {
247+
rx.recv().await;
248+
}
249+
}
250+
251+
pub async fn stop_listen(&mut self) {
252+
if let Some(mut tx) = self.stop_listen_tx.take() {
253+
let (fd_tx, mut fd_rx) = channel(1);
254+
tx.send(fd_tx).await.unwrap();
255+
256+
let fd = fd_rx.recv().await.unwrap();
257+
self.listeners.clear();
258+
self.listeners.push(fd);
259+
}
260+
}
164261
}
165262

166263
async fn handle_request(

0 commit comments

Comments
 (0)