From 22489ac250409de3be6a3107ef654780ec24dfbf Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 9 Oct 2024 17:48:59 +0200 Subject: [PATCH] Use CancellationToken and a TaskTracker to handle graceful shutdowns --- Cargo.lock | 6 +- Cargo.toml | 5 + crates/cli/Cargo.toml | 13 +- crates/cli/src/commands/server.rs | 44 +++-- crates/cli/src/main.rs | 1 + crates/cli/src/shutdown.rs | 116 ++++++++++++ crates/handlers/Cargo.toml | 1 + crates/handlers/src/activity_tracker/mod.rs | 66 ++++--- .../handlers/src/activity_tracker/worker.rs | 51 +++--- crates/handlers/src/test_utils.rs | 19 +- crates/handlers/src/views/index.rs | 1 + crates/listener/Cargo.toml | 2 +- crates/listener/examples/demo/main.rs | 27 ++- crates/listener/src/lib.rs | 1 - crates/listener/src/server.rs | 88 ++++----- crates/listener/src/shutdown.rs | 172 ------------------ 16 files changed, 310 insertions(+), 303 deletions(-) create mode 100644 crates/cli/src/shutdown.rs delete mode 100644 crates/listener/src/shutdown.rs diff --git a/Cargo.lock b/Cargo.lock index b8b87f1b4..9a6e431a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3261,6 +3261,7 @@ dependencies = [ "serde_yaml", "sqlx", "tokio", + "tokio-util", "tower 0.5.1", "tower-http", "tracing", @@ -3393,6 +3394,7 @@ dependencies = [ "thiserror", "time", "tokio", + "tokio-util", "tower 0.5.1", "tower-http", "tracing", @@ -3563,7 +3565,6 @@ version = "0.12.0" dependencies = [ "anyhow", "bytes", - "event-listener 5.3.1", "futures-util", "http-body", "hyper", @@ -3576,6 +3577,7 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-test", + "tokio-util", "tower 0.5.1", "tower-http", "tracing", @@ -6318,6 +6320,8 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", + "hashbrown 0.14.5", "pin-project-lite", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 7950c5cce..ec815f484 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -259,6 +259,11 @@ version = "1.0.64" version = "1.40.0" features = ["full"] +# Useful async utilities +[workspace.dependencies.tokio-util] +version = "0.7.12" +features = ["rt"] + # Tower services [workspace.dependencies.tower] version = "0.5.1" diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index a1bac3241..d5bc9dac8 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -36,6 +36,7 @@ serde_json.workspace = true serde_yaml = "0.9.34" sqlx.workspace = true tokio.workspace = true +tokio-util.workspace = true tower.workspace = true tower-http.workspace = true url.workspace = true @@ -48,12 +49,20 @@ tracing-opentelemetry.workspace = true opentelemetry.workspace = true opentelemetry-http.workspace = true opentelemetry-jaeger-propagator = "0.3.0" -opentelemetry-otlp = { version = "0.17.0", default-features = false, features = ["trace", "metrics", "http-proto"] } +opentelemetry-otlp = { version = "0.17.0", default-features = false, features = [ + "trace", + "metrics", + "http-proto", +] } opentelemetry-prometheus = "0.17.0" opentelemetry-resource-detectors = "0.3.0" opentelemetry-semantic-conventions.workspace = true opentelemetry-stdout = { version = "0.5.0", features = ["trace", "metrics"] } -opentelemetry_sdk = { version = "0.24.1", features = ["trace", "metrics", "rt-tokio"] } +opentelemetry_sdk = { version = "0.24.1", features = [ + "trace", + "metrics", + "rt-tokio", +] } prometheus = "0.13.4" sentry.workspace = true sentry-tracing.workspace = true diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 2f96b2e6f..8386b4689 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -14,7 +14,7 @@ use mas_config::{ AppConfig, ClientsConfig, ConfigurationSection, ConfigurationSectionExt, UpstreamOAuth2Config, }; use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, Limiter, MetadataCache}; -use mas_listener::{server::Server, shutdown::ShutdownStream}; +use mas_listener::server::Server; use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; use mas_storage::SystemClock; @@ -24,11 +24,11 @@ use rand::{ thread_rng, }; use sqlx::migrate::Migrate; -use tokio::signal::unix::SignalKind; use tracing::{info, info_span, warn, Instrument}; use crate::{ app_state::AppState, + shutdown::ShutdownManager, util::{ database_pool_from_config, mailer_from_config, password_manager_from_config, policy_factory_from_config, register_sighup, site_config_from_config, @@ -61,6 +61,7 @@ impl Options { #[allow(clippy::too_many_lines)] pub async fn run(self, figment: &Figment) -> anyhow::Result { let span = info_span!("cli.run.init").entered(); + let shutdown = ShutdownManager::new()?; let config = AppConfig::extract(figment)?; if self.migrate { @@ -173,8 +174,21 @@ impl Options { url_builder.clone(), ) .await?; - // TODO: grab the handle - tokio::spawn(monitor.run()); + + // XXX: The monitor from apalis is a bit annoying to use for graceful shutdowns, + // ideally we'd just give it a cancellation token + let shutdown_future = shutdown.soft_shutdown_token().cancelled_owned(); + shutdown.task_tracker().spawn(async move { + if let Err(e) = monitor + .run_with_signal(async move { + shutdown_future.await; + Ok(()) + }) + .await + { + tracing::error!(error = &e as &dyn std::error::Error, "Task worker failed"); + } + }); } let listeners_config = config.http.listeners.clone(); @@ -186,7 +200,12 @@ impl Options { // Initialize the activity tracker // Activity is flushed every minute - let activity_tracker = ActivityTracker::new(pool.clone(), Duration::from_secs(60)); + let activity_tracker = ActivityTracker::new( + pool.clone(), + Duration::from_secs(60), + shutdown.task_tracker(), + shutdown.soft_shutdown_token(), + ); let trusted_proxies = config.http.trusted_proxies.clone(); // Build a rate limiter. @@ -302,16 +321,17 @@ impl Options { .flatten_ok() .collect::, _>>()?; - let shutdown = ShutdownStream::default() - .with_timeout(Duration::from_secs(60)) - .with_signal(SignalKind::terminate())? - .with_signal(SignalKind::interrupt())?; - span.exit(); - mas_listener::server::run_servers(servers, shutdown).await; + shutdown + .task_tracker() + .spawn(mas_listener::server::run_servers( + servers, + shutdown.soft_shutdown_token(), + shutdown.hard_shutdown_token(), + )); - state.activity_tracker.shutdown().await; + shutdown.run().await; Ok(ExitCode::SUCCESS) } diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index e1d037ffb..eee5a73a4 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -22,6 +22,7 @@ mod app_state; mod commands; mod sentry_transport; mod server; +mod shutdown; mod sync; mod telemetry; mod util; diff --git a/crates/cli/src/shutdown.rs b/crates/cli/src/shutdown.rs new file mode 100644 index 000000000..080be0f2e --- /dev/null +++ b/crates/cli/src/shutdown.rs @@ -0,0 +1,116 @@ +// Copyright 2024 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use std::time::Duration; + +use tokio::signal::unix::{Signal, SignalKind}; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; + +/// A helper to manage graceful shutdowns and track tasks that gracefully +/// shutdown. +/// +/// It will listen for SIGTERM and SIGINT signals, and will trigger a soft +/// shutdown on the first signal, and a hard shutdown on the second signal or +/// after a timeout. +/// +/// Users of this manager should use the `soft_shutdown_token` to react to a +/// soft shutdown, which should gracefully finish requests and close +/// connections, and the `hard_shutdown_token` to react to a hard shutdown, +/// which should drop all connections and finish all requests. +/// +/// They should also use the `task_tracker` to make it track things running, so +/// that it knows when the soft shutdown is over and worked. +pub struct ShutdownManager { + hard_shutdown_token: CancellationToken, + soft_shutdown_token: CancellationToken, + task_tracker: TaskTracker, + sigterm: Signal, + sigint: Signal, + timeout: Duration, +} + +impl ShutdownManager { + /// Create a new shutdown manager, installing the signal handlers + /// + /// # Errors + /// + /// Returns an error if the signal handler could not be installed + pub fn new() -> Result { + let hard_shutdown_token = CancellationToken::new(); + let soft_shutdown_token = hard_shutdown_token.child_token(); + let sigterm = tokio::signal::unix::signal(SignalKind::terminate())?; + let sigint = tokio::signal::unix::signal(SignalKind::interrupt())?; + let timeout = Duration::from_secs(60); + let task_tracker = TaskTracker::new(); + + Ok(Self { + hard_shutdown_token, + soft_shutdown_token, + task_tracker, + sigterm, + sigint, + timeout, + }) + } + + /// Get a reference to the task tracker + #[must_use] + pub fn task_tracker(&self) -> &TaskTracker { + &self.task_tracker + } + + /// Get a cancellation token that can be used to react to a hard shutdown + #[must_use] + pub fn hard_shutdown_token(&self) -> CancellationToken { + self.hard_shutdown_token.clone() + } + + /// Get a cancellation token that can be used to react to a soft shutdown + #[must_use] + pub fn soft_shutdown_token(&self) -> CancellationToken { + self.soft_shutdown_token.clone() + } + + /// Run until we finish completely shutting down. + pub async fn run(mut self) { + // Wait for a first signal and trigger the soft shutdown + tokio::select! { + _ = self.sigterm.recv() => { + tracing::info!("Shutdown signal received (SIGTERM), shutting down"); + }, + _ = self.sigint.recv() => { + tracing::info!("Shutdown signal received (SIGINT), shutting down"); + }, + }; + + self.soft_shutdown_token.cancel(); + self.task_tracker.close(); + + // Start the timeout + let timeout = tokio::time::sleep(self.timeout); + tokio::select! { + _ = self.sigterm.recv() => { + tracing::warn!("Second shutdown signal received (SIGTERM), abort"); + }, + _ = self.sigint.recv() => { + tracing::warn!("Second shutdown signal received (SIGINT), abort"); + }, + () = timeout => { + tracing::warn!("Shutdown timeout reached, abort"); + }, + () = self.task_tracker.wait() => { + // This is the "happy path", we have gracefully shutdown + }, + } + + self.hard_shutdown_token().cancel(); + + // TODO: we may want to have a time out on the task tracker, in case we have + // really stuck tasks on it + self.task_tracker().wait().await; + + tracing::info!("All tasks are done, exitting"); + } +} diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index d24210cb7..46a4788c1 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -14,6 +14,7 @@ workspace = true [dependencies] # Async runtime tokio.workspace = true +tokio-util.workspace = true futures-util = "0.3.31" async-trait.workspace = true diff --git a/crates/handlers/src/activity_tracker/mod.rs b/crates/handlers/src/activity_tracker/mod.rs index d2d1e683a..d314fc16e 100644 --- a/crates/handlers/src/activity_tracker/mod.rs +++ b/crates/handlers/src/activity_tracker/mod.rs @@ -13,6 +13,7 @@ use chrono::{DateTime, Utc}; use mas_data_model::{BrowserSession, CompatSession, Session}; use mas_storage::Clock; use sqlx::PgPool; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; use ulid::Ulid; pub use self::bound::Bound; @@ -45,7 +46,6 @@ enum Message { ip: Option, }, Flush(tokio::sync::oneshot::Sender<()>), - Shutdown(tokio::sync::oneshot::Sender<()>), } #[derive(Clone)] @@ -54,16 +54,29 @@ pub struct ActivityTracker { } impl ActivityTracker { - /// Create a new activity tracker, spawning the worker. + /// Create a new activity tracker + /// + /// It will spawn the background worker and a loop to flush the tracker on + /// the task tracker, and both will shut themselves down, flushing one last + /// time, when the cancellation token is cancelled. #[must_use] - pub fn new(pool: PgPool, flush_interval: std::time::Duration) -> Self { + pub fn new( + pool: PgPool, + flush_interval: std::time::Duration, + task_tracker: &TaskTracker, + cancellation_token: CancellationToken, + ) -> Self { let worker = Worker::new(pool); let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE); let tracker = ActivityTracker { channel: sender }; // Spawn the flush loop and the worker - tokio::spawn(tracker.clone().flush_loop(flush_interval)); - tokio::spawn(worker.run(receiver)); + task_tracker.spawn( + tracker + .clone() + .flush_loop(flush_interval, cancellation_token.clone()), + ); + task_tracker.spawn(worker.run(receiver, cancellation_token)); tracker } @@ -148,50 +161,47 @@ impl ActivityTracker { match res { Ok(()) => { if let Err(e) = rx.await { - tracing::error!("Failed to flush activity tracker: {}", e); + tracing::error!( + error = &e as &dyn std::error::Error, + "Failed to flush activity tracker" + ); } } Err(e) => { - tracing::error!("Failed to flush activity tracker: {}", e); + tracing::error!( + error = &e as &dyn std::error::Error, + "Failed to flush activity tracker" + ); } } } /// Regularly flush the activity tracker. - async fn flush_loop(self, interval: std::time::Duration) { + async fn flush_loop( + self, + interval: std::time::Duration, + cancellation_token: CancellationToken, + ) { loop { tokio::select! { biased; + () = cancellation_token.cancelled() => { + // The cancellation token was cancelled, so we should exit + return; + } + // First check if the channel is closed, then check if the timer expired () = self.channel.closed() => { // The channel was closed, so we should exit - break; + return; } + () = tokio::time::sleep(interval) => { self.flush().await; } } } } - - /// Shutdown the activity tracker. - /// - /// This will wait for all pending messages to be processed. - pub async fn shutdown(&self) { - let (tx, rx) = tokio::sync::oneshot::channel(); - let res = self.channel.send(Message::Shutdown(tx)).await; - - match res { - Ok(()) => { - if let Err(e) = rx.await { - tracing::error!("Failed to shutdown activity tracker: {}", e); - } - } - Err(e) => { - tracing::error!("Failed to shutdown activity tracker: {}", e); - } - } - } } diff --git a/crates/handlers/src/activity_tracker/worker.rs b/crates/handlers/src/activity_tracker/worker.rs index ad77fe2f7..5f2e1caac 100644 --- a/crates/handlers/src/activity_tracker/worker.rs +++ b/crates/handlers/src/activity_tracker/worker.rs @@ -13,6 +13,7 @@ use opentelemetry::{ Key, }; use sqlx::PgPool; +use tokio_util::sync::CancellationToken; use ulid::Ulid; use crate::activity_tracker::{Message, SessionKind}; @@ -88,9 +89,30 @@ impl Worker { } } - pub(super) async fn run(mut self, mut receiver: tokio::sync::mpsc::Receiver) { - let mut shutdown_notifier = None; - while let Some(message) = receiver.recv().await { + pub(super) async fn run( + mut self, + mut receiver: tokio::sync::mpsc::Receiver, + cancellation_token: CancellationToken, + ) { + loop { + let message = tokio::select! { + // Because we want the cancellation token to trigger only once, + // we looked whether we closed the channel or not + () = cancellation_token.cancelled(), if !receiver.is_closed() => { + // We only close the channel, which will make it flush all + // the pending messages + receiver.close(); + tracing::debug!("Shutting down activity tracker"); + continue; + }, + + message = receiver.recv() => { + // We consumed all the messages, break out of the loop + let Some(message) = message else { break }; + message + } + }; + match message { Message::Record { kind, @@ -129,37 +151,18 @@ impl Worker { record.end_time = date_time.max(record.end_time); } + Message::Flush(tx) => { self.message_counter.add(1, &[TYPE.string("flush")]); self.flush().await; let _ = tx.send(()); } - Message::Shutdown(tx) => { - self.message_counter.add(1, &[TYPE.string("shutdown")]); - - let old_tx = shutdown_notifier.replace(tx); - if let Some(old_tx) = old_tx { - tracing::warn!("Activity tracker shutdown requested while another shutdown was already in progress"); - // Still send the shutdown signal to the previous notifier. This means we - // send the shutdown signal before we flush the activity tracker, but that - // should be fine, since there should not be multiple shutdown requests. - let _ = old_tx.send(()); - } - receiver.close(); - } } } + // Flush one last time self.flush().await; - - if let Some(shutdown_notifier) = shutdown_notifier { - let _ = shutdown_notifier.send(()); - } else { - // This should never happen, since we set the shutdown notifier when we receive - // the first shutdown message - tracing::warn!("Activity tracker shutdown requested but no shutdown notifier was set"); - } } /// Flush the activity tracker. diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 1e602411e..40283a6ff 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -44,6 +44,10 @@ use rand::SeedableRng; use rand_chacha::ChaChaRng; use serde::{de::DeserializeOwned, Serialize}; use sqlx::PgPool; +use tokio_util::{ + sync::{CancellationToken, DropGuard}, + task::TaskTracker, +}; use tower::{Layer, Service, ServiceExt}; use url::Url; @@ -105,6 +109,9 @@ pub(crate) struct TestState { pub limiter: Limiter, pub clock: Arc, pub rng: Arc>, + + #[allow(dead_code)] // It is used, as it will cancel the CancellationToken when dropped + cancellation_drop_guard: Arc, } fn workspace_root() -> camino::Utf8PathBuf { @@ -147,6 +154,9 @@ impl TestState { ) -> Result { let workspace_root = workspace_root(); + let task_tracker = TaskTracker::new(); + let shutdown_token = CancellationToken::new(); + let url_builder = UrlBuilder::new("https://example.com/".parse()?, None, None); let templates = Templates::load( @@ -204,8 +214,12 @@ impl TestState { let graphql_schema = graphql::schema_builder().data(state).finish(); - let activity_tracker = - ActivityTracker::new(pool.clone(), std::time::Duration::from_secs(1)); + let activity_tracker = ActivityTracker::new( + pool.clone(), + std::time::Duration::from_secs(60), + &task_tracker, + shutdown_token.child_token(), + ); let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap(); @@ -227,6 +241,7 @@ impl TestState { limiter, clock, rng, + cancellation_drop_guard: Arc::new(shutdown_token.drop_guard()), }) } diff --git a/crates/handlers/src/views/index.rs b/crates/handlers/src/views/index.rs index 3b3fe43f1..fffe80a6f 100644 --- a/crates/handlers/src/views/index.rs +++ b/crates/handlers/src/views/index.rs @@ -26,6 +26,7 @@ pub async fn get( cookie_jar: CookieJar, PreferredLanguage(locale): PreferredLanguage, ) -> Result { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let (session_info, cookie_jar) = cookie_jar.session_info(); let session = session_info.load_session(&mut repo).await?; diff --git a/crates/listener/Cargo.toml b/crates/listener/Cargo.toml index c36d92eb2..252c5a623 100644 --- a/crates/listener/Cargo.toml +++ b/crates/listener/Cargo.toml @@ -13,7 +13,6 @@ workspace = true [dependencies] bytes.workspace = true -event-listener = "5.3.1" futures-util = "0.3.31" http-body.workspace = true hyper = { workspace = true, features = ["server"] } @@ -24,6 +23,7 @@ socket2 = "0.5.7" thiserror.workspace = true tokio.workspace = true tokio-rustls = "0.26.0" +tokio-util.workspace = true tower.workspace = true tower-http.workspace = true tracing.workspace = true diff --git a/crates/listener/examples/demo/main.rs b/crates/listener/examples/demo/main.rs index 3012701fc..b547d13f1 100644 --- a/crates/listener/examples/demo/main.rs +++ b/crates/listener/examples/demo/main.rs @@ -14,9 +14,9 @@ use std::{ use anyhow::Context; use hyper::{Request, Response}; -use mas_listener::{server::Server, shutdown::ShutdownStream, ConnectionInfo}; -use tokio::signal::unix::SignalKind; +use mas_listener::{server::Server, ConnectionInfo}; use tokio_rustls::rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig}; +use tokio_util::sync::CancellationToken; use tower::service_fn; static CA_CERT_PEM: &[u8] = include_bytes!("./certs/ca.pem"); @@ -53,12 +53,23 @@ async fn main() -> Result<(), anyhow::Error> { tracing::info!("Listening on http://127.0.0.1:3000, http(proxy)://127.0.0.1:3001, https://127.0.0.1:3002 and https(proxy)://127.0.0.1:3003"); - let shutdown = ShutdownStream::default() - .with_timeout(Duration::from_secs(1)) - .with_signal(SignalKind::interrupt())? - .with_signal(SignalKind::terminate())?; - - mas_listener::server::run_servers(servers, shutdown).await; + let hard_shutdown = CancellationToken::new(); + let soft_shutdown = hard_shutdown.child_token(); + + { + let hard_shutdown = hard_shutdown.clone(); + let soft_shutdown = soft_shutdown.clone(); + tokio::spawn(async move { + tokio::signal::ctrl_c().await.unwrap(); + tracing::info!("Ctrl-C received, performing soft-shutdown"); + soft_shutdown.cancel(); + tokio::signal::ctrl_c().await.unwrap(); + tracing::info!("Ctrl-C received again, shutting down"); + hard_shutdown.cancel(); + }); + } + + mas_listener::server::run_servers(servers, hard_shutdown, soft_shutdown).await; Ok(()) } diff --git a/crates/listener/src/lib.rs b/crates/listener/src/lib.rs index 106319915..74ded2388 100644 --- a/crates/listener/src/lib.rs +++ b/crates/listener/src/lib.rs @@ -16,7 +16,6 @@ pub mod maybe_tls; pub mod proxy_protocol; pub mod rewind; pub mod server; -pub mod shutdown; pub mod unix_or_tcp; #[derive(Debug, Clone)] diff --git a/crates/listener/src/server.rs b/crates/listener/src/server.rs index 84000835f..06eb03128 100644 --- a/crates/listener/src/server.rs +++ b/crates/listener/src/server.rs @@ -7,13 +7,12 @@ use std::{ future::Future, pin::Pin, - sync::{atomic::AtomicBool, Arc}, + sync::Arc, task::{Context, Poll}, time::Duration, }; -use event_listener::{Event, EventListener}; -use futures_util::{stream::SelectAll, Stream, StreamExt}; +use futures_util::{stream::SelectAll, StreamExt}; use hyper::{Request, Response}; use hyper_util::{ rt::{TokioExecutor, TokioIo}, @@ -23,6 +22,7 @@ use hyper_util::{ use pin_project_lite::pin_project; use thiserror::Error; use tokio_rustls::rustls::ServerConfig; +use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned}; use tower::Service; use tower_http::add_extension::AddExtension; use tracing::Instrument; @@ -84,18 +84,24 @@ impl Server { } /// Run a single server - pub async fn run(self, shutdown: SD) - where + pub async fn run( + self, + soft_shutdown_token: CancellationToken, + hard_shutdown_token: CancellationToken, + ) where S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: std::error::Error + Send + Sync + 'static, B: http_body::Body + Send + 'static, B::Data: Send, B::Error: std::error::Error + Send + Sync + 'static, - SD: Stream + Unpin, - SD::Item: std::fmt::Display, { - run_servers(std::iter::once(self), shutdown).await; + run_servers( + std::iter::once(self), + soft_shutdown_token, + hard_shutdown_token, + ) + .await; } } @@ -252,18 +258,16 @@ pin_project! { #[pin] connection: C, #[pin] - shutdown_listener: EventListener, - shutdown_in_progress: Arc, + cancellation_future: WaitForCancellationFutureOwned, did_start_shutdown: bool, } } impl AbortableConnection { - fn new(connection: C, shutdown_in_progress: &Arc, event: &Arc) -> Self { + fn new(connection: C, cancellation_token: CancellationToken) -> Self { Self { connection, - shutdown_listener: event.listen(), - shutdown_in_progress: Arc::clone(shutdown_in_progress), + cancellation_future: cancellation_token.cancelled_owned(), did_start_shutdown: false, } } @@ -286,19 +290,11 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); - // Poll the shutdown signal, so that wakers get registered. - // XXX: I don't think we care about the result of this poll, since it's only - // really to register wakers. But I'm not sure if it's safe to - // ignore the result. - let _ = this.shutdown_listener.poll(cx); - - if !*this.did_start_shutdown - && this - .shutdown_in_progress - .load(std::sync::atomic::Ordering::Relaxed) - { - *this.did_start_shutdown = true; - this.connection.as_mut().graceful_shutdown(); + if let Poll::Ready(()) = this.cancellation_future.poll(cx) { + if !*this.did_start_shutdown { + *this.did_start_shutdown = true; + this.connection.as_mut().graceful_shutdown(); + } } this.connection.poll(cx) @@ -306,16 +302,17 @@ where } #[allow(clippy::too_many_lines)] -pub async fn run_servers(listeners: impl IntoIterator>, mut shutdown: SD) -where +pub async fn run_servers( + listeners: impl IntoIterator>, + soft_shutdown_token: CancellationToken, + hard_shutdown_token: CancellationToken, +) where S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: std::error::Error + Send + Sync + 'static, B: http_body::Body + Send + 'static, B::Data: Send, B::Error: std::error::Error + Send + Sync + 'static, - SD: Stream + Unpin, - SD::Item: std::fmt::Display, { // Create a stream of accepted connections out of the listeners let mut accept_stream: SelectAll<_> = listeners @@ -344,19 +341,13 @@ where // A JoinSet which collects connections that are being served let mut connection_tasks = tokio::task::JoinSet::new(); - // A shared atomic boolean to tell all connections to shutdown - let shutdown_in_progress = Arc::new(AtomicBool::new(false)); - let shutdown_event = Arc::new(Event::new()); - loop { tokio::select! { biased; // First look for the shutdown signal - res = shutdown.next() => { - let why = res.map_or_else(|| String::from("???"), |why| format!("{why}")); - tracing::info!("Received shutdown signal ({why})"); - + () = soft_shutdown_token.cancelled() => { + tracing::debug!("Shutting down listeners"); break; }, @@ -365,7 +356,7 @@ where match res { Some(Ok(Ok(connection))) => { tracing::trace!("Accepted connection"); - let conn = AbortableConnection::new(connection, &shutdown_in_progress, &shutdown_event); + let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token()); connection_tasks.spawn(conn); }, Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ }, @@ -385,9 +376,8 @@ where }, // Look for connections to accept - res = accept_stream.next(), if !accept_stream.is_empty() => { - // SAFETY: We shouldn't reach this branch if the stream set is empty - let Some(res) = res else { unreachable!() }; + res = accept_stream.next() => { + let Some(res) = res else { continue }; // Spawn the connection in the set, so we don't have to wait for the handshake to // accept the next connection. This allows us to keep track of active connections @@ -401,10 +391,6 @@ where }; } - // Tell the active connections to shutdown - shutdown_in_progress.store(true, std::sync::atomic::Ordering::Relaxed); - shutdown_event.notify(usize::MAX); - // Wait for connections to cleanup if !accept_tasks.is_empty() || !connection_tasks.is_empty() { tracing::info!( @@ -422,7 +408,7 @@ where match res { Some(Ok(Ok(connection))) => { tracing::trace!("Accepted connection"); - let conn = AbortableConnection::new(connection, &shutdown_in_progress, &shutdown_event); + let conn = AbortableConnection::new(connection, soft_shutdown_token.child_token()); connection_tasks.spawn(conn); } Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ }, @@ -441,11 +427,10 @@ where } }, - // Handle when we receive the shutdown signal again - res = shutdown.next() => { - let why = res.map_or_else(|| String::from("???"), |why| format!("{why}")); + // Handle when we are asked to hard shutdown + () = hard_shutdown_token.cancelled() => { tracing::warn!( - "Received shutdown signal again ({why}), forcing shutdown ({active} active connections, {pending} pending connections)", + "Forcing shutdown ({active} active connections, {pending} pending connections)", active = connection_tasks.len(), pending = accept_tasks.len(), ); @@ -457,5 +442,4 @@ where accept_tasks.shutdown().await; connection_tasks.shutdown().await; - tracing::info!("Shutdown complete"); } diff --git a/crates/listener/src/shutdown.rs b/crates/listener/src/shutdown.rs deleted file mode 100644 index 54775e439..000000000 --- a/crates/listener/src/shutdown.rs +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2022-2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use std::{fmt::Display, pin::Pin, task::Poll, time::Duration}; - -use futures_util::{ready, Future, Stream}; -use tokio::{ - signal::unix::{signal, Signal, SignalKind}, - time::Sleep, -}; - -#[derive(Debug, Clone, Copy)] -pub enum ShutdownReason { - Signal(SignalKind), - Timeout, -} - -fn signal_to_str(kind: SignalKind) -> &'static str { - match kind.as_raw_value() { - libc::SIGALRM => "SIGALRM", - libc::SIGCHLD => "SIGCHLD", - libc::SIGHUP => "SIGHUP", - libc::SIGINT => "SIGINT", - libc::SIGIO => "SIGIO", - libc::SIGPIPE => "SIGPIPE", - libc::SIGQUIT => "SIGQUIT", - libc::SIGTERM => "SIGTERM", - libc::SIGUSR1 => "SIGUSR1", - libc::SIGUSR2 => "SIGUSR2", - _ => "SIG???", - } -} - -impl Display for ShutdownReason { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Signal(s) => signal_to_str(*s).fmt(f), - Self::Timeout => "timeout".fmt(f), - } - } -} - -pub enum ShutdownStreamState { - Waiting, - - Graceful { sleep: Option>> }, - - Done, -} - -impl Default for ShutdownStreamState { - fn default() -> Self { - Self::Waiting - } -} - -impl ShutdownStreamState { - fn is_graceful(&self) -> bool { - matches!(self, Self::Graceful { .. }) - } - - fn is_done(&self) -> bool { - matches!(self, Self::Done) - } - - fn get_sleep_mut(&mut self) -> Option<&mut Pin>> { - match self { - Self::Graceful { sleep } => sleep.as_mut(), - _ => None, - } - } -} - -/// A stream which is used to drive a graceful shutdown. -/// -/// It will emit 2 items: one when a first signal is caught, the other when -/// either another signal is caught, or after a timeout. -#[derive(Default)] -pub struct ShutdownStream { - state: ShutdownStreamState, - signals: Vec<(SignalKind, Signal)>, - timeout: Option, -} - -impl ShutdownStream { - /// Create a default shutdown stream, which listens on SIGINT and SIGTERM, - /// with a 60s timeout - /// - /// # Errors - /// - /// Returns an error if signal handlers could not be installed - pub fn new() -> Result { - let ret = Self::default() - .with_timeout(Duration::from_secs(60)) - .with_signal(SignalKind::interrupt())? - .with_signal(SignalKind::terminate())?; - - Ok(ret) - } - - /// Add a signal to register - /// - /// # Errors - /// - /// Returns an error if the signal handler could not be installed - pub fn with_signal(mut self, kind: SignalKind) -> Result { - let signal = signal(kind)?; - self.signals.push((kind, signal)); - Ok(self) - } - - #[must_use] - pub fn with_timeout(mut self, timeout: Duration) -> Self { - self.timeout = Some(timeout); - self - } -} - -impl Stream for ShutdownStream { - type Item = ShutdownReason; - - fn size_hint(&self) -> (usize, Option) { - match self.state { - ShutdownStreamState::Waiting => (2, Some(2)), - ShutdownStreamState::Graceful { .. } => (1, Some(1)), - ShutdownStreamState::Done => (0, Some(0)), - } - } - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.get_mut(); - - if this.state.is_done() { - return Poll::Ready(None); - } - - for (kind, signal) in &mut this.signals { - match signal.poll_recv(cx) { - Poll::Ready(_) => { - // We got a signal - if this.state.is_graceful() { - // If we was gracefully shutting down, mark it as done - this.state = ShutdownStreamState::Done; - } else { - // Else start the timeout - let sleep = this - .timeout - .map(|duration| Box::pin(tokio::time::sleep(duration))); - this.state = ShutdownStreamState::Graceful { sleep }; - } - - return Poll::Ready(Some(ShutdownReason::Signal(*kind))); - } - Poll::Pending => {} - } - } - - if let Some(timeout) = this.state.get_sleep_mut() { - ready!(timeout.as_mut().poll(cx)); - this.state = ShutdownStreamState::Done; - Poll::Ready(Some(ShutdownReason::Timeout)) - } else { - Poll::Pending - } - } -}