diff --git a/.gitignore b/.gitignore index 00541512b..f23a7b55e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ dist/ npm-debug.log* Cargo.lock .DS_Store -.idea \ No newline at end of file +.idea +.vscode/launch.json diff --git a/Cargo.toml b/Cargo.toml index 7e34b7193..30127d8ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ async-sse = "4.0.1" async-std = { version = "1.6.5", features = ["unstable"] } async-trait = "0.1.41" femme = { version = "2.1.1", optional = true } +futures = "0.3.7" futures-util = "0.3.6" http-client = { version = "6.1.0", default-features = false } http-types = "2.5.0" diff --git a/src/cancelation_token.rs b/src/cancelation_token.rs new file mode 100644 index 000000000..58d4d46e0 --- /dev/null +++ b/src/cancelation_token.rs @@ -0,0 +1,60 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; + +#[derive(Debug)] +pub struct CancelationToken { + shared_state: Arc> +} + +#[derive(Debug)] +struct CancelationTokenState { + canceled: bool, + waker: Option +} + +/// Future that allows gracefully shutting down the server +impl CancelationToken { + pub fn new() -> CancelationToken { + CancelationToken { + shared_state: Arc::new(Mutex::new(CancelationTokenState { + canceled: false, + waker: None + })) + } + } + + /// Call to shut down the server + pub fn complete(&self) { + let mut shared_state = self.shared_state.lock().unwrap(); + + shared_state.canceled = true; + if let Some(waker) = shared_state.waker.take() { + waker.wake() + } + } +} + +impl Future for CancelationToken { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut shared_state = self.shared_state.lock().unwrap(); + + if shared_state.canceled { + Poll::Ready(()) + } else { + shared_state.waker = Some(cx.waker().clone()); + Poll::Pending + } + } +} + +impl Clone for CancelationToken { + fn clone(&self) -> Self { + CancelationToken { + shared_state: self.shared_state.clone() + } + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 20a113962..de26cc85d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,6 +60,7 @@ #![doc(html_favicon_url = "https://yoshuawuyts.com/assets/http-rs/favicon.ico")] #![doc(html_logo_url = "https://yoshuawuyts.com/assets/http-rs/logo-rounded.png")] +mod cancelation_token; #[cfg(feature = "cookies")] mod cookies; mod endpoint; @@ -88,6 +89,7 @@ pub mod utils; #[cfg(feature = "sessions")] pub mod sessions; +pub use cancelation_token::CancelationToken; pub use endpoint::Endpoint; pub use middleware::{Middleware, Next}; pub use redirect::Redirect; diff --git a/src/listener/concurrent_listener.rs b/src/listener/concurrent_listener.rs index 1cff9b0d7..e1eb2e1c0 100644 --- a/src/listener/concurrent_listener.rs +++ b/src/listener/concurrent_listener.rs @@ -1,9 +1,9 @@ use crate::listener::{Listener, ToListener}; -use crate::Server; +use crate::{CancelationToken, Server}; use std::fmt::{self, Debug, Display, Formatter}; -use async_std::io; +use async_std::{io, task}; use futures_util::stream::{futures_unordered::FuturesUnordered, StreamExt}; /// ConcurrentListener allows tide to listen on any number of transports @@ -79,17 +79,29 @@ impl ConcurrentListener { #[async_trait::async_trait] impl Listener for ConcurrentListener { - async fn listen(&mut self, app: Server) -> io::Result<()> { + async fn listen(&mut self, app: Server, cancelation_token: CancelationToken) -> io::Result<()> { let mut futures_unordered = FuturesUnordered::new(); + let mut cancelation_tokens = Vec::new(); + for listener in self.0.iter_mut() { let app = app.clone(); - futures_unordered.push(listener.listen(app)); + let sub_cancelation_token = CancelationToken::new(); + futures_unordered.push(listener.listen(app, sub_cancelation_token.clone())); + cancelation_tokens.push(sub_cancelation_token); } + task::spawn(async move { + cancelation_token.await; + for sub_cancelation_token in cancelation_tokens.iter_mut() { + sub_cancelation_token.complete(); + } + }); + while let Some(result) = futures_unordered.next().await { result?; } + Ok(()) } } diff --git a/src/listener/failover_listener.rs b/src/listener/failover_listener.rs index 4ab1bd242..267e01ab3 100644 --- a/src/listener/failover_listener.rs +++ b/src/listener/failover_listener.rs @@ -1,9 +1,9 @@ use crate::listener::{Listener, ToListener}; -use crate::Server; +use crate::{CancelationToken, Server}; use std::fmt::{self, Debug, Display, Formatter}; -use async_std::io; +use async_std::{io, task}; /// FailoverListener allows tide to attempt to listen in a sequential /// order to any number of ports/addresses. The first successful @@ -81,10 +81,14 @@ impl FailoverListener { #[async_trait::async_trait] impl Listener for FailoverListener { - async fn listen(&mut self, app: Server) -> io::Result<()> { + async fn listen(&mut self, app: Server, cancelation_token: CancelationToken) -> io::Result<()> { + + let mut cancelation_tokens = Vec::new(); + for listener in self.0.iter_mut() { let app = app.clone(); - match listener.listen(app).await { + let sub_cancelation_token = CancelationToken::new(); + match listener.listen(app, sub_cancelation_token.clone()).await { Ok(_) => return Ok(()), Err(e) => { crate::log::info!("unable to listen", { @@ -93,8 +97,16 @@ impl Listener for FailoverListener< }); } } + cancelation_tokens.push(sub_cancelation_token); } + task::spawn(async move { + cancelation_token.await; + for sub_cancelation_token in cancelation_tokens.iter_mut() { + sub_cancelation_token.complete(); + } + }); + Err(io::Error::new( io::ErrorKind::AddrNotAvailable, "unable to bind to any supplied listener spec", diff --git a/src/listener/mod.rs b/src/listener/mod.rs index 82033de15..27715f9dd 100644 --- a/src/listener/mod.rs +++ b/src/listener/mod.rs @@ -12,7 +12,7 @@ mod to_listener_impls; #[cfg(all(unix, feature = "h1-server"))] mod unix_listener; -use crate::Server; +use crate::{CancelationToken, Server}; use async_std::io; pub use concurrent_listener::ConcurrentListener; @@ -37,7 +37,7 @@ pub trait Listener: /// This is the primary entrypoint for the Listener trait. listen /// is called exactly once, and is expected to spawn tasks for /// each incoming connection. - async fn listen(&mut self, app: Server) -> io::Result<()>; + async fn listen(&mut self, app: Server, cancelation_token: CancelationToken) -> io::Result<()>; } /// crate-internal shared logic used by tcp and unix listeners to diff --git a/src/listener/parsed_listener.rs b/src/listener/parsed_listener.rs index 4b0e186a9..6a48938c6 100644 --- a/src/listener/parsed_listener.rs +++ b/src/listener/parsed_listener.rs @@ -1,7 +1,7 @@ #[cfg(unix)] use super::UnixListener; use super::{Listener, TcpListener}; -use crate::Server; +use crate::{CancelationToken, Server}; use async_std::io; use std::fmt::{self, Display, Formatter}; @@ -32,11 +32,11 @@ impl Display for ParsedListener { #[async_trait::async_trait] impl Listener for ParsedListener { - async fn listen(&mut self, app: Server) -> io::Result<()> { + async fn listen(&mut self, app: Server, cancelation_token: CancelationToken) -> io::Result<()> { match self { #[cfg(unix)] - Self::Unix(u) => u.listen(app).await, - Self::Tcp(t) => t.listen(app).await, + Self::Unix(u) => u.listen(app, cancelation_token).await, + Self::Tcp(t) => t.listen(app, cancelation_token).await, } } } diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index db68530ee..3c792fe36 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -1,13 +1,14 @@ use super::is_transient_error; use crate::listener::Listener; -use crate::{log, Server}; +use crate::{CancelationToken, log, Server}; use std::fmt::{self, Display, Formatter}; use async_std::net::{self, SocketAddr, TcpStream}; use async_std::prelude::*; use async_std::{io, task}; +use futures::future::{self, Either}; /// This represents a tide [Listener](crate::listener::Listener) that /// wraps an [async_std::net::TcpListener]. It is implemented as an @@ -70,28 +71,37 @@ fn handle_tcp(app: Server, stream: #[async_trait::async_trait] impl Listener for TcpListener { - async fn listen(&mut self, app: Server) -> io::Result<()> { + async fn listen(&mut self, app: Server, cancelation_token: CancelationToken) -> io::Result<()> { self.connect().await?; let listener = self.listener()?; crate::log::info!("Server listening on {}", self); let mut incoming = listener.incoming(); - while let Some(stream) = incoming.next().await { - match stream { - Err(ref e) if is_transient_error(e) => continue, - Err(error) => { - let delay = std::time::Duration::from_millis(500); - crate::log::error!("Error: {}. Pausing for {:?}.", error, delay); - task::sleep(delay).await; - continue; - } - - Ok(stream) => { - handle_tcp(app.clone(), stream); + 'serve_loop: + while let Either::Left(result) = future::select(incoming.next(), cancelation_token.clone()).await { + match result.0 { + Some(stream) => { + match stream { + Err(ref e) if is_transient_error(e) => continue, + Err(error) => { + let delay = std::time::Duration::from_millis(500); + crate::log::error!("Error: {}. Pausing for {:?}.", error, delay); + task::sleep(delay).await; + continue; + } + + Ok(stream) => { + handle_tcp(app.clone(), stream); + } + }; + }, + None => { + break 'serve_loop; } }; } + Ok(()) } } diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index 72aff852d..33a612667 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -1,13 +1,14 @@ use super::is_transient_error; use crate::listener::Listener; -use crate::{log, Server}; +use crate::{CancelationToken, log, Server}; use std::fmt::{self, Display, Formatter}; use async_std::os::unix::net::{self, SocketAddr, UnixStream}; use async_std::prelude::*; use async_std::{io, path::PathBuf, task}; +use futures::future::{self, Either}; /// This represents a tide [Listener](crate::listener::Listener) that /// wraps an [async_std::os::unix::net::UnixListener]. It is implemented as an @@ -83,24 +84,32 @@ fn handle_unix(app: Server, stream: #[async_trait::async_trait] impl Listener for UnixListener { - async fn listen(&mut self, app: Server) -> io::Result<()> { + async fn listen(&mut self, app: Server, cancelation_token: CancelationToken) -> io::Result<()> { self.connect().await?; crate::log::info!("Server listening on {}", self); let listener = self.listener()?; let mut incoming = listener.incoming(); - while let Some(stream) = incoming.next().await { - match stream { - Err(ref e) if is_transient_error(e) => continue, - Err(error) => { - let delay = std::time::Duration::from_millis(500); - crate::log::error!("Error: {}. Pausing for {:?}.", error, delay); - task::sleep(delay).await; - continue; - } - - Ok(stream) => { - handle_unix(app.clone(), stream); + 'serve_loop: + while let Either::Left(result) = future::select(incoming.next(), cancelation_token.clone()).await { + match result.0 { + Some(stream) => { + match stream { + Err(ref e) if is_transient_error(e) => continue, + Err(error) => { + let delay = std::time::Duration::from_millis(500); + crate::log::error!("Error: {}. Pausing for {:?}.", error, delay); + task::sleep(delay).await; + continue; + } + + Ok(stream) => { + handle_unix(app.clone(), stream); + } + }; + }, + None => { + break 'serve_loop; } }; } diff --git a/src/server.rs b/src/server.rs index 853d39e34..5f0147d64 100644 --- a/src/server.rs +++ b/src/server.rs @@ -9,7 +9,7 @@ use crate::listener::{Listener, ToListener}; use crate::log; use crate::middleware::{Middleware, Next}; use crate::router::{Router, Selection}; -use crate::{Endpoint, Request, Route}; +use crate::{CancelationToken, Endpoint, Request, Route}; /// An HTTP server. /// @@ -187,7 +187,12 @@ impl Server { /// Asynchronously serve the app with the supplied listener. For more details, see [Listener] and [ToListener] pub async fn listen>(self, listener: TL) -> io::Result<()> { - listener.to_listener()?.listen(self).await + self.listen_with_cancelation_token(listener, CancelationToken::new()).await + } + + /// Asynchronously serve the app with the supplied listener and canelation token. For more details, see [Listener] and [ToListener] + pub async fn listen_with_cancelation_token>(self, listener: TL, cancelation_token: CancelationToken) -> io::Result<()> { + listener.to_listener()?.listen(self, cancelation_token).await } /// Respond to a `Request` with a `Response`. diff --git a/tests/server.rs b/tests/server.rs index 070557451..9bf53c30d 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -1,5 +1,4 @@ mod test_utils; -use async_std::prelude::*; use async_std::task; use std::time::Duration; @@ -9,58 +8,62 @@ use tide::{Body, Request}; #[test] fn hello_world() -> tide::Result<()> { task::block_on(async { + + let cancelation_token = tide::CancelationToken::new(); + let port = test_utils::find_port().await; - let server = task::spawn(async move { - let mut app = tide::new(); - app.at("/").get(move |mut req: Request<()>| async move { - assert_eq!(req.body_string().await.unwrap(), "nori".to_string()); - assert!(req.local_addr().unwrap().contains(&port.to_string())); - assert!(req.peer_addr().is_some()); - Ok("says hello") - }); - app.listen(("localhost", port)).await?; - Result::<(), http_types::Error>::Ok(()) - }); - let client = task::spawn(async move { - task::sleep(Duration::from_millis(100)).await; - let string = surf::get(format!("http://localhost:{}", port)) - .body(Body::from_string("nori".to_string())) - .recv_string() - .await - .unwrap(); - assert_eq!(string, "says hello"); - Ok(()) + let mut app = tide::new(); + app.at("/").get(move |mut req: Request<()>| async move { + assert_eq!(req.body_string().await.unwrap(), "nori".to_string()); + assert!(req.local_addr().unwrap().contains(&port.to_string())); + assert!(req.peer_addr().is_some()); + Ok("says hello") }); - server.race(client).await + let server = app.listen_with_cancelation_token(("localhost", port), cancelation_token.clone()); + let server = task::spawn(server); + + task::sleep(Duration::from_millis(100)).await; + let string = surf::get(format!("http://localhost:{}", port)) + .body(Body::from_string("nori".to_string())) + .recv_string() + .await + .unwrap(); + assert_eq!(string, "says hello"); + + cancelation_token.complete(); + + server.await.expect("Server did not complete gracefully"); + Result::<(), http_types::Error>::Ok(()) }) } #[test] fn echo_server() -> tide::Result<()> { task::block_on(async { + + let cancelation_token = tide::CancelationToken::new(); + let port = test_utils::find_port().await; - let server = task::spawn(async move { - let mut app = tide::new(); - app.at("/").get(|req| async move { Ok(req) }); + let mut app = tide::new(); + app.at("/").get(|req| async move { Ok(req) }); - app.listen(("localhost", port)).await?; - Result::<(), http_types::Error>::Ok(()) - }); + let server = app.listen_with_cancelation_token(("localhost", port), cancelation_token.clone()); + let server = task::spawn(server); - let client = task::spawn(async move { - task::sleep(Duration::from_millis(100)).await; - let string = surf::get(format!("http://localhost:{}", port)) - .body(Body::from_string("chashu".to_string())) - .recv_string() - .await - .unwrap(); - assert_eq!(string, "chashu".to_string()); - Ok(()) - }); + task::sleep(Duration::from_millis(100)).await; + let string = surf::get(format!("http://localhost:{}", port)) + .body(Body::from_string("chashu".to_string())) + .recv_string() + .await + .unwrap(); + assert_eq!(string, "chashu".to_string()); - server.race(client).await + cancelation_token.complete(); + + server.await.expect("Server did not complete gracefully"); + Result::<(), http_types::Error>::Ok(()) }) } @@ -72,30 +75,31 @@ fn json() -> tide::Result<()> { } task::block_on(async { + + let cancelation_token = tide::CancelationToken::new(); + let port = test_utils::find_port().await; - let server = task::spawn(async move { - let mut app = tide::new(); - app.at("/").get(|mut req: Request<()>| async move { - let mut counter: Counter = req.body_json().await.unwrap(); - assert_eq!(counter.count, 0); - counter.count = 1; - Ok(Body::from_json(&counter)?) - }); - app.listen(("localhost", port)).await?; - Result::<(), http_types::Error>::Ok(()) - }); - let client = task::spawn(async move { - task::sleep(Duration::from_millis(100)).await; - let counter: Counter = surf::get(format!("http://localhost:{}", &port)) - .body(Body::from_json(&Counter { count: 0 })?) - .recv_json() - .await - .unwrap(); - assert_eq!(counter.count, 1); - Ok(()) + let mut app = tide::new(); + app.at("/").get(|mut req: Request<()>| async move { + let mut counter: Counter = req.body_json().await.unwrap(); + assert_eq!(counter.count, 0); + counter.count = 1; + Ok(Body::from_json(&counter)?) }); - server.race(client).await + let server = app.listen_with_cancelation_token(("localhost", port), cancelation_token.clone()); + let server = task::spawn(server); + + task::sleep(Duration::from_millis(100)).await; + let counter: Counter = surf::get(format!("http://localhost:{}", &port)) + .body(Body::from_json(&Counter { count: 0 })?) + .recv_json() + .await + .unwrap(); + assert_eq!(counter.count, 1); + + server.await.expect("Server did not complete gracefully"); + Result::<(), http_types::Error>::Ok(()) }) }