From 69f9577719bfb63aaa94ec0daf8a0ad067d03f44 Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 12:51:12 +0900 Subject: [PATCH 01/12] Add StopSource and StopToken cancellation types --- Cargo.toml | 1 + src/cancellation.rs | 218 ++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 3 files changed, 220 insertions(+) create mode 100644 src/cancellation.rs diff --git a/Cargo.toml b/Cargo.toml index 8f5d6f19..c3a23356 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ async-session = { version = "3.0", optional = true } async-sse = "4.0.1" async-std = { version = "1.6.5", features = ["unstable"] } async-trait = "0.1.41" +event-listener = "2.5.1" femme = { version = "2.1.1", optional = true } futures-util = "0.3.6" http-client = { version = "6.1.0", default-features = false } diff --git a/src/cancellation.rs b/src/cancellation.rs new file mode 100644 index 00000000..edc06f5d --- /dev/null +++ b/src/cancellation.rs @@ -0,0 +1,218 @@ +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; + +use async_std::future::Future; +use async_std::stream::Stream; +use async_std::sync::Arc; +use async_std::task::{Context, Poll}; + +use event_listener::{Event, EventListener}; +use pin_project_lite::pin_project; + +#[derive(Debug)] +pub struct StopSource { + stopped: Arc, + event: Arc, +} + +impl StopSource { + pub fn new() -> Self { + Self { + stopped: Arc::new(AtomicBool::new(false)), + event: Arc::new(Event::new()), + } + } + + pub fn token(&self) -> StopToken { + StopToken { + stopped: self.stopped.clone(), + event_listener: self.event.listen(), + event: self.event.clone(), + } + } +} + +impl Drop for StopSource { + fn drop(&mut self) { + self.stopped.store(true, Ordering::SeqCst); + self.event.notify(usize::MAX); + } +} + +pin_project! { + #[derive(Debug)] + pub struct StopToken { + #[pin] + stopped: Arc, + #[pin] + event_listener: EventListener, + event: Arc, + } +} + +impl StopToken { + pub fn never() -> Self { + let event = Event::new(); + Self { + stopped: Arc::new(AtomicBool::new(false)), + event_listener: event.listen(), + event: Arc::new(event), + } + } +} + +impl Clone for StopToken { + fn clone(&self) -> Self { + Self { + stopped: self.stopped.clone(), + event_listener: self.event.listen(), + event: self.event.clone(), + } + } +} + +impl Future for StopToken { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + let _ = Future::poll(Pin::new(&mut this.event_listener), cx); + if this.stopped.load(Ordering::Relaxed) { + Poll::Ready(()) + } else { + Poll::Pending + } + } +} + +pin_project! { + #[derive(Debug)] + pub struct StopStream { + #[pin] + stream: S, + #[pin] + stop_token: StopToken, + } +} + +impl StopStream { + pub fn new(stream: S, stop_token: StopToken) -> Self { + Self { stream, stop_token } + } +} + +impl Stream for StopStream +where + S: Stream, +{ + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + if Future::poll(Pin::new(&mut this.stop_token), cx).is_ready() { + Poll::Ready(None) + } else { + this.stream.poll_next(cx) + } + } +} + +pub trait StopStreamExt: Sized { + fn stop_on(self, stop_token: StopToken) -> StopStream { + StopStream::new(self, stop_token) + } +} + +impl StopStreamExt for S where S: Stream {} + +pin_project! { + #[derive(Debug)] + pub struct StopFuture { + #[pin] + future: F, + #[pin] + stop_token: StopToken, + } +} + +impl StopFuture { + pub fn new(future: F, stop_token: StopToken) -> Self { + Self { future, stop_token } + } +} + +impl Future for StopFuture +where + F: Future, +{ + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + if Future::poll(Pin::new(&mut this.stop_token), cx).is_ready() { + Poll::Ready(None) + } else { + match this.future.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(output) => Poll::Ready(Some(output)), + } + } + } +} + +pub trait StopFutureExt: Sized { + fn stop_on(self, stop_token: StopToken) -> StopFuture { + StopFuture::new(self, stop_token) + } +} + +impl StopFutureExt for F where F: Future {} + +#[cfg(test)] +mod tests { + use std::thread; + use std::time::Duration; + + use async_std::prelude::{FutureExt, StreamExt}; + + use super::*; + + #[test] + fn test_cancellation() { + let source = StopSource::new(); + let stop_token = source.token(); + + let pending_stream1 = async_std::stream::pending::<()>(); + let pending_stream2 = async_std::stream::pending::<()>(); + let pending_future1 = async_std::future::pending::<()>(); + let pending_future2 = async_std::future::pending::<()>(); + let wrapped_stream1 = pending_stream1.stop_on(stop_token.clone()); + let wrapped_stream2 = pending_stream2.stop_on(stop_token.clone()); + let wrapped_future1 = pending_future1.stop_on(stop_token.clone()); + let wrapped_future2 = pending_future2.stop_on(stop_token); + + let join_future = wrapped_stream1 + .last() + .join(wrapped_stream2.last()) + .join(wrapped_future1) + .join(wrapped_future2); + + thread::spawn(move || { + let source = source; + thread::sleep(Duration::from_secs(1)); + drop(source); + }); + + let res = async_std::task::block_on(join_future); + assert_eq!(res, (((None, None), None), None)); + } + + #[test] + fn test_never() { + let pending_future = async_std::future::pending::<()>(); + let wrapped_future = pending_future.stop_on(StopToken::never()); + + let res = async_std::task::block_on(wrapped_future.timeout(Duration::from_secs(1))); + assert!(res.is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 3245f23f..516128e9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,6 +76,7 @@ mod route; mod router; mod server; +pub mod cancellation; pub mod convert; pub mod listener; pub mod log; From dd219e55dd5c09e42fdbf58d880155ba2715e761 Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 12:51:28 +0900 Subject: [PATCH 02/12] Add stop_token to server --- src/server.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/server.rs b/src/server.rs index 1e6f8c1a..46dbdbf5 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,6 +3,7 @@ use async_std::io; use async_std::sync::Arc; +use crate::cancellation::StopToken; #[cfg(feature = "cookies")] use crate::cookies; use crate::listener::{Listener, ToListener}; @@ -38,6 +39,7 @@ pub struct Server { /// We don't use a Mutex around the Vec here because adding a middleware during execution should be an error. #[allow(clippy::rc_buffer)] middleware: Arc>>>, + pub(crate) stop_token: StopToken, } impl Server<()> { @@ -113,6 +115,7 @@ where Arc::new(log::LogMiddleware::new()), ]), state, + stop_token: StopToken::never(), } } @@ -286,6 +289,7 @@ where router, state, middleware, + stop_token: _, } = self.clone(); let method = req.method().to_owned(); @@ -317,6 +321,11 @@ where pub fn state(&self) -> &State { &self.state } + + pub fn stop_on(&mut self, stop_token: StopToken) -> &mut Self { + self.stop_token = stop_token; + self + } } impl std::fmt::Debug for Server { @@ -331,6 +340,7 @@ impl Clone for Server { router: self.router.clone(), state: self.state.clone(), middleware: self.middleware.clone(), + stop_token: self.stop_token.clone(), } } } From e6e987cc6e564dbc0183495d7c26848337df7a4c Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 12:51:36 +0900 Subject: [PATCH 03/12] Use stop_token in listeners --- src/listener/tcp_listener.rs | 3 ++- src/listener/unix_listener.rs | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index 7b86a013..7e1973cf 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -1,5 +1,6 @@ use super::{is_transient_error, ListenInfo}; +use crate::cancellation::StopStreamExt; use crate::listener::Listener; use crate::{log, Server}; @@ -98,7 +99,7 @@ where .take() .expect("`Listener::bind` must be called before `Listener::accept`"); - let mut incoming = listener.incoming(); + let mut incoming = listener.incoming().stop_on(server.stop_token.clone()); while let Some(stream) = incoming.next().await { match stream { diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index d99a21d3..a767d223 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -1,5 +1,6 @@ use super::{is_transient_error, ListenInfo}; +use crate::cancellation::StopStreamExt; use crate::listener::Listener; use crate::{log, Server}; @@ -96,7 +97,7 @@ where .take() .expect("`Listener::bind` must be called before `Listener::accept`"); - let mut incoming = listener.incoming(); + let mut incoming = listener.incoming().stop_on(server.stop_token.clone()); while let Some(stream) = incoming.next().await { match stream { From 3e5198f42a77e566eec29bb1e8dd0613d217d7ea Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 13:19:06 +0900 Subject: [PATCH 04/12] Add documentations for cancellation module --- src/cancellation.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/cancellation.rs b/src/cancellation.rs index edc06f5d..d79d291c 100644 --- a/src/cancellation.rs +++ b/src/cancellation.rs @@ -1,3 +1,5 @@ +//! Future and Stream cancellation + use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; @@ -9,6 +11,7 @@ use async_std::task::{Context, Poll}; use event_listener::{Event, EventListener}; use pin_project_lite::pin_project; +/// StopSource produces [StopToken] and cancels all of its tokens on drop. #[derive(Debug)] pub struct StopSource { stopped: Arc, @@ -16,6 +19,7 @@ pub struct StopSource { } impl StopSource { + /// Create a new StopSource pub fn new() -> Self { Self { stopped: Arc::new(AtomicBool::new(false)), @@ -23,6 +27,9 @@ impl StopSource { } } + /// Produce a new [StopToken], associated with this source. + /// + /// Once this source is dropped, all associated [StopToken] futures will complete. pub fn token(&self) -> StopToken { StopToken { stopped: self.stopped.clone(), @@ -40,6 +47,7 @@ impl Drop for StopSource { } pin_project! { + /// StopToken is a future which completes when the associated [StopSource] is dropped. #[derive(Debug)] pub struct StopToken { #[pin] @@ -51,6 +59,7 @@ pin_project! { } impl StopToken { + /// Produce a StopToken that associates with no [StopSource], and never completes. pub fn never() -> Self { let event = Event::new(); Self { @@ -86,6 +95,9 @@ impl Future for StopToken { } pin_project! { + /// A stream that early exits when inner [StopToken] completes. + /// + /// Users usually do not need to construct this type manually, but rather use the [StopStreamExt::stop_on] method instead. #[derive(Debug)] pub struct StopStream { #[pin] @@ -96,6 +108,7 @@ pin_project! { } impl StopStream { + /// Wraps a stream to exit early when `stop_token` completes. pub fn new(stream: S, stop_token: StopToken) -> Self { Self { stream, stop_token } } @@ -117,7 +130,9 @@ where } } +/// Stream extensions to generate [StopStream] that exits early when `stop_token` completes. pub trait StopStreamExt: Sized { + /// Wraps a stream to exit early when `stop_token` completes. fn stop_on(self, stop_token: StopToken) -> StopStream { StopStream::new(self, stop_token) } @@ -126,6 +141,9 @@ pub trait StopStreamExt: Sized { impl StopStreamExt for S where S: Stream {} pin_project! { + /// A future that early exits when inner [StopToken] completes. + /// + /// Users usually do not need to construct this type manually, but rather use the [StopFutureExt::stop_on] method instead. #[derive(Debug)] pub struct StopFuture { #[pin] @@ -136,6 +154,7 @@ pin_project! { } impl StopFuture { + /// Wraps a future to exit early when `stop_token` completes. pub fn new(future: F, stop_token: StopToken) -> Self { Self { future, stop_token } } @@ -160,7 +179,9 @@ where } } +/// Future extensions to generate [StopFuture] that exits early when `stop_token` completes. pub trait StopFutureExt: Sized { + /// Wraps a future to exit early when `stop_token` completes. fn stop_on(self, stop_token: StopToken) -> StopFuture { StopFuture::new(self, stop_token) } From d23be8e9af22b50b614c724890db5a7f31f9fdf4 Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 13:23:14 +0900 Subject: [PATCH 05/12] Add documentation for stop_on method --- src/server.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/server.rs b/src/server.rs index 46dbdbf5..e656078e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -322,6 +322,24 @@ where &self.state } + /// Stops the server when given `stop_token` completes. + /// + /// # Example + /// + /// ```rust + /// use tide::cancellation::StopSource; + /// + /// let mut app = tide::new(); + /// + /// let stop_source = StopSource::new(); + /// + /// app.stop_on(stop_source.token()); + /// + /// // Runs server... + /// + /// // When something happens + /// drop(stop_source); + /// ``` pub fn stop_on(&mut self, stop_token: StopToken) -> &mut Self { self.stop_token = stop_token; self From 22c53916b2b90e498cf0f62b6036a2bf7b05f358 Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 13:23:58 +0900 Subject: [PATCH 06/12] Impl Default trait for StopSource --- src/cancellation.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/cancellation.rs b/src/cancellation.rs index d79d291c..51b9945e 100644 --- a/src/cancellation.rs +++ b/src/cancellation.rs @@ -39,6 +39,12 @@ impl StopSource { } } +impl Default for StopSource { + fn default() -> Self { + Self::new() + } +} + impl Drop for StopSource { fn drop(&mut self) { self.stopped.store(true, Ordering::SeqCst); From cd74324aeb48d5d33604e4bd7fcbeca2beb057bf Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 14:36:43 +0900 Subject: [PATCH 07/12] Replace cancellation module with stopper --- Cargo.toml | 2 +- src/cancellation.rs | 245 ---------------------------------- src/lib.rs | 3 +- src/listener/tcp_listener.rs | 3 +- src/listener/unix_listener.rs | 3 +- src/server.rs | 25 ++-- 6 files changed, 18 insertions(+), 263 deletions(-) delete mode 100644 src/cancellation.rs diff --git a/Cargo.toml b/Cargo.toml index c3a23356..54395599 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,6 @@ async-session = { version = "3.0", optional = true } async-sse = "4.0.1" async-std = { version = "1.6.5", features = ["unstable"] } async-trait = "0.1.41" -event-listener = "2.5.1" femme = { version = "2.1.1", optional = true } futures-util = "0.3.6" http-client = { version = "6.1.0", default-features = false } @@ -49,6 +48,7 @@ pin-project-lite = "0.2.0" route-recognizer = "0.2.0" serde = "1.0.117" serde_json = "1.0.59" +stopper = "0.2.0" [dev-dependencies] async-std = { version = "1.6.5", features = ["unstable", "attributes"] } diff --git a/src/cancellation.rs b/src/cancellation.rs deleted file mode 100644 index 51b9945e..00000000 --- a/src/cancellation.rs +++ /dev/null @@ -1,245 +0,0 @@ -//! Future and Stream cancellation - -use std::pin::Pin; -use std::sync::atomic::{AtomicBool, Ordering}; - -use async_std::future::Future; -use async_std::stream::Stream; -use async_std::sync::Arc; -use async_std::task::{Context, Poll}; - -use event_listener::{Event, EventListener}; -use pin_project_lite::pin_project; - -/// StopSource produces [StopToken] and cancels all of its tokens on drop. -#[derive(Debug)] -pub struct StopSource { - stopped: Arc, - event: Arc, -} - -impl StopSource { - /// Create a new StopSource - pub fn new() -> Self { - Self { - stopped: Arc::new(AtomicBool::new(false)), - event: Arc::new(Event::new()), - } - } - - /// Produce a new [StopToken], associated with this source. - /// - /// Once this source is dropped, all associated [StopToken] futures will complete. - pub fn token(&self) -> StopToken { - StopToken { - stopped: self.stopped.clone(), - event_listener: self.event.listen(), - event: self.event.clone(), - } - } -} - -impl Default for StopSource { - fn default() -> Self { - Self::new() - } -} - -impl Drop for StopSource { - fn drop(&mut self) { - self.stopped.store(true, Ordering::SeqCst); - self.event.notify(usize::MAX); - } -} - -pin_project! { - /// StopToken is a future which completes when the associated [StopSource] is dropped. - #[derive(Debug)] - pub struct StopToken { - #[pin] - stopped: Arc, - #[pin] - event_listener: EventListener, - event: Arc, - } -} - -impl StopToken { - /// Produce a StopToken that associates with no [StopSource], and never completes. - pub fn never() -> Self { - let event = Event::new(); - Self { - stopped: Arc::new(AtomicBool::new(false)), - event_listener: event.listen(), - event: Arc::new(event), - } - } -} - -impl Clone for StopToken { - fn clone(&self) -> Self { - Self { - stopped: self.stopped.clone(), - event_listener: self.event.listen(), - event: self.event.clone(), - } - } -} - -impl Future for StopToken { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - let _ = Future::poll(Pin::new(&mut this.event_listener), cx); - if this.stopped.load(Ordering::Relaxed) { - Poll::Ready(()) - } else { - Poll::Pending - } - } -} - -pin_project! { - /// A stream that early exits when inner [StopToken] completes. - /// - /// Users usually do not need to construct this type manually, but rather use the [StopStreamExt::stop_on] method instead. - #[derive(Debug)] - pub struct StopStream { - #[pin] - stream: S, - #[pin] - stop_token: StopToken, - } -} - -impl StopStream { - /// Wraps a stream to exit early when `stop_token` completes. - pub fn new(stream: S, stop_token: StopToken) -> Self { - Self { stream, stop_token } - } -} - -impl Stream for StopStream -where - S: Stream, -{ - type Item = S::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - if Future::poll(Pin::new(&mut this.stop_token), cx).is_ready() { - Poll::Ready(None) - } else { - this.stream.poll_next(cx) - } - } -} - -/// Stream extensions to generate [StopStream] that exits early when `stop_token` completes. -pub trait StopStreamExt: Sized { - /// Wraps a stream to exit early when `stop_token` completes. - fn stop_on(self, stop_token: StopToken) -> StopStream { - StopStream::new(self, stop_token) - } -} - -impl StopStreamExt for S where S: Stream {} - -pin_project! { - /// A future that early exits when inner [StopToken] completes. - /// - /// Users usually do not need to construct this type manually, but rather use the [StopFutureExt::stop_on] method instead. - #[derive(Debug)] - pub struct StopFuture { - #[pin] - future: F, - #[pin] - stop_token: StopToken, - } -} - -impl StopFuture { - /// Wraps a future to exit early when `stop_token` completes. - pub fn new(future: F, stop_token: StopToken) -> Self { - Self { future, stop_token } - } -} - -impl Future for StopFuture -where - F: Future, -{ - type Output = Option; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - if Future::poll(Pin::new(&mut this.stop_token), cx).is_ready() { - Poll::Ready(None) - } else { - match this.future.poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(output) => Poll::Ready(Some(output)), - } - } - } -} - -/// Future extensions to generate [StopFuture] that exits early when `stop_token` completes. -pub trait StopFutureExt: Sized { - /// Wraps a future to exit early when `stop_token` completes. - fn stop_on(self, stop_token: StopToken) -> StopFuture { - StopFuture::new(self, stop_token) - } -} - -impl StopFutureExt for F where F: Future {} - -#[cfg(test)] -mod tests { - use std::thread; - use std::time::Duration; - - use async_std::prelude::{FutureExt, StreamExt}; - - use super::*; - - #[test] - fn test_cancellation() { - let source = StopSource::new(); - let stop_token = source.token(); - - let pending_stream1 = async_std::stream::pending::<()>(); - let pending_stream2 = async_std::stream::pending::<()>(); - let pending_future1 = async_std::future::pending::<()>(); - let pending_future2 = async_std::future::pending::<()>(); - let wrapped_stream1 = pending_stream1.stop_on(stop_token.clone()); - let wrapped_stream2 = pending_stream2.stop_on(stop_token.clone()); - let wrapped_future1 = pending_future1.stop_on(stop_token.clone()); - let wrapped_future2 = pending_future2.stop_on(stop_token); - - let join_future = wrapped_stream1 - .last() - .join(wrapped_stream2.last()) - .join(wrapped_future1) - .join(wrapped_future2); - - thread::spawn(move || { - let source = source; - thread::sleep(Duration::from_secs(1)); - drop(source); - }); - - let res = async_std::task::block_on(join_future); - assert_eq!(res, (((None, None), None), None)); - } - - #[test] - fn test_never() { - let pending_future = async_std::future::pending::<()>(); - let wrapped_future = pending_future.stop_on(StopToken::never()); - - let res = async_std::task::block_on(wrapped_future.timeout(Duration::from_secs(1))); - assert!(res.is_err()); - } -} diff --git a/src/lib.rs b/src/lib.rs index 516128e9..6acd12e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,7 +76,6 @@ mod route; mod router; mod server; -pub mod cancellation; pub mod convert; pub mod listener; pub mod log; @@ -99,6 +98,8 @@ pub use server::Server; pub use http_types::{self as http, Body, Error, Status, StatusCode}; +pub use stopper; + /// Create a new Tide server. /// /// # Examples diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index 7e1973cf..6b906690 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -1,6 +1,5 @@ use super::{is_transient_error, ListenInfo}; -use crate::cancellation::StopStreamExt; use crate::listener::Listener; use crate::{log, Server}; @@ -99,7 +98,7 @@ where .take() .expect("`Listener::bind` must be called before `Listener::accept`"); - let mut incoming = listener.incoming().stop_on(server.stop_token.clone()); + let mut incoming = server.stopper.clone().stop_stream(listener.incoming()); while let Some(stream) = incoming.next().await { match stream { diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index a767d223..805d2182 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -1,6 +1,5 @@ use super::{is_transient_error, ListenInfo}; -use crate::cancellation::StopStreamExt; use crate::listener::Listener; use crate::{log, Server}; @@ -97,7 +96,7 @@ where .take() .expect("`Listener::bind` must be called before `Listener::accept`"); - let mut incoming = listener.incoming().stop_on(server.stop_token.clone()); + let mut incoming = server.stopper.clone().stop_stream(listener.incoming()); while let Some(stream) = incoming.next().await { match stream { diff --git a/src/server.rs b/src/server.rs index e656078e..5e3d4840 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,7 +3,8 @@ use async_std::io; use async_std::sync::Arc; -use crate::cancellation::StopToken; +use stopper::Stopper; + #[cfg(feature = "cookies")] use crate::cookies; use crate::listener::{Listener, ToListener}; @@ -39,7 +40,7 @@ pub struct Server { /// We don't use a Mutex around the Vec here because adding a middleware during execution should be an error. #[allow(clippy::rc_buffer)] middleware: Arc>>>, - pub(crate) stop_token: StopToken, + pub(crate) stopper: Stopper, } impl Server<()> { @@ -115,7 +116,7 @@ where Arc::new(log::LogMiddleware::new()), ]), state, - stop_token: StopToken::never(), + stopper: Stopper::new(), } } @@ -289,7 +290,7 @@ where router, state, middleware, - stop_token: _, + stopper: _, } = self.clone(); let method = req.method().to_owned(); @@ -322,26 +323,26 @@ where &self.state } - /// Stops the server when given `stop_token` completes. + /// Stops the server when given `stopper` stops. /// /// # Example /// /// ```rust - /// use tide::cancellation::StopSource; + /// use tide::stopper::Stopper; /// /// let mut app = tide::new(); /// - /// let stop_source = StopSource::new(); + /// let stopper = Stopper::new(); /// - /// app.stop_on(stop_source.token()); + /// app.with_stopper(stopper.clone()); /// /// // Runs server... /// /// // When something happens - /// drop(stop_source); + /// stopper.stop(); /// ``` - pub fn stop_on(&mut self, stop_token: StopToken) -> &mut Self { - self.stop_token = stop_token; + pub fn with_stopper(&mut self, stopper: Stopper) -> &mut Self { + self.stopper = stopper; self } } @@ -358,7 +359,7 @@ impl Clone for Server { router: self.router.clone(), state: self.state.clone(), middleware: self.middleware.clone(), - stop_token: self.stop_token.clone(), + stopper: self.stopper.clone(), } } } From a8d1707d9f1097345eb36a24e848d055e51d2874 Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 18:21:25 +0900 Subject: [PATCH 08/12] Make stopper optional --- src/listener/tcp_listener.rs | 9 ++++++++- src/listener/unix_listener.rs | 9 ++++++++- src/server.rs | 6 +++--- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index 6b906690..c12e4a17 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -9,6 +9,8 @@ use async_std::net::{self, SocketAddr, TcpStream}; use async_std::prelude::*; use async_std::{io, task}; +use futures_util::future::Either; + /// This represents a tide [Listener](crate::listener::Listener) that /// wraps an [async_std::net::TcpListener]. It is implemented as an /// enum in order to allow creation of a tide::listener::TcpListener @@ -98,7 +100,12 @@ where .take() .expect("`Listener::bind` must be called before `Listener::accept`"); - let mut incoming = server.stopper.clone().stop_stream(listener.incoming()); + let incoming = listener.incoming(); + let mut incoming = if let Some(stopper) = server.stopper.clone() { + Either::Left(stopper.stop_stream(incoming)) + } else { + Either::Right(incoming) + }; while let Some(stream) = incoming.next().await { match stream { diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index 805d2182..a2058e30 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -10,6 +10,8 @@ use async_std::path::PathBuf; use async_std::prelude::*; use async_std::{io, task}; +use futures_util::future::Either; + /// This represents a tide [Listener](crate::listener::Listener) that /// wraps an [async_std::os::unix::net::UnixListener]. It is implemented as an /// enum in order to allow creation of a tide::listener::UnixListener @@ -96,7 +98,12 @@ where .take() .expect("`Listener::bind` must be called before `Listener::accept`"); - let mut incoming = server.stopper.clone().stop_stream(listener.incoming()); + let incoming = listener.incoming(); + let mut incoming = if let Some(stopper) = server.stopper.clone() { + Either::Left(stopper.stop_stream(incoming)) + } else { + Either::Right(incoming) + }; while let Some(stream) = incoming.next().await { match stream { diff --git a/src/server.rs b/src/server.rs index 5e3d4840..0dc96ea1 100644 --- a/src/server.rs +++ b/src/server.rs @@ -40,7 +40,7 @@ pub struct Server { /// We don't use a Mutex around the Vec here because adding a middleware during execution should be an error. #[allow(clippy::rc_buffer)] middleware: Arc>>>, - pub(crate) stopper: Stopper, + pub(crate) stopper: Option, } impl Server<()> { @@ -116,7 +116,7 @@ where Arc::new(log::LogMiddleware::new()), ]), state, - stopper: Stopper::new(), + stopper: None, } } @@ -342,7 +342,7 @@ where /// stopper.stop(); /// ``` pub fn with_stopper(&mut self, stopper: Stopper) -> &mut Self { - self.stopper = stopper; + self.stopper = Some(stopper); self } } From 7b5b1468a424668c471f65ac9f8655b916764695 Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 18:23:58 +0900 Subject: [PATCH 09/12] Add join handles to wait tasks --- src/listener/tcp_listener.rs | 22 +++++++++++++++++++--- src/listener/unix_listener.rs | 22 +++++++++++++++++++--- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index c12e4a17..533249f4 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -10,6 +10,7 @@ use async_std::prelude::*; use async_std::{io, task}; use futures_util::future::Either; +use futures_util::stream::FuturesUnordered; /// This represents a tide [Listener](crate::listener::Listener) that /// wraps an [async_std::net::TcpListener]. It is implemented as an @@ -24,6 +25,7 @@ pub struct TcpListener { listener: Option, server: Option>, info: Option, + join_handles: Vec>, } impl TcpListener { @@ -33,6 +35,7 @@ impl TcpListener { listener: None, server: None, info: None, + join_handles: Vec::new(), } } @@ -42,11 +45,15 @@ impl TcpListener { listener: Some(tcp_listener.into()), server: None, info: None, + join_handles: Vec::new(), } } } -fn handle_tcp(app: Server, stream: TcpStream) { +fn handle_tcp( + app: Server, + stream: TcpStream, +) -> task::JoinHandle<()> { task::spawn(async move { let local_addr = stream.local_addr().ok(); let peer_addr = stream.peer_addr().ok(); @@ -60,7 +67,7 @@ fn handle_tcp(app: Server, stream: if let Err(error) = fut.await { log::error!("async-h1 error", { error: error.to_string() }); } - }); + }) } #[async_trait::async_trait] @@ -118,10 +125,19 @@ where } Ok(stream) => { - handle_tcp(server.clone(), stream); + let handle = handle_tcp(server.clone(), stream); + self.join_handles.push(handle); } }; } + + let join_handles = std::mem::take(&mut self.join_handles); + join_handles + .into_iter() + .collect::>>() + .collect::<()>() + .await; + Ok(()) } diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index a2058e30..aa309577 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -11,6 +11,7 @@ use async_std::prelude::*; use async_std::{io, task}; use futures_util::future::Either; +use futures_util::stream::FuturesUnordered; /// This represents a tide [Listener](crate::listener::Listener) that /// wraps an [async_std::os::unix::net::UnixListener]. It is implemented as an @@ -25,6 +26,7 @@ pub struct UnixListener { listener: Option, server: Option>, info: Option, + join_handles: Vec>, } impl UnixListener { @@ -34,6 +36,7 @@ impl UnixListener { listener: None, server: None, info: None, + join_handles: Vec::new(), } } @@ -43,11 +46,15 @@ impl UnixListener { listener: Some(unix_listener.into()), server: None, info: None, + join_handles: Vec::new(), } } } -fn handle_unix(app: Server, stream: UnixStream) { +fn handle_unix( + app: Server, + stream: UnixStream, +) -> task::JoinHandle<()> { task::spawn(async move { let local_addr = unix_socket_addr_to_string(stream.local_addr()); let peer_addr = unix_socket_addr_to_string(stream.peer_addr()); @@ -61,7 +68,7 @@ fn handle_unix(app: Server, stream: if let Err(error) = fut.await { log::error!("async-h1 error", { error: error.to_string() }); } - }); + }) } #[async_trait::async_trait] @@ -116,10 +123,19 @@ where } Ok(stream) => { - handle_unix(server.clone(), stream); + let handle = handle_unix(server.clone(), stream); + self.join_handles.push(handle); } }; } + + let join_handles = std::mem::take(&mut self.join_handles); + join_handles + .into_iter() + .collect::>>() + .collect::<()>() + .await; + Ok(()) } From 5c119793a7e432eb0ceecf04bfbb80fd3d7000f5 Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 18:40:48 +0900 Subject: [PATCH 10/12] Add test for cancellation --- tests/cancellation.rs | 57 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/cancellation.rs diff --git a/tests/cancellation.rs b/tests/cancellation.rs new file mode 100644 index 00000000..2996f61e --- /dev/null +++ b/tests/cancellation.rs @@ -0,0 +1,57 @@ +mod test_utils; +use async_std::prelude::*; +use async_std::task; +use std::time::Duration; + +use tide::stopper::Stopper; +use tide::Response; + +#[async_std::test] +async fn cancellation() -> Result<(), http_types::Error> { + let port = test_utils::find_port().await; + let stopper = Stopper::new(); + let stopper_ = stopper.clone(); + + let server = task::spawn(async move { + let mut app = tide::new(); + app.with_stopper(stopper_); + app.at("/").get(|_| async { + task::sleep(Duration::from_secs(1)).await; + Ok(Response::new(200)) + }); + app.listen(("localhost", port)).await?; + tide::Result::Ok(()) + }); + + let client1 = task::spawn(async move { + task::sleep(Duration::from_millis(100)).await; + let res = surf::get(format!("http://localhost:{}", port)) + .await + .unwrap(); + assert_eq!(res.status(), 200); + async_std::future::pending().await + }); + + let client2 = task::spawn(async move { + task::sleep(Duration::from_millis(200)).await; + let res = surf::get(format!("http://localhost:{}", port)) + .await + .unwrap(); + assert_eq!(res.status(), 200); + async_std::future::pending().await + }); + + let stop = task::spawn(async move { + task::sleep(Duration::from_millis(300)).await; + stopper.stop(); + Ok(()) + }); + + server + .try_join(stop) + .race(client1.try_join(client2)) + .timeout(Duration::from_secs(2)) + .await??; + + Ok(()) +} From 54dbb93f7d7a0f258f710e23c7dc6c0cc31b3afb Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 18:28:26 +0900 Subject: [PATCH 11/12] Use draft async-h1 --- Cargo.toml | 4 +++- src/listener/tcp_listener.rs | 18 +++++++++++++----- src/listener/unix_listener.rs | 18 +++++++++++++----- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 54395599..fab5e45b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,9 @@ sessions = ["async-session", "cookies"] unstable = [] [dependencies] -async-h1 = { version = "2.3.0", optional = true } +# async-h1 = { version = "2.3.0", optional = true } +# FIXME: for proposal purpose only +async-h1 = { git = "https://github.com/pbzweihander/async-h1.git", branch = "cancellation", optional = true } async-session = { version = "3.0", optional = true } async-sse = "4.0.1" async-std = { version = "1.6.5", features = ["unstable"] } diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index 533249f4..9ae9dc8d 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -58,11 +58,19 @@ fn handle_tcp( let local_addr = stream.local_addr().ok(); let peer_addr = stream.peer_addr().ok(); - let fut = async_h1::accept(stream, |mut req| async { - req.set_local_addr(local_addr); - req.set_peer_addr(peer_addr); - app.respond(req).await - }); + let opts = async_h1::ServerOptions { + stopper: app.stopper.clone(), + ..Default::default() + }; + let fut = async_h1::accept_with_opts( + stream, + |mut req| async { + req.set_local_addr(local_addr); + req.set_peer_addr(peer_addr); + app.respond(req).await + }, + opts, + ); if let Err(error) = fut.await { log::error!("async-h1 error", { error: error.to_string() }); diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index aa309577..50233ca8 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -59,11 +59,19 @@ fn handle_unix( let local_addr = unix_socket_addr_to_string(stream.local_addr()); let peer_addr = unix_socket_addr_to_string(stream.peer_addr()); - let fut = async_h1::accept(stream, |mut req| async { - req.set_local_addr(local_addr.as_ref()); - req.set_peer_addr(peer_addr.as_ref()); - app.respond(req).await - }); + let opts = async_h1::ServerOptions { + stopper: app.stopper.clone(), + ..Default::default() + }; + let fut = async_h1::accept_with_opts( + stream, + |mut req| async { + req.set_local_addr(local_addr.as_ref()); + req.set_peer_addr(peer_addr.as_ref()); + app.respond(req).await + }, + opts, + ); if let Err(error) = fut.await { log::error!("async-h1 error", { error: error.to_string() }); From 595e47d938d4cda596d40a75b269cdf8fc0283ec Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 25 Jun 2021 21:03:15 +0900 Subject: [PATCH 12/12] Wait tasks with waitgroup --- Cargo.toml | 1 + src/listener/tcp_listener.rs | 24 ++++++++++-------------- src/listener/unix_listener.rs | 24 ++++++++++-------------- 3 files changed, 21 insertions(+), 28 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fab5e45b..6512284c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ route-recognizer = "0.2.0" serde = "1.0.117" serde_json = "1.0.59" stopper = "0.2.0" +waitgroup = "0.1.2" [dev-dependencies] async-std = { version = "1.6.5", features = ["unstable", "attributes"] } diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index 9ae9dc8d..5b444ff3 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -10,7 +10,8 @@ use async_std::prelude::*; use async_std::{io, task}; use futures_util::future::Either; -use futures_util::stream::FuturesUnordered; + +use waitgroup::{WaitGroup, Worker}; /// This represents a tide [Listener](crate::listener::Listener) that /// wraps an [async_std::net::TcpListener]. It is implemented as an @@ -25,7 +26,6 @@ pub struct TcpListener { listener: Option, server: Option>, info: Option, - join_handles: Vec>, } impl TcpListener { @@ -35,7 +35,6 @@ impl TcpListener { listener: None, server: None, info: None, - join_handles: Vec::new(), } } @@ -45,7 +44,6 @@ impl TcpListener { listener: Some(tcp_listener.into()), server: None, info: None, - join_handles: Vec::new(), } } } @@ -53,8 +51,11 @@ impl TcpListener { fn handle_tcp( app: Server, stream: TcpStream, -) -> task::JoinHandle<()> { + wait_group_worker: Worker, +) { task::spawn(async move { + let _wait_group_worker = wait_group_worker; + let local_addr = stream.local_addr().ok(); let peer_addr = stream.peer_addr().ok(); @@ -75,7 +76,7 @@ fn handle_tcp( if let Err(error) = fut.await { log::error!("async-h1 error", { error: error.to_string() }); } - }) + }); } #[async_trait::async_trait] @@ -121,6 +122,7 @@ where } else { Either::Right(incoming) }; + let wait_group = WaitGroup::new(); while let Some(stream) = incoming.next().await { match stream { @@ -133,18 +135,12 @@ where } Ok(stream) => { - let handle = handle_tcp(server.clone(), stream); - self.join_handles.push(handle); + handle_tcp(server.clone(), stream, wait_group.worker()); } }; } - let join_handles = std::mem::take(&mut self.join_handles); - join_handles - .into_iter() - .collect::>>() - .collect::<()>() - .await; + wait_group.wait().await; Ok(()) } diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index 50233ca8..9b6c6e4d 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -11,7 +11,8 @@ use async_std::prelude::*; use async_std::{io, task}; use futures_util::future::Either; -use futures_util::stream::FuturesUnordered; + +use waitgroup::{WaitGroup, Worker}; /// This represents a tide [Listener](crate::listener::Listener) that /// wraps an [async_std::os::unix::net::UnixListener]. It is implemented as an @@ -26,7 +27,6 @@ pub struct UnixListener { listener: Option, server: Option>, info: Option, - join_handles: Vec>, } impl UnixListener { @@ -36,7 +36,6 @@ impl UnixListener { listener: None, server: None, info: None, - join_handles: Vec::new(), } } @@ -46,7 +45,6 @@ impl UnixListener { listener: Some(unix_listener.into()), server: None, info: None, - join_handles: Vec::new(), } } } @@ -54,8 +52,11 @@ impl UnixListener { fn handle_unix( app: Server, stream: UnixStream, -) -> task::JoinHandle<()> { + wait_group_worker: Worker, +) { task::spawn(async move { + let _wait_group_worker = wait_group_worker; + let local_addr = unix_socket_addr_to_string(stream.local_addr()); let peer_addr = unix_socket_addr_to_string(stream.peer_addr()); @@ -76,7 +77,7 @@ fn handle_unix( if let Err(error) = fut.await { log::error!("async-h1 error", { error: error.to_string() }); } - }) + }); } #[async_trait::async_trait] @@ -119,6 +120,7 @@ where } else { Either::Right(incoming) }; + let wait_group = WaitGroup::new(); while let Some(stream) = incoming.next().await { match stream { @@ -131,18 +133,12 @@ where } Ok(stream) => { - let handle = handle_unix(server.clone(), stream); - self.join_handles.push(handle); + handle_unix(server.clone(), stream, wait_group.worker()); } }; } - let join_handles = std::mem::take(&mut self.join_handles); - join_handles - .into_iter() - .collect::>>() - .collect::<()>() - .await; + wait_group.wait().await; Ok(()) }