Skip to content

Commit d5c10bd

Browse files
authored
inbound: Split HTTP detection stack from TLS (#664)
The `DetectTls` module only operates on `TcpStream`s (because it uses the `TcpStream::peek` api); but this complicates writing tests on the HTTP stack to validate changes like #660. This change decouples these accept stacks so that they can be tested more easily.
1 parent d7784bb commit d5c10bd

File tree

2 files changed

+106
-46
lines changed

2 files changed

+106
-46
lines changed

linkerd/app/inbound/src/lib.rs

Lines changed: 95 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,19 @@ use linkerd2_app_core::{
2222
reconnect, router,
2323
spans::SpanConverter,
2424
svc::{self, NewService},
25-
transport::{self, io::BoxedIo, listen, tls},
25+
transport::{self, io, listen, tls},
2626
Error, ProxyMetrics, TraceContextLayer, DST_OVERRIDE_HEADER,
2727
};
2828
use std::collections::HashMap;
29-
use tokio::sync::mpsc;
29+
use tokio::{net::TcpStream, sync::mpsc};
3030
use tracing::debug_span;
3131

3232
pub mod endpoint;
3333
mod prevent_loop;
3434
mod require_identity_for_ports;
3535

36+
type SensorIo<T> = io::SensorIo<T, transport::metrics::Sensor>;
37+
3638
#[derive(Clone, Debug)]
3739
pub struct Config {
3840
pub proxy: ProxyConfig,
@@ -64,7 +66,7 @@ impl Config {
6466
> + Send
6567
+ 'static
6668
where
67-
L: tower::Service<Target, Response = S> + Unpin + Send + Clone + 'static,
69+
L: tower::Service<Target, Response = S> + Unpin + Clone + Send + Sync + 'static,
6870
L::Error: Into<Error>,
6971
L::Future: Unpin + Send,
7072
S: tower::Service<
@@ -90,14 +92,22 @@ impl Config {
9092
metrics.clone(),
9193
span_sink.clone(),
9294
);
93-
self.build_server(
94-
tcp_connect,
95+
96+
// Forwards TCP streams that cannot be decoded as HTTP.
97+
let tcp_forward = svc::stack(tcp_connect)
98+
.push_make_thunk()
99+
.push_on_response(svc::layer::mk(tcp::Forward::new))
100+
.into_inner();
101+
102+
let accept = self.build_accept(
103+
tcp_forward.clone(),
95104
http_router,
96-
local_identity,
97-
metrics,
105+
metrics.clone(),
98106
span_sink,
99107
drain,
100-
)
108+
);
109+
110+
self.build_tls_accept(accept, tcp_forward, local_identity, metrics)
101111
}
102112

103113
pub fn build_tcp_connect(
@@ -120,7 +130,7 @@ impl Config {
120130
// Establishes connections to remote peers (for both TCP
121131
// forwarding and HTTP proxying).
122132
svc::connect(self.proxy.connect.keepalive)
123-
.push_map_response(BoxedIo::new) // Ensures the transport propagates shutdown properly.
133+
.push_map_response(io::BoxedIo::new) // Ensures the transport propagates shutdown properly.
124134
// Limits the time we wait for a connection to be established.
125135
.push_timeout(self.proxy.connect.timeout)
126136
.push(metrics.transport.layer_connect(TransportLabels))
@@ -284,33 +294,36 @@ impl Config {
284294
.into_inner()
285295
}
286296

287-
pub fn build_server<C, H, S>(
288-
self,
289-
tcp_connect: C,
297+
pub fn build_accept<I, F, A, H, S>(
298+
&self,
299+
tcp_forward: F,
290300
http_router: H,
291-
local_identity: tls::Conditional<identity::Local>,
292301
metrics: ProxyMetrics,
293302
span_sink: Option<mpsc::Sender<oc::Span>>,
294303
drain: drain::Watch,
295304
) -> impl tower::Service<
296-
listen::Addrs,
305+
tls::accept::Meta,
297306
Error = impl Into<Error>,
298307
Future = impl Send + 'static,
299308
Response = impl tower::Service<
300-
tokio::net::TcpStream,
309+
I,
301310
Response = (),
302311
Error = impl Into<Error>,
303312
Future = impl Send + 'static,
304313
> + Send
305314
+ 'static,
306-
> + Send
315+
> + Clone
316+
+ Send
307317
+ 'static
308318
where
309-
C: tower::Service<TcpEndpoint> + Unpin + Clone + Send + Sync + 'static,
310-
C::Error: Into<Error>,
311-
C::Response: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
312-
C::Future: Unpin + Send,
313-
H: tower::Service<Target, Response = S, Error = Error> + Unpin + Send + Clone + 'static,
319+
I: io::AsyncRead + io::AsyncWrite + Unpin + Send + 'static,
320+
F: tower::Service<TcpEndpoint, Response = A> + Unpin + Clone + Send + 'static,
321+
F::Error: Into<Error>,
322+
F::Future: Send,
323+
A: tower::Service<io::PrefixedIo<I>, Response = ()> + Clone + Send + 'static,
324+
A::Error: Into<Error>,
325+
A::Future: Send,
326+
H: tower::Service<Target, Response = S, Error = Error> + Unpin + Clone + Send + 'static,
314327
H::Future: Send,
315328
S: tower::Service<
316329
http::Request<http::boxed::Payload>,
@@ -322,13 +335,11 @@ impl Config {
322335
{
323336
let ProxyConfig {
324337
server: ServerConfig { h2_settings, .. },
325-
disable_protocol_detection_for_ports: skip_detect,
326338
dispatch_timeout,
327339
max_in_flight_requests,
328340
detect_protocol_timeout,
329341
..
330-
} = self.proxy;
331-
let require_identity = self.require_identity_for_inbound_ports;
342+
} = self.proxy.clone();
332343

333344
// Handles requests as they are initially received by the proxy.
334345
let http_admit_request = svc::layers()
@@ -378,34 +389,72 @@ impl Config {
378389
.into_inner()
379390
.into_make_service();
380391

381-
// The stack is served lazily since some layers (notably buffer) spawn
382-
// tasks from their constructor. This helps to ensure that tasks are
383-
// spawned on the same runtime as the proxy.
384-
// Forwards TCP streams that cannot be decoded as HTTP.
385-
let tcp_forward = svc::stack(tcp_connect)
386-
.push_make_thunk()
387-
.push_on_response(svc::layer::mk(tcp::Forward::new));
388-
389-
let http = DetectHttp::new(
392+
DetectHttp::new(
390393
h2_settings,
391394
detect_protocol_timeout,
392395
http_server,
393-
tcp_forward.clone().push_map_target(TcpEndpoint::from),
396+
svc::stack(tcp_forward)
397+
.push_map_target(TcpEndpoint::from)
398+
.into_inner(),
394399
drain.clone(),
395-
);
400+
)
401+
}
396402

397-
let tls = svc::stack(http)
398-
.push(admit::AdmitLayer::new(require_identity))
399-
.push(metrics.transport.layer_accept(TransportLabels))
400-
.push(svc::layer::mk(|inner| {
401-
tls::DetectTls::new(local_identity.clone(), inner, detect_protocol_timeout)
402-
}));
403+
pub fn build_tls_accept<D, A, F, B>(
404+
self,
405+
detect: D,
406+
tcp_forward: F,
407+
identity: tls::Conditional<identity::Local>,
408+
metrics: ProxyMetrics,
409+
) -> impl tower::Service<
410+
listen::Addrs,
411+
Error = impl Into<Error>,
412+
Future = impl Send + 'static,
413+
Response = impl tower::Service<
414+
TcpStream,
415+
Response = (),
416+
Error = impl Into<Error>,
417+
Future = impl Send + 'static,
418+
> + Send
419+
+ 'static,
420+
> + Send
421+
+ 'static
422+
where
423+
D: tower::Service<tls::accept::Meta, Response = A> + Unpin + Clone + Send + Sync + 'static,
424+
D::Error: Into<Error>,
425+
D::Future: Unpin + Send,
426+
A: tower::Service<SensorIo<io::BoxedIo>, Response = ()> + Unpin + Send + 'static,
427+
A::Error: Into<Error>,
428+
A::Future: Send,
429+
F: tower::Service<TcpEndpoint, Response = B> + Unpin + Clone + Send + Sync + 'static,
430+
F::Error: Into<Error>,
431+
F::Future: Unpin + Send,
432+
B: tower::Service<SensorIo<TcpStream>, Response = ()> + Unpin + Send + 'static,
433+
B::Error: Into<Error>,
434+
B::Future: Send,
435+
{
436+
let ProxyConfig {
437+
disable_protocol_detection_for_ports: skip_detect,
438+
detect_protocol_timeout,
439+
..
440+
} = self.proxy;
441+
let require_identity = self.require_identity_for_inbound_ports;
403442

404-
let accept_fwd = tcp_forward
405-
.push_map_target(TcpEndpoint::from)
406-
.push(metrics.transport.layer_accept(TransportLabels))
407-
.into_inner();
408-
svc::stack::MakeSwitch::new(skip_detect, tls, accept_fwd)
443+
svc::stack::MakeSwitch::new(
444+
skip_detect,
445+
svc::stack(detect)
446+
.push(admit::AdmitLayer::new(require_identity))
447+
.push(metrics.transport.layer_accept(TransportLabels))
448+
.push(tls::DetectTls::layer(
449+
identity.clone(),
450+
detect_protocol_timeout,
451+
))
452+
.into_inner(),
453+
svc::stack(tcp_forward)
454+
.push_map_target(TcpEndpoint::from)
455+
.push(metrics.transport.layer_accept(TransportLabels))
456+
.into_inner(),
457+
)
409458
}
410459
}
411460

linkerd/proxy/transport/src/tls/accept.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use futures::prelude::*;
66
use linkerd2_dns_name as dns;
77
use linkerd2_error::{Error, Never};
88
use linkerd2_identity as identity;
9+
use linkerd2_stack::layer;
910
pub use rustls::ServerConfig as Config;
1011
use std::{
1112
pin::Pin,
@@ -74,6 +75,16 @@ impl<I: HasConfig, M> DetectTls<I, M> {
7475
timeout,
7576
}
7677
}
78+
79+
pub fn layer(
80+
local_identity: Conditional<I>,
81+
timeout: Duration,
82+
) -> impl layer::Layer<M, Service = Self>
83+
where
84+
I: Clone,
85+
{
86+
layer::mk(move |inner| Self::new(local_identity.clone(), inner, timeout))
87+
}
7788
}
7889

7990
impl<I: HasConfig, M> tower::Service<Addrs> for DetectTls<I, M>

0 commit comments

Comments
 (0)