Skip to content

Commit 88fb4ed

Browse files
authored
Shutdown the server if any of the tasks crashes (#3672)
1 parent a0b3a94 commit 88fb4ed

File tree

8 files changed

+57
-22
lines changed

8 files changed

+57
-22
lines changed

crates/cli/src/commands/server.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ impl Options {
311311
shutdown.hard_shutdown_token(),
312312
));
313313

314-
shutdown.run().await;
314+
let exit_code = shutdown.run().await;
315315

316-
Ok(ExitCode::SUCCESS)
316+
Ok(exit_code)
317317
}
318318
}

crates/cli/src/commands/worker.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ impl Options {
8080

8181
span.exit();
8282

83-
shutdown.run().await;
83+
let exit_code = shutdown.run().await;
8484

85-
Ok(ExitCode::SUCCESS)
85+
Ok(exit_code)
8686
}
8787
}

crates/cli/src/shutdown.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
// SPDX-License-Identifier: AGPL-3.0-only
44
// Please see LICENSE in the repository root for full details.
55

6-
use std::time::Duration;
6+
use std::{process::ExitCode, time::Duration};
77

88
use tokio::signal::unix::{Signal, SignalKind};
99
use tokio_util::{sync::CancellationToken, task::TaskTracker};
@@ -74,14 +74,22 @@ impl ShutdownManager {
7474
}
7575

7676
/// Run until we finish completely shutting down.
77-
pub async fn run(mut self) {
77+
pub async fn run(mut self) -> ExitCode {
7878
// Wait for a first signal and trigger the soft shutdown
79-
tokio::select! {
79+
let likely_crashed = tokio::select! {
80+
() = self.soft_shutdown_token.cancelled() => {
81+
tracing::warn!("Another task triggered a shutdown, it likely crashed! Shutting down");
82+
true
83+
},
84+
8085
_ = self.sigterm.recv() => {
8186
tracing::info!("Shutdown signal received (SIGTERM), shutting down");
87+
false
8288
},
89+
8390
_ = self.sigint.recv() => {
8491
tracing::info!("Shutdown signal received (SIGINT), shutting down");
92+
false
8593
},
8694
};
8795

@@ -112,5 +120,11 @@ impl ShutdownManager {
112120
self.task_tracker().wait().await;
113121

114122
tracing::info!("All tasks are done, exitting");
123+
124+
if likely_crashed {
125+
ExitCode::FAILURE
126+
} else {
127+
ExitCode::SUCCESS
128+
}
115129
}
116130
}

crates/handlers/src/activity_tracker/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ impl ActivityTracker {
182182
interval: std::time::Duration,
183183
cancellation_token: CancellationToken,
184184
) {
185+
// This guard on the shutdown token is to ensure that if this task crashes for
186+
// any reason, the server will shut down
187+
let _guard = cancellation_token.clone().drop_guard();
188+
185189
loop {
186190
tokio::select! {
187191
biased;

crates/handlers/src/activity_tracker/worker.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ impl Worker {
9393
mut receiver: tokio::sync::mpsc::Receiver<Message>,
9494
cancellation_token: CancellationToken,
9595
) {
96+
// This guard on the shutdown token is to ensure that if this task crashes for
97+
// any reason, the server will shut down
98+
let _guard = cancellation_token.clone().drop_guard();
99+
96100
loop {
97101
let message = tokio::select! {
98102
// Because we want the cancellation token to trigger only once,

crates/listener/src/server.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,10 @@ pub async fn run_servers<S, B>(
314314
B::Data: Send,
315315
B::Error: std::error::Error + Send + Sync + 'static,
316316
{
317+
// This guard on the shutdown token is to ensure that if this task crashes for
318+
// any reason, the server will shut down
319+
let _guard = soft_shutdown_token.clone().drop_guard();
320+
317321
// Create a stream of accepted connections out of the listeners
318322
let mut accept_stream: SelectAll<_> = listeners
319323
.into_iter()
@@ -360,7 +364,7 @@ pub async fn run_servers<S, B>(
360364
connection_tasks.spawn(conn);
361365
},
362366
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
363-
Some(Err(e)) => tracing::error!("Join error: {e}"),
367+
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
364368
None => tracing::error!("Join set was polled even though it was empty"),
365369
}
366370
},
@@ -369,8 +373,8 @@ pub async fn run_servers<S, B>(
369373
res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
370374
match res {
371375
Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
372-
Some(Ok(Err(e))) => tracing::error!("Error while serving connection: {e}"),
373-
Some(Err(e)) => tracing::error!("Join error: {e}"),
376+
Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
377+
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
374378
None => tracing::error!("Join set was polled even though it was empty"),
375379
}
376380
},
@@ -412,7 +416,7 @@ pub async fn run_servers<S, B>(
412416
connection_tasks.spawn(conn);
413417
}
414418
Some(Ok(Err(_e))) => { /* Connection did not finish handshake, error should be logged in `accept` */ },
415-
Some(Err(e)) => tracing::error!("Join error: {e}"),
419+
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
416420
None => tracing::error!("Join set was polled even though it was empty"),
417421
}
418422
},
@@ -421,8 +425,8 @@ pub async fn run_servers<S, B>(
421425
res = connection_tasks.join_next(), if !connection_tasks.is_empty() => {
422426
match res {
423427
Some(Ok(Ok(()))) => tracing::trace!("Connection finished"),
424-
Some(Ok(Err(e))) => tracing::error!("Error while serving connection: {e}"),
425-
Some(Err(e)) => tracing::error!("Join error: {e}"),
428+
Some(Ok(Err(e))) => tracing::error!(error = &*e as &dyn std::error::Error, "Error while serving connection"),
429+
Some(Err(e)) => tracing::error!(error = &e as &dyn std::error::Error, "Join error"),
426430
None => tracing::error!("Join set was polled even though it was empty"),
427431
}
428432
},

crates/tasks/src/lib.rs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,7 @@ pub async fn init(
133133
mas_storage::queue::CleanupExpiredTokensJob,
134134
);
135135

136-
task_tracker.spawn(async move {
137-
if let Err(e) = worker.run().await {
138-
tracing::error!(
139-
error = &e as &dyn std::error::Error,
140-
"Failed to run new queue"
141-
);
142-
}
143-
});
136+
task_tracker.spawn(worker.run());
144137

145138
Ok(())
146139
}

crates/tasks/src/new_queue.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ pub struct QueueWorker {
200200
am_i_leader: bool,
201201
last_heartbeat: DateTime<Utc>,
202202
cancellation_token: CancellationToken,
203+
#[expect(dead_code, reason = "This is used on Drop")]
204+
cancellation_guard: tokio_util::sync::DropGuard,
203205
state: State,
204206
schedules: Vec<ScheduleDefinition>,
205207
tracker: JobTracker,
@@ -269,6 +271,10 @@ impl QueueWorker {
269271
)
270272
.build();
271273

274+
// We put a cancellation drop guard in the structure, so that when it gets
275+
// dropped, we're sure to cancel the token
276+
let cancellation_guard = cancellation_token.clone().drop_guard();
277+
272278
Ok(Self {
273279
rng,
274280
clock,
@@ -277,6 +283,7 @@ impl QueueWorker {
277283
am_i_leader: false,
278284
last_heartbeat: now,
279285
cancellation_token,
286+
cancellation_guard,
280287
state,
281288
schedules: Vec::new(),
282289
tracker: JobTracker::new(),
@@ -316,7 +323,16 @@ impl QueueWorker {
316323
self
317324
}
318325

319-
pub async fn run(&mut self) -> Result<(), QueueRunnerError> {
326+
pub async fn run(mut self) {
327+
if let Err(e) = self.run_inner().await {
328+
tracing::error!(
329+
error = &e as &dyn std::error::Error,
330+
"Failed to run new queue"
331+
);
332+
}
333+
}
334+
335+
async fn run_inner(&mut self) -> Result<(), QueueRunnerError> {
320336
self.setup_schedules().await?;
321337

322338
while !self.cancellation_token.is_cancelled() {

0 commit comments

Comments
 (0)