Skip to content

Commit 990b992

Browse files
committed
Shutdown the server if any of the tasks crashes
1 parent f2221d3 commit 990b992

File tree

7 files changed

+54
-20
lines changed

7 files changed

+54
-20
lines changed

crates/cli/src/commands/server.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ impl Options {
311311
shutdown.hard_shutdown_token(),
312312
));
313313

314-
shutdown.run().await;
314+
let exit_code = shutdown.run().await;
315315

316-
Ok(ExitCode::SUCCESS)
316+
Ok(exit_code)
317317
}
318318
}

crates/cli/src/shutdown.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
// SPDX-License-Identifier: AGPL-3.0-only
44
// Please see LICENSE in the repository root for full details.
55

6-
use std::time::Duration;
6+
use std::{process::ExitCode, time::Duration};
77

88
use tokio::signal::unix::{Signal, SignalKind};
99
use tokio_util::{sync::CancellationToken, task::TaskTracker};
@@ -74,14 +74,22 @@ impl ShutdownManager {
7474
}
7575

7676
/// Run until we finish completely shutting down.
77-
pub async fn run(mut self) {
77+
pub async fn run(mut self) -> ExitCode {
7878
// Wait for a first signal and trigger the soft shutdown
79-
tokio::select! {
79+
let likely_crashed = tokio::select! {
80+
() = self.soft_shutdown_token.cancelled() => {
81+
tracing::warn!("Another task triggered a shutdown, it likely crashed! Shutting down");
82+
true
83+
},
84+
8085
_ = self.sigterm.recv() => {
8186
tracing::info!("Shutdown signal received (SIGTERM), shutting down");
87+
false
8288
},
89+
8390
_ = self.sigint.recv() => {
8491
tracing::info!("Shutdown signal received (SIGINT), shutting down");
92+
false
8593
},
8694
};
8795

@@ -112,5 +120,11 @@ impl ShutdownManager {
112120
self.task_tracker().wait().await;
113121

114122
tracing::info!("All tasks are done, exitting");
123+
124+
if likely_crashed {
125+
ExitCode::FAILURE
126+
} else {
127+
ExitCode::SUCCESS
128+
}
115129
}
116130
}

crates/handlers/src/activity_tracker/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ impl ActivityTracker {
182182
interval: std::time::Duration,
183183
cancellation_token: CancellationToken,
184184
) {
185+
// This guard on the shutdown token is to ensure that if this task crashes for
186+
// any reason, the server will shut down
187+
let _guard = cancellation_token.clone().drop_guard();
188+
185189
loop {
186190
tokio::select! {
187191
biased;

crates/handlers/src/activity_tracker/worker.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ impl Worker {
9393
mut receiver: tokio::sync::mpsc::Receiver<Message>,
9494
cancellation_token: CancellationToken,
9595
) {
96+
// This guard on the shutdown token is to ensure that if this task crashes for
97+
// any reason, the server will shut down
98+
let _guard = cancellation_token.clone().drop_guard();
99+
96100
loop {
97101
let message = tokio::select! {
98102
// Because we want the cancellation token to trigger only once,

crates/listener/src/server.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,10 @@ pub async fn run_servers<S, B>(
314314
B::Data: Send,
315315
B::Error: std::error::Error + Send + Sync + 'static,
316316
{
317+
// This guard on the shutdown token is to ensure that if this task crashes for
318+
// any reason, the server will shut down
319+
let _guard = soft_shutdown_token.clone().drop_guard();
320+
317321
// Create a stream of accepted connections out of the listeners
318322
let mut accept_stream: SelectAll<_> = listeners
319323
.into_iter()
@@ -360,7 +364,7 @@ pub async fn run_servers<S, B>(
360364
connection_tasks.spawn(conn);
361365
},
362366
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
363-
Some(Err(e)) => tracing::error!("Join error: {e}"),
367+
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
364368
None => tracing::error!("Join set was polled even though it was empty"),
365369
}
366370
},
@@ -369,8 +373,8 @@ pub async fn run_servers<S, B>(
369373
res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
370374
match res {
371375
Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
372-
Some(Ok(Err(e))) => tracing::error!("Error while serving connection: {e}"),
373-
Some(Err(e)) => tracing::error!("Join error: {e}"),
376+
Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
377+
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
374378
None => tracing::error!("Join set was polled even though it was empty"),
375379
}
376380
},
@@ -412,7 +416,7 @@ pub async fn run_servers<S, B>(
412416
connection_tasks.spawn(conn);
413417
}
414418
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
415-
Some(Err(e)) => tracing::error!("Join error: {e}"),
419+
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
416420
None => tracing::error!("Join set was polled even though it was empty"),
417421
}
418422
},
@@ -421,8 +425,8 @@ pub async fn run_servers<S, B>(
421425
res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
422426
match res {
423427
Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
424-
Some(Ok(Err(e))) => tracing::error!("Error while serving connection: {e}"),
425-
Some(Err(e)) => tracing::error!("Join error: {e}"),
428+
Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
429+
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
426430
None => tracing::error!("Join set was polled even though it was empty"),
427431
}
428432
},

crates/tasks/src/lib.rs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,7 @@ pub async fn init(
125125
mas_storage::queue::CleanupExpiredTokensJob,
126126
);
127127

128-
task_tracker.spawn(async move {
129-
if let Err(e) = worker.run().await {
130-
tracing::error!(
131-
error = &e as &dyn std::error::Error,
132-
"Failed to run new queue"
133-
);
134-
}
135-
});
128+
task_tracker.spawn(worker.run());
136129

137130
Ok(())
138131
}

crates/tasks/src/new_queue.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ pub struct QueueWorker {
190190
am_i_leader: bool,
191191
last_heartbeat: DateTime<Utc>,
192192
cancellation_token: CancellationToken,
193+
cancellation_guard: tokio_util::sync::DropGuard,
193194
state: State,
194195
schedules: Vec<ScheduleDefinition>,
195196
tracker: JobTracker,
@@ -240,6 +241,10 @@ impl QueueWorker {
240241
tracing::info!("Registered worker");
241242
let now = clock.now();
242243

244+
// We put a cancellation drop guard in the structure, so that when it gets
245+
// dropped, we're sure to cancel the token
246+
let cancellation_guard = cancellation_token.clone().drop_guard();
247+
243248
Ok(Self {
244249
rng,
245250
clock,
@@ -248,6 +253,7 @@ impl QueueWorker {
248253
am_i_leader: false,
249254
last_heartbeat: now,
250255
cancellation_token,
256+
cancellation_guard,
251257
state,
252258
schedules: Vec::new(),
253259
tracker: JobTracker::default(),
@@ -285,7 +291,16 @@ impl QueueWorker {
285291
self
286292
}
287293

288-
pub async fn run(&mut self) -> Result<(), QueueRunnerError> {
294+
pub async fn run(mut self) {
295+
if let Err(e) = self.run_inner().await {
296+
tracing::error!(
297+
error = &e as &dyn std::error::Error,
298+
"Failed to run new queue"
299+
);
300+
}
301+
}
302+
303+
async fn run_inner(&mut self) -> Result<(), QueueRunnerError> {
289304
self.setup_schedules().await?;
290305

291306
while !self.cancellation_token.is_cancelled() {

0 commit comments

Comments
 (0)