Skip to content

Commit f3ec8b7

Browse files
authored
Decouple TLS detection from TCP connections (#818)
The TLS accept stack--the `DetectTls` type--is currently coupled to the `TcpStream` type; but the `TcpStream` type requires an actual OS-level TCP connection, which isn't ideal for testing. This change introduces a new trait `tls::accept::Detectable` and an implementation for `TcpStream`. This will permit us to use alternate implementations (e.g., for `io::DuplexStream`). This change also updates the `DetectTls` type name to `NewDetectTls`, and `AcceptTls` to `DetectTls`, to fit our more recent idioms.
1 parent 467e979 commit f3ec8b7

File tree

7 files changed

+123
-99
lines changed

7 files changed

+123
-99
lines changed

linkerd/app/inbound/src/lib.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -384,37 +384,35 @@ impl Config {
384384
.into_inner()
385385
}
386386

387-
pub fn build_tls_accept<D, DSvc, F, FSvc>(
387+
pub fn build_tls_accept<I, D, DSvc, F, FSvc>(
388388
self,
389389
detect: D,
390390
tcp_forward: F,
391391
identity: tls::Conditional<identity::Local>,
392392
metrics: metrics::Proxy,
393393
) -> impl svc::NewService<
394394
listen::Addrs,
395-
Service = impl svc::Service<TcpStream, Response = (), Error = Error, Future = impl Send>,
395+
Service = impl svc::Service<I, Response = (), Error = Error, Future = impl Send>,
396396
> + Clone
397397
where
398+
I: tls::accept::Detectable + Send + 'static,
398399
D: svc::NewService<TcpAccept, Service = DSvc> + Clone + Send + 'static,
399-
DSvc: svc::Service<SensorIo<tls::accept::Io>, Response = ()> + Send + 'static,
400+
DSvc: svc::Service<SensorIo<tls::accept::Io<I>>, Response = ()> + Send + 'static,
400401
DSvc::Error: Into<Error>,
401402
DSvc::Future: Send,
402403
F: svc::NewService<TcpEndpoint, Service = FSvc> + Clone + 'static,
403-
FSvc: svc::Service<SensorIo<TcpStream>, Response = ()> + 'static,
404+
FSvc: svc::Service<SensorIo<I>, Response = ()> + 'static,
404405
FSvc::Error: Into<Error>,
405406
FSvc::Future: Send,
406407
{
407-
let ProxyConfig {
408-
detect_protocol_timeout,
409-
..
410-
} = self.proxy;
411-
let require_identity = self.require_identity_for_inbound_ports;
412-
413408
svc::stack(detect)
414-
.push_request_filter(require_identity)
409+
.push_request_filter(self.require_identity_for_inbound_ports)
415410
.push(metrics.transport.layer_accept())
416411
.push_map_target(TcpAccept::from)
417-
.push(tls::DetectTls::layer(identity, detect_protocol_timeout))
412+
.push(tls::NewDetectTls::layer(
413+
identity,
414+
self.proxy.detect_protocol_timeout,
415+
))
418416
.push_switch(
419417
self.disable_protocol_detection_for_ports,
420418
svc::stack(tcp_forward)

linkerd/app/src/admin.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ impl Config {
3333

3434
let (ready, latch) = admin::Readiness::new();
3535
let admin = admin::Admin::new(report, ready, shutdown, trace);
36-
let accept = tls::DetectTls::new(
36+
let accept = tls::NewDetectTls::new(
3737
identity,
3838
admin.into_accept(),
3939
std::time::Duration::from_secs(1),

linkerd/app/src/tap.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl Config {
5151

5252
let service =
5353
tap::AcceptPermittedClients::new(permitted_peer_identities.into(), server);
54-
let accept = tls::DetectTls::new(
54+
let accept = tls::NewDetectTls::new(
5555
identity,
5656
move |meta: tls::accept::Meta| {
5757
let service = service.clone();

linkerd/proxy/tap/src/accept.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ use linkerd2_proxy_transport::{
99
io,
1010
tls::{accept::Connection, Conditional, ReasonForNoPeerName},
1111
};
12-
use std::future::Future;
13-
use std::pin::Pin;
14-
use std::sync::Arc;
15-
use std::task::{Context, Poll};
12+
use std::{
13+
future::Future,
14+
pin::Pin,
15+
sync::Arc,
16+
task::{Context, Poll},
17+
};
18+
use tokio::net::TcpStream;
1619
use tower::Service;
1720

1821
#[derive(Clone, Debug)]
@@ -63,7 +66,7 @@ impl AcceptPermittedClients {
6366
}
6467
}
6568

66-
impl Service<Connection> for AcceptPermittedClients {
69+
impl Service<Connection<TcpStream>> for AcceptPermittedClients {
6770
type Response = ServeFuture;
6871
type Error = Error;
6972
type Future = future::Ready<Result<Self::Response, Self::Error>>;
@@ -72,7 +75,7 @@ impl Service<Connection> for AcceptPermittedClients {
7275
Poll::Ready(Ok(()))
7376
}
7477

75-
fn call(&mut self, (meta, io): Connection) -> Self::Future {
78+
fn call(&mut self, (meta, io): Connection<TcpStream>) -> Self::Future {
7679
future::ok(match meta.peer_identity {
7780
Conditional::Some(ref peer) => {
7881
if self.permitted_client_ids.contains(peer) {

linkerd/proxy/transport/src/tls/accept.rs

Lines changed: 97 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,28 @@ pub trait HasConfig {
2727
fn tls_server_config(&self) -> Arc<Config>;
2828
}
2929

30+
/// Must be implemented for I/O types like `TcpStream` on which TLS is
31+
/// transparently detected.
32+
///
33+
/// This is necessary so that we can be generic over the I/O type but still use
34+
/// `TcpStream::peek` to avoid allocating for mTLS SNI detection.
35+
#[async_trait::async_trait]
36+
pub trait Detectable {
37+
/// Attempts to detect a `ClientHello` message from the underlying transport
38+
/// and, if its SNI matches `local_name`, initiates a TLS server handshake to
39+
/// decrypt the stream.
40+
///
41+
/// Returns the client's identity, if one exists, and an optionally decrypted
42+
/// transport.
43+
async fn detected(
44+
self,
45+
config: Arc<Config>,
46+
local_name: identity::Name,
47+
) -> io::Result<(PeerIdentity, Io<Self>)>
48+
where
49+
Self: Sized;
50+
}
51+
3052
/// Produces a server config that fails to handshake all connections.
3153
pub fn empty_config() -> Arc<Config> {
3254
let verifier = rustls::NoClientAuth::new();
@@ -40,12 +62,12 @@ pub struct Meta {
4062
pub addrs: Addrs,
4163
}
4264

43-
pub type Io = EitherIo<PrefixedIo<TcpStream>, TlsStream<PrefixedIo<TcpStream>>>;
65+
pub type Io<T> = EitherIo<PrefixedIo<T>, TlsStream<PrefixedIo<T>>>;
4466

45-
pub type Connection = (Meta, Io);
67+
pub type Connection<T> = (Meta, Io<T>);
4668

4769
#[derive(Clone, Debug)]
48-
pub struct DetectTls<I, A> {
70+
pub struct NewDetectTls<I, A> {
4971
local_identity: Conditional<I>,
5072
inner: A,
5173
timeout: Duration,
@@ -55,7 +77,7 @@ pub struct DetectTls<I, A> {
5577
pub struct DetectTimeout(());
5678

5779
#[derive(Clone, Debug)]
58-
pub struct AcceptTls<I, N> {
80+
pub struct DetectTls<I, N> {
5981
addrs: Addrs,
6082
local_identity: Conditional<I>,
6183
inner: N,
@@ -70,7 +92,7 @@ const PEEK_CAPACITY: usize = 512;
7092
// insufficient. This is the same value used in HTTP detection.
7193
const BUFFER_CAPACITY: usize = 8192;
7294

73-
impl<I: HasConfig, N> DetectTls<I, N> {
95+
impl<I: HasConfig, N> NewDetectTls<I, N> {
7496
pub fn new(local_identity: Conditional<I>, inner: N, timeout: Duration) -> Self {
7597
Self {
7698
local_identity,
@@ -90,15 +112,15 @@ impl<I: HasConfig, N> DetectTls<I, N> {
90112
}
91113
}
92114

93-
impl<I, N> NewService<Addrs> for DetectTls<I, N>
115+
impl<I, N> NewService<Addrs> for NewDetectTls<I, N>
94116
where
95117
I: HasConfig + Clone,
96118
N: NewService<Meta> + Clone,
97119
{
98-
type Service = AcceptTls<I, N>;
120+
type Service = DetectTls<I, N>;
99121

100122
fn new_service(&mut self, addrs: Addrs) -> Self::Service {
101-
AcceptTls {
123+
DetectTls {
102124
addrs,
103125
local_identity: self.local_identity.clone(),
104126
inner: self.inner.clone(),
@@ -107,12 +129,14 @@ where
107129
}
108130
}
109131

110-
impl<I: HasConfig, N, A> tower::Service<TcpStream> for AcceptTls<I, N>
132+
impl<T, I, N, NSvc> tower::Service<T> for DetectTls<I, N>
111133
where
112-
N: NewService<Meta, Service = A> + Clone + Send + 'static,
113-
A: tower::Service<Io, Response = ()> + Send + 'static,
114-
A::Error: Into<Error>,
115-
A::Future: Send,
134+
T: Detectable + Send + 'static,
135+
I: HasConfig,
136+
N: NewService<Meta, Service = NSvc> + Clone + Send + 'static,
137+
NSvc: tower::Service<Io<T>, Response = ()> + Send + 'static,
138+
NSvc::Error: Into<Error>,
139+
NSvc::Future: Send,
116140
{
117141
type Response = ();
118142
type Error = Error;
@@ -122,7 +146,7 @@ where
122146
Poll::Ready(Ok(()))
123147
}
124148

125-
fn call(&mut self, tcp: TcpStream) -> Self::Future {
149+
fn call(&mut self, tcp: T) -> Self::Future {
126150
let addrs = self.addrs.clone();
127151
let mut new_accept = self.inner.clone();
128152

@@ -134,7 +158,7 @@ where
134158

135159
Box::pin(async move {
136160
let (peer_identity, io) = tokio::select! {
137-
res = detect(config, name, tcp) => { res? }
161+
res = tcp.detected(config, name) => { res? }
138162
() = timeout => {
139163
return Err(DetectTimeout(()).into());
140164
}
@@ -163,71 +187,74 @@ where
163187
}
164188
}
165189

166-
pub async fn detect(
167-
tls_config: Arc<Config>,
168-
local_id: identity::Name,
169-
mut tcp: TcpStream,
170-
) -> io::Result<(PeerIdentity, Io)> {
171-
const NO_TLS_META: PeerIdentity = Conditional::None(ReasonForNoPeerName::NoTlsFromRemote);
172-
173-
// First, try to use MSG_PEEK to read the SNI from the TLS ClientHello.
174-
// Because peeked data does not need to be retained, we use a static
175-
// buffer to prevent needless heap allocation.
176-
//
177-
// Anecdotally, the ClientHello sent by Linkerd proxies is <300B. So a
178-
// ~500B byte buffer is more than enough.
179-
let mut buf = [0u8; PEEK_CAPACITY];
180-
let sz = tcp.peek(&mut buf).await?;
181-
debug!(sz, "Peeked bytes from TCP stream");
182-
match conditional_accept::match_client_hello(&buf, &local_id) {
183-
conditional_accept::Match::Matched => {
184-
trace!("Identified matching SNI via peek");
185-
// Terminate the TLS stream.
186-
let (peer_id, tls) = handshake(tls_config, PrefixedIo::from(tcp)).await?;
187-
return Ok((peer_id, EitherIo::Right(tls)));
188-
}
189-
190-
conditional_accept::Match::NotMatched => {
191-
trace!("Not a matching TLS ClientHello");
192-
return Ok((NO_TLS_META, EitherIo::Left(tcp.into())));
193-
}
194-
195-
conditional_accept::Match::Incomplete => {}
196-
}
197-
198-
// Peeking didn't return enough data, so instead we'll allocate more
199-
// capacity and try reading data from the socket.
200-
debug!("Attempting to buffer TLS ClientHello after incomplete peek");
201-
let mut buf = BytesMut::with_capacity(BUFFER_CAPACITY);
202-
debug!(buf.capacity = %buf.capacity(), "Reading bytes from TCP stream");
203-
while tcp.read_buf(&mut buf).await? != 0 {
204-
debug!(buf.len = %buf.len(), "Read bytes from TCP stream");
205-
match conditional_accept::match_client_hello(buf.as_ref(), &local_id) {
190+
#[async_trait::async_trait]
191+
impl Detectable for TcpStream {
192+
async fn detected(
193+
mut self,
194+
tls_config: Arc<Config>,
195+
local_id: identity::Name,
196+
) -> io::Result<(PeerIdentity, Io<Self>)> {
197+
const NO_TLS_META: PeerIdentity = Conditional::None(ReasonForNoPeerName::NoTlsFromRemote);
198+
199+
// First, try to use MSG_PEEK to read the SNI from the TLS ClientHello.
200+
// Because peeked data does not need to be retained, we use a static
201+
// buffer to prevent needless heap allocation.
202+
//
203+
// Anecdotally, the ClientHello sent by Linkerd proxies is <300B. So a
204+
// ~500B byte buffer is more than enough.
205+
let mut buf = [0u8; PEEK_CAPACITY];
206+
let sz = self.peek(&mut buf).await?;
207+
debug!(sz, "Peeked bytes from TCP stream");
208+
match conditional_accept::match_client_hello(&buf, &local_id) {
206209
conditional_accept::Match::Matched => {
207-
trace!("Identified matching SNI via buffered read");
210+
trace!("Identified matching SNI via peek");
208211
// Terminate the TLS stream.
209-
let (peer_id, tls) =
210-
handshake(tls_config.clone(), PrefixedIo::new(buf.freeze(), tcp)).await?;
212+
let (peer_id, tls) = handshake(tls_config, PrefixedIo::from(self)).await?;
211213
return Ok((peer_id, EitherIo::Right(tls)));
212214
}
213215

214-
conditional_accept::Match::NotMatched => break,
216+
conditional_accept::Match::NotMatched => {
217+
trace!("Not a matching TLS ClientHello");
218+
return Ok((NO_TLS_META, EitherIo::Left(self.into())));
219+
}
215220

216-
conditional_accept::Match::Incomplete => {
217-
if buf.capacity() == 0 {
218-
// If we can't buffer an entire TLS ClientHello, it
219-
// almost definitely wasn't initiated by another proxy,
220-
// at least.
221-
warn!("Buffer insufficient for TLS ClientHello");
222-
break;
221+
conditional_accept::Match::Incomplete => {}
222+
}
223+
224+
// Peeking didn't return enough data, so instead we'll allocate more
225+
// capacity and try reading data from the socket.
226+
debug!("Attempting to buffer TLS ClientHello after incomplete peek");
227+
let mut buf = BytesMut::with_capacity(BUFFER_CAPACITY);
228+
debug!(buf.capacity = %buf.capacity(), "Reading bytes from TCP stream");
229+
while self.read_buf(&mut buf).await? != 0 {
230+
debug!(buf.len = %buf.len(), "Read bytes from TCP stream");
231+
match conditional_accept::match_client_hello(buf.as_ref(), &local_id) {
232+
conditional_accept::Match::Matched => {
233+
trace!("Identified matching SNI via buffered read");
234+
// Terminate the TLS stream.
235+
let (peer_id, tls) =
236+
handshake(tls_config.clone(), PrefixedIo::new(buf.freeze(), self)).await?;
237+
return Ok((peer_id, EitherIo::Right(tls)));
238+
}
239+
240+
conditional_accept::Match::NotMatched => break,
241+
242+
conditional_accept::Match::Incomplete => {
243+
if buf.capacity() == 0 {
244+
// If we can't buffer an entire TLS ClientHello, it
245+
// almost definitely wasn't initiated by another proxy,
246+
// at least.
247+
warn!("Buffer insufficient for TLS ClientHello");
248+
break;
249+
}
223250
}
224251
}
225252
}
226-
}
227253

228-
trace!("Could not read TLS ClientHello via buffering");
229-
let io = EitherIo::Left(PrefixedIo::new(buf.freeze(), tcp));
230-
Ok((NO_TLS_META, io))
254+
trace!("Could not read TLS ClientHello via buffering");
255+
let io = EitherIo::Left(PrefixedIo::new(buf.freeze(), self));
256+
Ok((NO_TLS_META, io))
257+
}
231258
}
232259

233260
async fn handshake<T>(

linkerd/proxy/transport/src/tls/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pub mod accept;
66
pub mod client;
77
mod conditional_accept;
88

9-
pub use self::accept::DetectTls;
9+
pub use self::accept::NewDetectTls;
1010
pub use self::client::Client;
1111

1212
/// Describes whether or not a connection was secured with TLS and, if it was

linkerd/proxy/transport/tests/tls_accept.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@
88
use futures::prelude::*;
99
use linkerd2_error::Never;
1010
use linkerd2_identity::{test_util, CrtKey, Name};
11-
use linkerd2_proxy_transport::tls::{
12-
self,
13-
accept::{self, DetectTls},
14-
Conditional,
15-
};
11+
use linkerd2_proxy_transport::tls::{self, Conditional};
1612
use linkerd2_proxy_transport::{
1713
io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
1814
BindTcp, ConnectTcp,
@@ -115,7 +111,7 @@ where
115111
CF: Future<Output = Result<CR, io::Error>> + Send + 'static,
116112
CR: Send + 'static,
117113
// Server
118-
S: Fn(accept::Connection) -> SF + Clone + Send + 'static,
114+
S: Fn(tls::accept::Connection<TcpStream>) -> SF + Clone + Send + 'static,
119115
SF: Future<Output = Result<SR, io::Error>> + Send + 'static,
120116
SR: Send + 'static,
121117
{
@@ -148,9 +144,9 @@ where
148144

149145
let (listen_addr, listen) = BindTcp::new(addr, None).bind().expect("must bind");
150146

151-
let mut detect = DetectTls::new(
147+
let mut detect = tls::NewDetectTls::new(
152148
server_tls,
153-
move |meta: accept::Meta| {
149+
move |meta: tls::accept::Meta| {
154150
let server = server.clone();
155151
let sender = sender.clone();
156152
let peer_identity = Some(meta.peer_identity.clone());

0 commit comments

Comments
 (0)