Skip to content

Commit 041bf30

Browse files
committed
async: add shutdown module.
It is used for server-side graceful shutdown. Signed-off-by: wanglei01 <[email protected]>
1 parent f4b3d90 commit 041bf30

File tree

2 files changed

+312
-0
lines changed

2 files changed

+312
-0
lines changed

src/asynchronous/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ mod stream;
1212
#[doc(hidden)]
1313
mod utils;
1414
mod unix_incoming;
15+
pub mod shutdown;
1516

1617
#[doc(inline)]
1718
pub use crate::r#async::client::Client;

src/asynchronous/shutdown.rs

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
// Copyright 2022 Alibaba Cloud. All rights reserved.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
//
5+
6+
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7+
use std::sync::Arc;
8+
9+
use tokio::sync::Notify;
10+
use tokio::time::{error::Elapsed, timeout, Duration};
11+
12+
#[derive(Debug)]
13+
struct Shared {
14+
shutdown: AtomicBool,
15+
notify_shutdown: Notify,
16+
17+
waiters: AtomicUsize,
18+
notify_exit: Notify,
19+
}
20+
21+
impl Shared {
22+
fn is_shutdown(&self) -> bool {
23+
self.shutdown.load(Ordering::Relaxed)
24+
}
25+
}
26+
27+
/// Wait for the shutdown notification.
28+
#[derive(Debug)]
29+
pub struct Waiter {
30+
shared: Arc<Shared>,
31+
}
32+
33+
/// Used to Notify all [`Waiter`s](Waiter) shutdown.
34+
///
35+
/// No `Clone` is provided. If you want multiple instances, you can use Arc<Notifier>.
36+
/// Notifier will automatically call shutdown when dropping.
37+
#[derive(Debug)]
38+
pub struct Notifier {
39+
shared: Arc<Shared>,
40+
wait_time: Option<Duration>,
41+
}
42+
43+
/// Create a new shutdown pair([`Notifier`], [`Waiter`]) without timeout.
44+
///
45+
/// The [`Notifier`]
46+
pub fn new() -> (Notifier, Waiter) {
47+
_with_timeout(None)
48+
}
49+
50+
/// Create a new shutdown pair with the specified [`Duration`].
51+
///
52+
/// The [`Duration`] is used to specify the timeout of the [`Notifier::wait_all_exit()`].
53+
///
54+
/// [`Duration`]: tokio::time::Duration
55+
pub fn with_timeout(wait_time: Duration) -> (Notifier, Waiter) {
56+
_with_timeout(Some(wait_time))
57+
}
58+
59+
fn _with_timeout(wait_time: Option<Duration>) -> (Notifier, Waiter) {
60+
let shared = Arc::new(Shared {
61+
shutdown: AtomicBool::new(false),
62+
waiters: AtomicUsize::new(1),
63+
notify_shutdown: Notify::new(),
64+
notify_exit: Notify::new(),
65+
});
66+
67+
let notifier = Notifier {
68+
shared: shared.clone(),
69+
wait_time,
70+
};
71+
72+
let waiter = Waiter { shared };
73+
74+
(notifier, waiter)
75+
}
76+
77+
impl Waiter {
78+
/// Return `true` if the [`Notifier::shutdown()`] has been called.
79+
///
80+
/// [`Notifier::shutdown()`]: Notifier::shutdown()
81+
pub fn is_shutdown(&self) -> bool {
82+
self.shared.is_shutdown()
83+
}
84+
85+
/// Waiting for the [`Notifier::shutdown()`] to be called.
86+
pub async fn wait_shutdown(&self) {
87+
while !self.is_shutdown() {
88+
let shutdown = self.shared.notify_shutdown.notified();
89+
if self.is_shutdown() {
90+
return;
91+
}
92+
shutdown.await;
93+
}
94+
}
95+
96+
fn from_shared(shared: Arc<Shared>) -> Self {
97+
shared.waiters.fetch_add(1, Ordering::Relaxed);
98+
Self { shared }
99+
}
100+
}
101+
102+
impl Clone for Waiter {
103+
fn clone(&self) -> Self {
104+
Self::from_shared(self.shared.clone())
105+
}
106+
}
107+
108+
impl Drop for Waiter {
109+
fn drop(&mut self) {
110+
if 1 == self.shared.waiters.fetch_sub(1, Ordering::Relaxed) {
111+
self.shared.notify_exit.notify_waiters();
112+
}
113+
}
114+
}
115+
116+
impl Notifier {
117+
/// Return `true` if the [`Notifier::shutdown()`] has been called.
118+
///
119+
/// [`Notifier::shutdown()`]: Notifier::shutdown()
120+
pub fn is_shutdown(&self) -> bool {
121+
self.shared.is_shutdown()
122+
}
123+
124+
/// Notify all [`Waiter`s](Waiter) shutdown.
125+
///
126+
/// It will cause all calls blocking at `Waiter::wait_shutdown().await` to return.
127+
pub fn shutdown(&self) {
128+
let is_shutdown = self.shared.shutdown.swap(true, Ordering::Relaxed);
129+
if !is_shutdown {
130+
self.shared.notify_shutdown.notify_waiters();
131+
}
132+
}
133+
134+
/// Return the num of all [`Waiter`]s.
135+
pub fn waiters(&self) -> usize {
136+
self.shared.waiters.load(Ordering::Relaxed)
137+
}
138+
139+
/// Create a new [`Waiter`].
140+
pub fn subscribe(&self) -> Waiter {
141+
Waiter::from_shared(self.shared.clone())
142+
}
143+
144+
/// Wait for all [`Waiter`]s to drop.
145+
pub async fn wait_all_exit(&self) -> Result<(), Elapsed> {
146+
//debug_assert!(self.shared.is_shutdown());
147+
if self.waiters() == 0 {
148+
return Ok(());
149+
}
150+
let wait = self.wait();
151+
if self.waiters() == 0 {
152+
return Ok(());
153+
}
154+
wait.await
155+
}
156+
157+
async fn wait(&self) -> Result<(), Elapsed> {
158+
if let Some(tm) = self.wait_time {
159+
timeout(tm, self.shared.notify_exit.notified()).await
160+
} else {
161+
self.shared.notify_exit.notified().await;
162+
Ok(())
163+
}
164+
}
165+
}
166+
167+
impl Drop for Notifier {
168+
fn drop(&mut self) {
169+
self.shutdown()
170+
}
171+
}
172+
173+
#[cfg(test)]
174+
mod test {
175+
use super::*;
176+
177+
#[tokio::test]
178+
async fn it_work() {
179+
let (notifier, waiter) = new();
180+
181+
let task = tokio::spawn(async move {
182+
waiter.wait_shutdown().await;
183+
});
184+
185+
assert_eq!(notifier.waiters(), 1);
186+
notifier.shutdown();
187+
task.await.unwrap();
188+
assert_eq!(notifier.waiters(), 0);
189+
}
190+
191+
#[tokio::test]
192+
async fn notifier_drop() {
193+
let (notifier, waiter) = new();
194+
assert_eq!(notifier.waiters(), 1);
195+
assert!(!waiter.is_shutdown());
196+
drop(notifier);
197+
assert!(waiter.is_shutdown());
198+
assert_eq!(waiter.shared.waiters.load(Ordering::Relaxed), 1);
199+
}
200+
201+
#[tokio::test]
202+
async fn waiter_clone() {
203+
let (notifier, waiter1) = new();
204+
assert_eq!(notifier.waiters(), 1);
205+
206+
let waiter2 = waiter1.clone();
207+
assert_eq!(notifier.waiters(), 2);
208+
209+
let waiter3 = notifier.subscribe();
210+
assert_eq!(notifier.waiters(), 3);
211+
212+
drop(waiter2);
213+
assert_eq!(notifier.waiters(), 2);
214+
215+
let task = tokio::spawn(async move {
216+
waiter3.wait_shutdown().await;
217+
assert!(waiter3.is_shutdown());
218+
});
219+
220+
assert!(!waiter1.is_shutdown());
221+
notifier.shutdown();
222+
assert!(waiter1.is_shutdown());
223+
224+
task.await.unwrap();
225+
226+
assert_eq!(notifier.waiters(), 1);
227+
}
228+
229+
#[tokio::test]
230+
async fn concurrency_notifier_shutdown() {
231+
let (notifier, waiter) = new();
232+
let arc_notifier = Arc::new(notifier);
233+
let notifier1 = arc_notifier.clone();
234+
let notifier2 = notifier1.clone();
235+
236+
let task1 = tokio::spawn(async move {
237+
assert_eq!(notifier1.waiters(), 1);
238+
239+
let waiter = notifier1.subscribe();
240+
assert_eq!(notifier1.waiters(), 2);
241+
242+
notifier1.shutdown();
243+
waiter.wait_shutdown().await;
244+
});
245+
246+
let task2 = tokio::spawn(async move {
247+
assert_eq!(notifier2.waiters(), 1);
248+
notifier2.shutdown();
249+
});
250+
waiter.wait_shutdown().await;
251+
assert!(arc_notifier.is_shutdown());
252+
task1.await.unwrap();
253+
task2.await.unwrap();
254+
}
255+
256+
#[tokio::test]
257+
async fn concurrency_notifier_wait() {
258+
let (notifier, waiter) = new();
259+
let arc_notifier = Arc::new(notifier);
260+
let notifier1 = arc_notifier.clone();
261+
let notifier2 = notifier1.clone();
262+
263+
let task1 = tokio::spawn(async move {
264+
notifier1.shutdown();
265+
notifier1.wait_all_exit().await.unwrap();
266+
});
267+
268+
let task2 = tokio::spawn(async move {
269+
notifier2.shutdown();
270+
notifier2.wait_all_exit().await.unwrap();
271+
});
272+
273+
waiter.wait_shutdown().await;
274+
drop(waiter);
275+
task1.await.unwrap();
276+
task2.await.unwrap();
277+
}
278+
279+
#[tokio::test]
280+
async fn wait_all_exit() {
281+
let (notifier, waiter) = new();
282+
let mut tasks = Vec::with_capacity(100);
283+
for i in 0..100 {
284+
assert_eq!(notifier.waiters(), 1 + i);
285+
let waiter1 = waiter.clone();
286+
tasks.push(tokio::spawn(async move {
287+
waiter1.wait_shutdown().await;
288+
}));
289+
}
290+
drop(waiter);
291+
assert_eq!(notifier.waiters(), 100);
292+
notifier.shutdown();
293+
notifier.wait_all_exit().await.unwrap();
294+
for t in tasks {
295+
t.await.unwrap();
296+
}
297+
}
298+
299+
#[tokio::test]
300+
async fn wait_timeout() {
301+
let (notifier, waiter) = with_timeout(Duration::from_millis(100));
302+
let task = tokio::spawn(async move {
303+
waiter.wait_shutdown().await;
304+
tokio::time::sleep(Duration::from_millis(200)).await;
305+
});
306+
notifier.shutdown();
307+
// Elapsed
308+
assert!(matches!(notifier.wait_all_exit().await, Err(_)));
309+
task.await.unwrap();
310+
}
311+
}

0 commit comments

Comments
 (0)