Skip to content

Commit a2ef4a7

Browse files
authored
Ensure all forwarded TCP streams keep the proxy running (#786)
The proxy is not intended to shutdown until all in-flight responses and TCP streams are complete. This is enforced by the `ServeHttp` middleware, which confusingly handles some TCP forwarding in addition to HTTP server initialization. But there are HTTP streams that are not instrumented by this middleware and will therefore be abruptly interrupted when a shutdown signal is received. This change modifie the `ServeHttp` middleware to only handle HTTP streams. a `NewOptional` middleware is introduced to conditionally build a TCP stack when no HTTP type has been detected; and a `drain::Retain` middleware is added to prevent the drain signal from being released until TCP streams are completed.
1 parent ad1b358 commit a2ef4a7

File tree

12 files changed

+243
-171
lines changed

12 files changed

+243
-171
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,9 +1104,11 @@ version = "0.1.0"
11041104
dependencies = [
11051105
"futures 0.3.5",
11061106
"linkerd2-error",
1107+
"linkerd2-stack",
11071108
"pin-project 0.4.22",
11081109
"tokio 0.3.5",
11091110
"tokio-test 0.3.0",
1111+
"tower",
11101112
]
11111113

11121114
[[package]]

linkerd/app/inbound/src/lib.rs

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,11 @@ impl Config {
106106
// Forwards TCP streams that cannot be decoded as HTTP.
107107
let tcp_forward = svc::stack(tcp_connect)
108108
.push_make_thunk()
109-
.push_on_response(svc::layer::mk(tcp::Forward::new))
109+
.push_on_response(
110+
svc::layers()
111+
.push(svc::layer::mk(tcp::Forward::new))
112+
.push(drain::Retain::layer(drain.clone())),
113+
)
110114
.instrument(|_: &_| debug_span!("tcp"))
111115
.into_inner();
112116

@@ -380,23 +384,21 @@ impl Config {
380384
.check_new_service::<(http::Version, TcpAccept), http::Request<_>>()
381385
.into_inner();
382386

383-
svc::stack(http::NewServeHttp::new(
384-
h2_settings,
385-
http_server,
386-
svc::stack(tcp_forward)
387-
.push_map_target(TcpEndpoint::from)
388-
.into_inner(),
389-
drain,
390-
))
391-
.check_new_clone::<(Option<http::Version>, TcpAccept)>()
392-
.push_cache(cache_max_idle_age)
393-
.push(transport::NewDetectService::layer(
394-
transport::detect::DetectTimeout::new(
395-
detect_protocol_timeout,
396-
http::DetectHttp::default(),
397-
),
398-
))
399-
.into_inner()
387+
svc::stack(http::NewServeHttp::new(h2_settings, http_server, drain))
388+
.push(svc::stack::NewOptional::layer(
389+
svc::stack(tcp_forward)
390+
.push_map_target(TcpEndpoint::from)
391+
.into_inner(),
392+
))
393+
.check_new_clone::<(Option<http::Version>, TcpAccept)>()
394+
.push_cache(cache_max_idle_age)
395+
.push(transport::NewDetectService::layer(
396+
transport::detect::DetectTimeout::new(
397+
detect_protocol_timeout,
398+
http::DetectHttp::default(),
399+
),
400+
))
401+
.into_inner()
400402
}
401403

402404
pub fn build_tls_accept<D, A, F, B>(

linkerd/app/outbound/src/ingress.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,8 @@ where
130130
))
131131
.into_inner();
132132

133-
svc::stack(http::NewServeHttp::new(h2_settings, http, tcp, drain))
134-
.check_new_service::<(Option<http::Version>, tcp::Accept), io::PrefixedIo<transport::metrics::SensorIo<I>>>()
135-
.check_new_clone::<(Option<http::Version>, tcp::Accept)>()
133+
svc::stack(http::NewServeHttp::new(h2_settings, http, drain))
134+
.push(svc::stack::NewOptional::layer(tcp))
136135
.push_cache(cache_max_idle_age)
137136
.push(transport::NewDetectService::layer(
138137
transport::detect::DetectTimeout::new(

linkerd/app/outbound/src/server.rs

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ where
5656
P::Future: Unpin + Send,
5757
P::Error: Send,
5858
{
59-
let tcp_balance = tcp::balance::stack(&config.proxy, tcp_connect.clone(), resolve);
59+
let tcp_balance =
60+
tcp::balance::stack(&config.proxy, tcp_connect.clone(), resolve, drain.clone());
6061
let accept = accept_stack(
6162
config,
6263
profiles,
@@ -206,36 +207,30 @@ where
206207

207208
// Load balances TCP streams that cannot be decoded as HTTP.
208209
let tcp_balance = svc::stack(tcp_balance)
209-
.push_map_target(tcp::Concrete::from)
210-
.push(profiles::split::layer())
211-
.check_new_service::<tcp::Logical, transport::io::PrefixedIo<transport::metrics::SensorIo<I>>>()
212-
.push_switch(tcp::Logical::should_resolve, tcp_forward)
213-
.push_on_response(
214-
svc::layers()
215-
.push_failfast(dispatch_timeout)
216-
.push_spawn_buffer(buffer_capacity),
217-
)
218-
.instrument(|_: &_| debug_span!("tcp"))
219-
.check_new_service::<tcp::Logical, transport::io::PrefixedIo<transport::metrics::SensorIo<I>>>()
220-
.into_inner();
210+
.push_map_target(tcp::Concrete::from)
211+
.push(profiles::split::layer())
212+
.check_new_service::<tcp::Logical, transport::io::PrefixedIo<transport::metrics::SensorIo<I>>>()
213+
.push_switch(tcp::Logical::should_resolve, tcp_forward)
214+
.push_on_response(
215+
svc::layers()
216+
.push_failfast(dispatch_timeout)
217+
.push_spawn_buffer(buffer_capacity),
218+
)
219+
.instrument(|_: &_| debug_span!("tcp"))
220+
.check_new_service::<tcp::Logical, transport::io::PrefixedIo<transport::metrics::SensorIo<I>>>()
221+
.into_inner();
221222

222-
let http = svc::stack(http::NewServeHttp::new(
223-
h2_settings,
224-
http_server,
225-
tcp_balance,
226-
drain,
227-
))
228-
.check_new_clone::<(Option<http::Version>, tcp::Logical)>()
229-
.check_new_service::<(Option<http::Version>, tcp::Logical), transport::io::PrefixedIo<transport::metrics::SensorIo<I>>>()
230-
.push_cache(cache_max_idle_age)
231-
.push(transport::NewDetectService::layer(
232-
transport::detect::DetectTimeout::new(
233-
detect_protocol_timeout,
234-
http::DetectHttp::default(),
235-
),
236-
))
237-
.check_new_service::<tcp::Logical, transport::metrics::SensorIo<I>>()
238-
.into_inner();
223+
let http = svc::stack(http::NewServeHttp::new(h2_settings, http_server, drain))
224+
.push(svc::stack::NewOptional::layer(tcp_balance))
225+
.push_cache(cache_max_idle_age)
226+
.push(transport::NewDetectService::layer(
227+
transport::detect::DetectTimeout::new(
228+
detect_protocol_timeout,
229+
http::DetectHttp::default(),
230+
),
231+
))
232+
.check_new_service::<tcp::Logical, transport::metrics::SensorIo<I>>()
233+
.into_inner();
239234

240235
let tcp = svc::stack(tcp::connect::forward(tcp_connect))
241236
.push_map_target(tcp::Endpoint::from_logical(

linkerd/app/outbound/src/tcp/balance.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use super::{Concrete, Endpoint};
22
use crate::resolve;
33
use linkerd2_app_core::{
44
config::ProxyConfig,
5+
drain,
56
proxy::{api_resolve::Metadata, core::Resolve, tcp},
67
svc,
78
transport::io,
@@ -14,6 +15,7 @@ pub fn stack<C, R, I>(
1415
config: &ProxyConfig,
1516
connect: C,
1617
resolve: R,
18+
drain: drain::Watch,
1719
) -> impl svc::NewService<
1820
Concrete,
1921
Service = impl tower::Service<
@@ -51,7 +53,8 @@ where
5153
crate::EWMA_DEFAULT_RTT,
5254
crate::EWMA_DECAY,
5355
))
54-
.push(svc::layer::mk(tcp::Forward::new)),
56+
.push(svc::layer::mk(tcp::Forward::new))
57+
.push(drain::Retain::layer(drain)),
5558
)
5659
.check_make_service::<Concrete, I>()
5760
.into_new_service()

linkerd/app/outbound/src/tcp/tests.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ async fn plaintext_tcp() {
6969
);
7070

7171
// Build the outbound TCP balancer stack.
72-
let forward = super::balance::stack(&cfg.proxy, connect, resolver).new_service(concrete);
72+
let (_, drain) = drain::channel();
73+
let forward = super::balance::stack(&cfg.proxy, connect, resolver, drain).new_service(concrete);
7374

7475
forward
7576
.oneshot(client_io)
@@ -150,7 +151,8 @@ async fn tls_when_hinted() {
150151
client_io.read(b"hello").write(b"world");
151152

152153
// Build the outbound TCP balancer stack.
153-
let mut balance = super::balance::stack(&cfg.proxy, connect, resolver);
154+
let (_, drain) = drain::channel();
155+
let mut balance = super::balance::stack(&cfg.proxy, connect, resolver, drain);
154156

155157
let plain = balance
156158
.new_service(plain_concrete)

linkerd/drain/Cargo.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ edition = "2018"
66
publish = false
77

88
[dependencies]
9+
futures = "0.3"
910
linkerd2-error = { path = "../error" }
10-
tokio = {version = "0.3", features = ["sync", "macros", "stream"]}
11+
linkerd2-stack = { path = "../stack" }
1112
pin-project = "0.4"
12-
futures = "0.3"
13+
tokio = { version = "0.3", features = ["sync", "macros", "stream"]}
14+
tower = { version = "0.4", default-features = false }
1315

1416
[dev-dependencies]
15-
tokio-test = "0.3"
17+
tokio-test = "0.3"

linkerd/drain/src/lib.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
#![deny(warnings, rust_2018_idioms)]
2+
3+
mod retain;
4+
5+
pub use crate::retain::Retain;
26
use linkerd2_error::Never;
37
use pin_project::pin_project;
4-
use std::future::Future;
5-
use std::pin::Pin;
6-
use std::task::{Context, Poll};
8+
use std::{
9+
future::Future,
10+
pin::Pin,
11+
task::{Context, Poll},
12+
};
713
use tokio::{
814
stream::Stream,
915
sync::{mpsc, watch},

linkerd/drain/src/retain.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
use crate::Watch;
2+
use linkerd2_stack::layer;
3+
use std::{
4+
future::Future,
5+
pin::Pin,
6+
task::{Context, Poll},
7+
};
8+
9+
/// Holds a drain::Watch for as long as a request is pending.
10+
#[derive(Clone, Debug)]
11+
pub struct Retain<S> {
12+
inner: S,
13+
drain: Watch,
14+
}
15+
16+
// === impl Retain ===
17+
18+
impl<S> Retain<S> {
19+
pub fn new(drain: Watch, inner: S) -> Self {
20+
Self { drain, inner }
21+
}
22+
23+
pub fn layer(drain: Watch) -> impl layer::Layer<S, Service = Self> + Clone {
24+
layer::mk(move |inner| Self::new(drain.clone(), inner))
25+
}
26+
}
27+
28+
impl<Req, S> tower::Service<Req> for Retain<S>
29+
where
30+
S: tower::Service<Req>,
31+
S::Future: Send + 'static,
32+
{
33+
type Response = S::Response;
34+
type Error = S::Error;
35+
type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send + 'static>>;
36+
37+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38+
self.inner.poll_ready(cx)
39+
}
40+
41+
fn call(&mut self, req: Req) -> Self::Future {
42+
Box::pin(
43+
self.drain
44+
.clone()
45+
.ignore_signal()
46+
.release_after(self.inner.call(req)),
47+
)
48+
}
49+
}

0 commit comments

Comments
 (0)