Skip to content

Commit 22489ac

Browse files
committed
Use CancellationToken and a TaskTracker to handle graceful shutdowns
1 parent 1e1ec08 commit 22489ac

File tree

16 files changed

+310
-303
lines changed

16 files changed

+310
-303
lines changed

Cargo.lock

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,11 @@ version = "1.0.64"
259259
version = "1.40.0"
260260
features = ["full"]
261261

262+
# Useful async utilities
263+
[workspace.dependencies.tokio-util]
264+
version = "0.7.12"
265+
features = ["rt"]
266+
262267
# Tower services
263268
[workspace.dependencies.tower]
264269
version = "0.5.1"

crates/cli/Cargo.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ serde_json.workspace = true
3636
serde_yaml = "0.9.34"
3737
sqlx.workspace = true
3838
tokio.workspace = true
39+
tokio-util.workspace = true
3940
tower.workspace = true
4041
tower-http.workspace = true
4142
url.workspace = true
@@ -48,12 +49,20 @@ tracing-opentelemetry.workspace = true
4849
opentelemetry.workspace = true
4950
opentelemetry-http.workspace = true
5051
opentelemetry-jaeger-propagator = "0.3.0"
51-
opentelemetry-otlp = { version = "0.17.0", default-features = false, features = ["trace", "metrics", "http-proto"] }
52+
opentelemetry-otlp = { version = "0.17.0", default-features = false, features = [
53+
"trace",
54+
"metrics",
55+
"http-proto",
56+
] }
5257
opentelemetry-prometheus = "0.17.0"
5358
opentelemetry-resource-detectors = "0.3.0"
5459
opentelemetry-semantic-conventions.workspace = true
5560
opentelemetry-stdout = { version = "0.5.0", features = ["trace", "metrics"] }
56-
opentelemetry_sdk = { version = "0.24.1", features = ["trace", "metrics", "rt-tokio"] }
61+
opentelemetry_sdk = { version = "0.24.1", features = [
62+
"trace",
63+
"metrics",
64+
"rt-tokio",
65+
] }
5766
prometheus = "0.13.4"
5867
sentry.workspace = true
5968
sentry-tracing.workspace = true

crates/cli/src/commands/server.rs

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use mas_config::{
1414
AppConfig, ClientsConfig, ConfigurationSection, ConfigurationSectionExt, UpstreamOAuth2Config,
1515
};
1616
use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, Limiter, MetadataCache};
17-
use mas_listener::{server::Server, shutdown::ShutdownStream};
17+
use mas_listener::server::Server;
1818
use mas_matrix_synapse::SynapseConnection;
1919
use mas_router::UrlBuilder;
2020
use mas_storage::SystemClock;
@@ -24,11 +24,11 @@ use rand::{
2424
thread_rng,
2525
};
2626
use sqlx::migrate::Migrate;
27-
use tokio::signal::unix::SignalKind;
2827
use tracing::{info, info_span, warn, Instrument};
2928

3029
use crate::{
3130
app_state::AppState,
31+
shutdown::ShutdownManager,
3232
util::{
3333
database_pool_from_config, mailer_from_config, password_manager_from_config,
3434
policy_factory_from_config, register_sighup, site_config_from_config,
@@ -61,6 +61,7 @@ impl Options {
6161
#[allow(clippy::too_many_lines)]
6262
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
6363
let span = info_span!("cli.run.init").entered();
64+
let shutdown = ShutdownManager::new()?;
6465
let config = AppConfig::extract(figment)?;
6566

6667
if self.migrate {
@@ -173,8 +174,21 @@ impl Options {
173174
url_builder.clone(),
174175
)
175176
.await?;
176-
// TODO: grab the handle
177-
tokio::spawn(monitor.run());
177+
178+
// XXX: The monitor from apalis is a bit annoying to use for graceful shutdowns,
179+
// ideally we'd just give it a cancellation token
180+
let shutdown_future = shutdown.soft_shutdown_token().cancelled_owned();
181+
shutdown.task_tracker().spawn(async move {
182+
if let Err(e) = monitor
183+
.run_with_signal(async move {
184+
shutdown_future.await;
185+
Ok(())
186+
})
187+
.await
188+
{
189+
tracing::error!(error = &e as &dyn std::error::Error, "Task worker failed");
190+
}
191+
});
178192
}
179193

180194
let listeners_config = config.http.listeners.clone();
@@ -186,7 +200,12 @@ impl Options {
186200

187201
// Initialize the activity tracker
188202
// Activity is flushed every minute
189-
let activity_tracker = ActivityTracker::new(pool.clone(), Duration::from_secs(60));
203+
let activity_tracker = ActivityTracker::new(
204+
pool.clone(),
205+
Duration::from_secs(60),
206+
shutdown.task_tracker(),
207+
shutdown.soft_shutdown_token(),
208+
);
190209
let trusted_proxies = config.http.trusted_proxies.clone();
191210

192211
// Build a rate limiter.
@@ -302,16 +321,17 @@ impl Options {
302321
.flatten_ok()
303322
.collect::<Result<Vec<_>, _>>()?;
304323

305-
let shutdown = ShutdownStream::default()
306-
.with_timeout(Duration::from_secs(60))
307-
.with_signal(SignalKind::terminate())?
308-
.with_signal(SignalKind::interrupt())?;
309-
310324
span.exit();
311325

312-
mas_listener::server::run_servers(servers, shutdown).await;
326+
shutdown
327+
.task_tracker()
328+
.spawn(mas_listener::server::run_servers(
329+
servers,
330+
shutdown.soft_shutdown_token(),
331+
shutdown.hard_shutdown_token(),
332+
));
313333

314-
state.activity_tracker.shutdown().await;
334+
shutdown.run().await;
315335

316336
Ok(ExitCode::SUCCESS)
317337
}

crates/cli/src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ mod app_state;
2222
mod commands;
2323
mod sentry_transport;
2424
mod server;
25+
mod shutdown;
2526
mod sync;
2627
mod telemetry;
2728
mod util;

crates/cli/src/shutdown.rs

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Copyright 2024 New Vector Ltd.
2+
//
3+
// SPDX-License-Identifier: AGPL-3.0-only
4+
// Please see LICENSE in the repository root for full details.
5+
6+
use std::time::Duration;
7+
8+
use tokio::signal::unix::{Signal, SignalKind};
9+
use tokio_util::{sync::CancellationToken, task::TaskTracker};
10+
11+
/// A helper to manage graceful shutdowns and track tasks that gracefully
12+
/// shutdown.
13+
///
14+
/// It will listen for SIGTERM and SIGINT signals, and will trigger a soft
15+
/// shutdown on the first signal, and a hard shutdown on the second signal or
16+
/// after a timeout.
17+
///
18+
/// Users of this manager should use the `soft_shutdown_token` to react to a
19+
/// soft shutdown, which should gracefully finish requests and close
20+
/// connections, and the `hard_shutdown_token` to react to a hard shutdown,
21+
/// which should drop all connections and finish all requests.
22+
///
23+
/// They should also use the `task_tracker` to make it track things running, so
24+
/// that it knows when the soft shutdown is over and worked.
25+
pub struct ShutdownManager {
26+
hard_shutdown_token: CancellationToken,
27+
soft_shutdown_token: CancellationToken,
28+
task_tracker: TaskTracker,
29+
sigterm: Signal,
30+
sigint: Signal,
31+
timeout: Duration,
32+
}
33+
34+
impl ShutdownManager {
35+
/// Create a new shutdown manager, installing the signal handlers
36+
///
37+
/// # Errors
38+
///
39+
/// Returns an error if the signal handler could not be installed
40+
pub fn new() -> Result<Self, std::io::Error> {
41+
let hard_shutdown_token = CancellationToken::new();
42+
let soft_shutdown_token = hard_shutdown_token.child_token();
43+
let sigterm = tokio::signal::unix::signal(SignalKind::terminate())?;
44+
let sigint = tokio::signal::unix::signal(SignalKind::interrupt())?;
45+
let timeout = Duration::from_secs(60);
46+
let task_tracker = TaskTracker::new();
47+
48+
Ok(Self {
49+
hard_shutdown_token,
50+
soft_shutdown_token,
51+
task_tracker,
52+
sigterm,
53+
sigint,
54+
timeout,
55+
})
56+
}
57+
58+
/// Get a reference to the task tracker
59+
#[must_use]
60+
pub fn task_tracker(&self) -> &TaskTracker {
61+
&self.task_tracker
62+
}
63+
64+
/// Get a cancellation token that can be used to react to a hard shutdown
65+
#[must_use]
66+
pub fn hard_shutdown_token(&self) -> CancellationToken {
67+
self.hard_shutdown_token.clone()
68+
}
69+
70+
/// Get a cancellation token that can be used to react to a soft shutdown
71+
#[must_use]
72+
pub fn soft_shutdown_token(&self) -> CancellationToken {
73+
self.soft_shutdown_token.clone()
74+
}
75+
76+
/// Run until we finish completely shutting down.
77+
pub async fn run(mut self) {
78+
// Wait for a first signal and trigger the soft shutdown
79+
tokio::select! {
80+
_ = self.sigterm.recv() => {
81+
tracing::info!("Shutdown signal received (SIGTERM), shutting down");
82+
},
83+
_ = self.sigint.recv() => {
84+
tracing::info!("Shutdown signal received (SIGINT), shutting down");
85+
},
86+
};
87+
88+
self.soft_shutdown_token.cancel();
89+
self.task_tracker.close();
90+
91+
// Start the timeout
92+
let timeout = tokio::time::sleep(self.timeout);
93+
tokio::select! {
94+
_ = self.sigterm.recv() => {
95+
tracing::warn!("Second shutdown signal received (SIGTERM), abort");
96+
},
97+
_ = self.sigint.recv() => {
98+
tracing::warn!("Second shutdown signal received (SIGINT), abort");
99+
},
100+
() = timeout => {
101+
tracing::warn!("Shutdown timeout reached, abort");
102+
},
103+
() = self.task_tracker.wait() => {
104+
// This is the "happy path", we have gracefully shutdown
105+
},
106+
}
107+
108+
self.hard_shutdown_token().cancel();
109+
110+
// TODO: we may want to have a time out on the task tracker, in case we have
111+
// really stuck tasks on it
112+
self.task_tracker().wait().await;
113+
114+
tracing::info!("All tasks are done, exitting");
115+
}
116+
}

crates/handlers/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ workspace = true
1414
[dependencies]
1515
# Async runtime
1616
tokio.workspace = true
17+
tokio-util.workspace = true
1718
futures-util = "0.3.31"
1819
async-trait.workspace = true
1920

0 commit comments

Comments
 (0)