Skip to content

Commit 2588eed

Browse files
committed
fix: gateway shutdown cancellation
1 parent 0a6ef27 commit 2588eed

File tree

5 files changed

+123
-27
lines changed

5 files changed

+123
-27
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

landscape-gateway/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ landscape-common = { path = "../landscape-common" }
88
landscape-database = { path = "../landscape-database" }
99
pingora = { workspace = true }
1010
tokio = { workspace = true }
11+
tokio-util = { workspace = true }
1112
tokio-rustls = { workspace = true }
1213
rustls = { workspace = true }
1314
tracing = { workspace = true }

landscape-gateway/src/lib.rs

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use tokio::runtime::Builder as RuntimeBuilder;
2828
use tokio::sync::watch;
2929
use tokio::task::JoinSet;
3030
use tokio_rustls::{server::TlsStream as TokioTlsStream, TlsAcceptor};
31+
use tokio_util::sync::CancellationToken;
3132

3233
use crate::sni_proxy::{parse_sni_from_client_hello, proxy_tls_passthrough, SniProxyRouter};
3334

@@ -41,11 +42,16 @@ pub struct GatewayTlsConfig {
4142
pub struct GatewayManager {
4243
rules: SharedRules,
4344
status: WatchService,
44-
state: Mutex<Option<JoinHandle<()>>>,
45+
state: Mutex<Option<GatewayRuntimeState>>,
4546
config: GatewayRuntimeConfig,
4647
tls_config: Option<GatewayTlsConfig>,
4748
}
4849

50+
struct GatewayRuntimeState {
51+
thread: JoinHandle<()>,
52+
cancel: CancellationToken,
53+
}
54+
4955
impl GatewayManager {
5056
pub fn new(
5157
initial_rules: Vec<HttpUpstreamRuleConfig>,
@@ -84,12 +90,22 @@ impl GatewayManager {
8490
let https_port = self.config.https_port;
8591
let tls_config = self.tls_config.clone();
8692
let status = self.status.clone();
87-
88-
let handle = std::thread::spawn(move || {
89-
run_pingora_server(rules, http_port, https_port, tls_config, status);
93+
let cancel = CancellationToken::new();
94+
let thread_cancel = cancel.clone();
95+
96+
let thread = std::thread::spawn(move || {
97+
run_pingora_server(
98+
rules,
99+
http_port,
100+
https_port,
101+
tls_config,
102+
status.clone(),
103+
thread_cancel,
104+
);
105+
status.just_change_status(ServiceStatus::Stop);
90106
});
91107

92-
*state = Some(handle);
108+
*state = Some(GatewayRuntimeState { thread, cancel });
93109
self.status.just_change_status(ServiceStatus::Running);
94110
if self.tls_config.is_some() {
95111
tracing::info!(
@@ -113,17 +129,19 @@ impl GatewayManager {
113129
}
114130
tracing::info!("Signalling gateway to stop...");
115131
self.status.just_change_status(ServiceStatus::Stopping);
132+
if let Some(state) = self.state.lock().unwrap().as_ref() {
133+
state.cancel.cancel();
134+
}
116135
}
117136

118137
/// Block until the Pingora thread has exited. Call after shutdown().
119138
pub fn join(&self) {
120139
let mut state = self.state.lock().unwrap();
121-
if let Some(handle) = state.take() {
140+
if let Some(runtime_state) = state.take() {
122141
tracing::info!("Waiting for gateway thread to finish...");
123-
if let Err(e) = handle.join() {
142+
if let Err(e) = runtime_state.thread.join() {
124143
tracing::error!("Gateway thread panicked: {:?}", e);
125144
}
126-
self.status.just_change_status(ServiceStatus::Stop);
127145
tracing::info!("Gateway stopped");
128146
}
129147
}
@@ -152,7 +170,17 @@ impl GatewayManager {
152170
impl Drop for GatewayManager {
153171
fn drop(&mut self) {
154172
self.shutdown();
155-
self.join();
173+
if self.status.is_stop() {
174+
self.join();
175+
return;
176+
}
177+
178+
if let Some(runtime_state) = self.state.lock().unwrap().take() {
179+
runtime_state.cancel.cancel();
180+
tracing::warn!(
181+
"Dropping gateway manager before gateway thread fully stopped; detaching thread"
182+
);
183+
}
156184
}
157185
}
158186

@@ -162,6 +190,7 @@ fn run_pingora_server(
162190
https_port: u16,
163191
tls_config: Option<GatewayTlsConfig>,
164192
status: WatchService,
193+
cancel: CancellationToken,
165194
) {
166195
use pingora::server::Server;
167196
use proxy_service::LandscapeReverseProxy;
@@ -178,8 +207,9 @@ fn run_pingora_server(
178207
let https_handle = tls_config.map(|tls_config| {
179208
let rules = rules.clone();
180209
let status = status.clone();
210+
let cancel = cancel.child_token();
181211
std::thread::spawn(move || {
182-
run_https_server(rules, https_port, tls_config, server_conf, status);
212+
run_https_server(rules, https_port, tls_config, server_conf, status, cancel);
183213
})
184214
});
185215

@@ -201,6 +231,7 @@ fn run_https_server(
201231
tls_config: GatewayTlsConfig,
202232
server_conf: Arc<pingora::server::configuration::ServerConf>,
203233
status: WatchService,
234+
cancel: CancellationToken,
204235
) {
205236
let runtime = RuntimeBuilder::new_multi_thread()
206237
.enable_all()
@@ -210,7 +241,7 @@ fn run_https_server(
210241

211242
runtime.block_on(async move {
212243
if let Err(e) =
213-
run_https_server_inner(rules, https_port, tls_config, server_conf, status).await
244+
run_https_server_inner(rules, https_port, tls_config, server_conf, status, cancel).await
214245
{
215246
tracing::error!("Gateway HTTPS listener exited with error: {e}");
216247
}
@@ -223,6 +254,7 @@ async fn run_https_server_inner(
223254
tls_config: GatewayTlsConfig,
224255
server_conf: Arc<pingora::server::configuration::ServerConf>,
225256
status: WatchService,
257+
cancel: CancellationToken,
226258
) -> std::io::Result<()> {
227259
use proxy_service::LandscapeReverseProxy;
228260

@@ -245,6 +277,10 @@ async fn run_https_server_inner(
245277
break;
246278
}
247279
}
280+
_ = cancel.cancelled() => {
281+
let _ = shutdown_tx.send(true);
282+
break;
283+
}
248284
accept_result = listener.accept() => {
249285
let (stream, peer_addr) = match accept_result {
250286
Ok(pair) => pair,
@@ -258,11 +294,15 @@ async fn run_https_server_inner(
258294
let app = app.clone();
259295
let sni_proxy_router = sni_proxy_router.clone();
260296
let connection_shutdown = shutdown_rx.clone();
297+
let connection_cancel = cancel.child_token();
261298

262299
tasks.spawn(async move {
263300
if sni_proxy_router.has_sni_proxy_rules() {
264301
let mut peek_buf = vec![0u8; 4096];
265-
match stream.peek(&mut peek_buf).await {
302+
match tokio::select! {
303+
_ = connection_cancel.cancelled() => return,
304+
result = stream.peek(&mut peek_buf) => result,
305+
} {
266306
Ok(size) if size > 0 => {
267307
if let Some(sni) = parse_sni_from_client_hello(&peek_buf[..size]) {
268308
if let Some(target) = sni_proxy_router.match_target(&sni) {
@@ -273,7 +313,7 @@ async fn run_https_server_inner(
273313
target.target.address,
274314
target.target.port
275315
);
276-
if let Err(e) = proxy_tls_passthrough(stream, &target).await {
316+
if let Err(e) = proxy_tls_passthrough(stream, &target, connection_cancel.clone()).await {
277317
tracing::warn!(
278318
"Gateway TLS passthrough failed for '{}' via rule '{}': {}",
279319
target.sni,
@@ -293,7 +333,10 @@ async fn run_https_server_inner(
293333
}
294334
}
295335

296-
let tls_result = tokio::time::timeout(Duration::from_secs(60), acceptor.accept(stream)).await;
336+
let tls_result = tokio::select! {
337+
_ = connection_cancel.cancelled() => return,
338+
result = tokio::time::timeout(Duration::from_secs(60), acceptor.accept(stream)) => result,
339+
};
297340
let tls_stream = match tls_result {
298341
Ok(Ok(stream)) => stream,
299342
Ok(Err(e)) => {
@@ -307,7 +350,10 @@ async fn run_https_server_inner(
307350
};
308351

309352
let stream: Stream = Box::new(GatewayTlsStream::new(tls_stream));
310-
let _ = app.process_new(stream, &connection_shutdown).await;
353+
tokio::select! {
354+
_ = connection_cancel.cancelled() => {}
355+
_ = app.process_new(stream, &connection_shutdown) => {}
356+
}
311357
});
312358
}
313359
}

landscape-gateway/src/service.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,21 @@ impl GatewayService {
6161
/// to exit cleanly, without blocking the async runtime.
6262
pub async fn shutdown_and_wait(&self, timeout: std::time::Duration) {
6363
self.manager.shutdown();
64-
let manager = self.manager.clone();
65-
let join_task = tokio::task::spawn_blocking(move || {
66-
manager.join();
67-
});
68-
match tokio::time::timeout(timeout, join_task).await {
69-
Ok(Ok(())) => tracing::info!("Gateway thread exited cleanly."),
70-
Ok(Err(e)) => tracing::error!("Gateway join task panicked: {:?}", e),
64+
let mut status_rx = self.manager.watch_service().subscribe();
65+
if matches!(*status_rx.borrow(), ServiceStatus::Stop) {
66+
tracing::info!("Gateway thread exited cleanly.");
67+
return;
68+
}
69+
70+
let wait_result = tokio::time::timeout(
71+
timeout,
72+
status_rx.wait_for(|status| matches!(status, ServiceStatus::Stop)),
73+
)
74+
.await;
75+
76+
match wait_result {
77+
Ok(Ok(_)) => tracing::info!("Gateway thread exited cleanly."),
78+
Ok(Err(e)) => tracing::error!("Gateway status watch closed during shutdown: {:?}", e),
7179
Err(_) => tracing::warn!(
7280
"Gateway did not stop within {}s timeout, proceeding.",
7381
timeout.as_secs()

landscape-gateway/src/sni_proxy.rs

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ use std::sync::atomic::{AtomicUsize, Ordering};
33
use landscape_common::config::gateway::{
44
HttpUpstreamMatchRule, HttpUpstreamRuleConfig, HttpUpstreamTarget, LoadBalanceMethod,
55
};
6-
use tokio::io;
6+
use tokio::io::{self, AsyncWriteExt};
77
use tokio::net::TcpStream;
8+
use tokio_util::sync::CancellationToken;
89

910
use crate::SharedRules;
1011

@@ -84,14 +85,53 @@ impl MatchedSniTarget {
8485
}
8586

8687
pub async fn proxy_tls_passthrough(
87-
mut downstream: TcpStream,
88+
downstream: TcpStream,
8889
target: &MatchedSniTarget,
90+
cancel: CancellationToken,
8991
) -> io::Result<()> {
9092
let upstream_addr = format!("{}:{}", target.target.address, target.target.port);
91-
let mut upstream = TcpStream::connect(&upstream_addr).await?;
93+
let upstream = tokio::select! {
94+
_ = cancel.cancelled() => return Ok(()),
95+
result = TcpStream::connect(&upstream_addr) => result?,
96+
};
97+
98+
let (mut downstream_read, mut downstream_write) = downstream.into_split();
99+
let (mut upstream_read, mut upstream_write) = upstream.into_split();
100+
101+
let client_to_upstream = tokio::spawn(async move {
102+
let result = io::copy(&mut downstream_read, &mut upstream_write).await;
103+
let _ = upstream_write.shutdown().await;
104+
result
105+
});
106+
107+
let upstream_to_client = tokio::spawn(async move {
108+
let result = io::copy(&mut upstream_read, &mut downstream_write).await;
109+
let _ = downstream_write.shutdown().await;
110+
result
111+
});
112+
113+
tokio::pin!(client_to_upstream);
114+
tokio::pin!(upstream_to_client);
115+
116+
let result = tokio::select! {
117+
_ = cancel.cancelled() => Ok(()),
118+
result = &mut client_to_upstream => join_copy_task(result),
119+
result = &mut upstream_to_client => join_copy_task(result),
120+
};
121+
122+
client_to_upstream.abort();
123+
upstream_to_client.abort();
124+
125+
result
126+
}
92127

93-
let _ = io::copy_bidirectional(&mut downstream, &mut upstream).await?;
94-
Ok(())
128+
fn join_copy_task(result: Result<io::Result<u64>, tokio::task::JoinError>) -> io::Result<()> {
129+
match result {
130+
Ok(Ok(_)) => Ok(()),
131+
Ok(Err(e)) => Err(e),
132+
Err(e) if e.is_cancelled() => Ok(()),
133+
Err(e) => Err(io::Error::other(e)),
134+
}
95135
}
96136

97137
/// Parse the SNI (Server Name Indication) extension from a TLS ClientHello message.

0 commit comments

Comments
 (0)