Skip to content

Commit bb00dc6

Browse files
authored
tls: Test SNI detection (#959)
7d2bdbb fixed an issue with TLS detection. This change adds a test for this behavior. In order to implement a test, this change modifies the `detect` helper to not be responsible for SNI matching & TLS termination. The `detect` helper is replaced with a `detect_sni` function that only reads the ClientHello without terminating TLS. A test has been added for the `detect_sni` helper.
1 parent 7d2bdbb commit bb00dc6

File tree

3 files changed

+89
-57
lines changed

3 files changed

+89
-57
lines changed

linkerd/io/src/lib.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ impl Peek for tokio::net::TcpStream {
4040
}
4141
}
4242

43+
#[async_trait::async_trait]
44+
impl Peek for tokio::io::DuplexStream {
45+
async fn peek(&self, _: &mut [u8]) -> Result<usize> {
46+
Ok(0)
47+
}
48+
}
49+
4350
// === PeerAddr ===
4451

4552
pub trait PeerAddr {
@@ -71,7 +78,6 @@ impl PeerAddr for tokio_test::io::Mock {
7178
}
7279
}
7380

74-
#[cfg(feature = "tokio-test")]
7581
impl PeerAddr for tokio::io::DuplexStream {
7682
fn peer_addr(&self) -> Result<SocketAddr> {
7783
Ok(([0, 0, 0, 0], 0).into())

linkerd/tls/src/server/mod.rs

Lines changed: 82 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ use std::{
1616
str::FromStr,
1717
sync::Arc,
1818
task::{Context, Poll},
19-
time::Duration,
2019
};
20+
use tokio::time::{self, Duration};
2121
pub use tokio_rustls::server::TlsStream;
2222
use tower::util::ServiceExt;
2323
use tracing::{debug, trace, warn};
@@ -68,7 +68,8 @@ pub type ConditionalServerTls = Conditional<ServerTls, NoServerTls>;
6868

6969
pub type Meta<T> = (ConditionalServerTls, T);
7070

71-
pub type Io<T> = EitherIo<PrefixedIo<T>, TlsStream<PrefixedIo<T>>>;
71+
type DetectIo<T> = EitherIo<T, PrefixedIo<T>>;
72+
pub type Io<T> = EitherIo<TlsStream<DetectIo<T>>, DetectIo<T>>;
7273

7374
pub type Connection<T, I> = (Meta<T>, Io<I>);
7475

@@ -159,17 +160,35 @@ where
159160

160161
match self.local_identity.as_ref() {
161162
Some(local) => {
162-
let config = Param::<Config>::param(local);
163-
let local_id = Param::<LocalId>::param(local);
164-
let timeout = tokio::time::sleep(self.timeout);
163+
let config: Config = local.param();
164+
let LocalId(local_id) = local.param();
165165

166+
// Detect the SNI from a ClientHello (or timeout).
167+
let detect = time::timeout(self.timeout, detect_sni(io));
166168
Box::pin(async move {
167-
let (peer, io) = tokio::select! {
168-
res = detect(io, config, local_id) => { res? }
169-
() = timeout => {
170-
return Err(DetectTimeout(()).into());
169+
let (sni, io) = detect.await.map_err(|_| DetectTimeout(()))??;
170+
171+
let (peer, io) = match sni {
172+
// If we detected an SNI matching this proxy, terminate TLS.
173+
Some(ServerId(id)) if id == local_id => {
174+
trace!("Identified local SNI");
175+
let (peer, io) = handshake(config, io).await?;
176+
(Conditional::Some(peer), EitherIo::Left(io))
171177
}
178+
// If we detected another SNI, continue proxying the
179+
// opaque stream.
180+
Some(sni) => {
181+
debug!(%sni, "Identified foreign SNI");
182+
let peer = ServerTls::Passthru { sni };
183+
(Conditional::Some(peer), EitherIo::Right(io))
184+
}
185+
// If no TLS was detected, continue proxying the stream.
186+
None => (
187+
Conditional::None(NoServerTls::NoClientHello),
188+
EitherIo::Right(io),
189+
),
172190
};
191+
173192
new_accept
174193
.new_service((peer, target))
175194
.oneshot(io)
@@ -181,22 +200,20 @@ where
181200
None => {
182201
let peer = Conditional::None(NoServerTls::Disabled);
183202
let svc = new_accept.new_service((peer, target));
184-
Box::pin(svc.oneshot(EitherIo::Left(io.into())).err_into::<Error>())
203+
Box::pin(
204+
svc.oneshot(EitherIo::Right(EitherIo::Left(io)))
205+
.err_into::<Error>(),
206+
)
185207
}
186208
}
187209
}
188210
}
189211

190-
async fn detect<I>(
191-
mut io: I,
192-
tls_config: Config,
193-
LocalId(local_id): LocalId,
194-
) -> io::Result<(ConditionalServerTls, Io<I>)>
212+
/// Peek or buffer the provided stream to determine an SNI value.
213+
async fn detect_sni<I>(mut io: I) -> io::Result<(Option<ServerId>, DetectIo<I>)>
195214
where
196215
I: io::Peek + io::AsyncRead + io::AsyncWrite + Send + Sync + Unpin,
197216
{
198-
const NO_TLS_META: ConditionalServerTls = Conditional::None(NoServerTls::NoClientHello);
199-
200217
// First, try to use MSG_PEEK to read the SNI from the TLS ClientHello.
201218
// Because peeked data does not need to be retained, we use a static
202219
// buffer to prevent needless heap allocation.
@@ -206,26 +223,15 @@ where
206223
let mut buf = [0u8; PEEK_CAPACITY];
207224
let sz = io.peek(&mut buf).await?;
208225
debug!(sz, "Peeked bytes from TCP stream");
209-
match client_hello::parse_sni(&buf) {
210-
Ok(Some(ServerId(sni))) if sni == local_id => {
211-
trace!(%sni, "Identified matching SNI via peek");
212-
// Terminate the TLS stream.
213-
let (tls, io) = handshake(tls_config, PrefixedIo::from(io)).await?;
214-
return Ok((Conditional::Some(tls), EitherIo::Right(io)));
215-
}
216-
217-
Ok(Some(sni)) => {
218-
trace!(%sni, "Identified non-matching SNI via peek");
219-
let tls = Conditional::Some(ServerTls::Passthru { sni });
220-
return Ok((tls, EitherIo::Left(io.into())));
221-
}
226+
// Peek may return 0 bytes if the socket is not peekable.
227+
if sz > 0 {
228+
match client_hello::parse_sni(&buf) {
229+
Ok(sni) => {
230+
return Ok((sni, EitherIo::Left(io)));
231+
}
222232

223-
Ok(None) => {
224-
trace!("Not a matching TLS ClientHello");
225-
return Ok((NO_TLS_META, EitherIo::Left(io.into())));
233+
Err(client_hello::Incomplete) => {}
226234
}
227-
228-
Err(client_hello::Incomplete) => {}
229235
}
230236

231237
// Peeking didn't return enough data, so instead we'll allocate more
@@ -236,25 +242,8 @@ where
236242
while io.read_buf(&mut buf).await? != 0 {
237243
debug!(buf.len = %buf.len(), "Read bytes from TCP stream");
238244
match client_hello::parse_sni(buf.as_ref()) {
239-
Ok(Some(ServerId(sni))) if sni == local_id => {
240-
trace!(%sni, "Identified matching SNI via buffered read");
241-
// Terminate the TLS stream.
242-
let (tls, io) =
243-
handshake(tls_config.clone(), PrefixedIo::new(buf.freeze(), io)).await?;
244-
return Ok((Conditional::Some(tls), EitherIo::Right(io)));
245-
}
246-
247-
Ok(Some(sni)) => {
248-
trace!(%sni, "Identified non-matching SNI via peek");
249-
let tls = Conditional::Some(ServerTls::Passthru { sni });
250-
let io = PrefixedIo::new(buf.freeze(), io);
251-
return Ok((tls, EitherIo::Left(io)));
252-
}
253-
254-
Ok(None) => {
255-
trace!("Not a matching TLS ClientHello");
256-
let io = PrefixedIo::new(buf.freeze(), io);
257-
return Ok((NO_TLS_META, EitherIo::Left(io)));
245+
Ok(sni) => {
246+
return Ok((sni, EitherIo::Right(PrefixedIo::new(buf.freeze(), io))));
258247
}
259248

260249
Err(client_hello::Incomplete) => {
@@ -271,8 +260,8 @@ where
271260
}
272261

273262
trace!("Could not read TLS ClientHello via buffering");
274-
let io = EitherIo::Left(PrefixedIo::new(buf.freeze(), io));
275-
Ok((NO_TLS_META, io))
263+
let io = EitherIo::Right(PrefixedIo::new(buf.freeze(), io));
264+
Ok((None, io))
276265
}
277266

278267
async fn handshake<T>(tls_config: Config, io: T) -> io::Result<(ServerTls, TlsStream<T>)>
@@ -373,3 +362,40 @@ impl fmt::Display for NoServerTls {
373362
}
374363
}
375364
}
365+
366+
#[cfg(test)]
367+
mod tests {
368+
use io::AsyncWriteExt;
369+
370+
use super::*;
371+
use std::str::FromStr;
372+
373+
#[tokio::test]
374+
async fn detect_buffered() {
375+
let _ = tracing_subscriber::fmt::try_init();
376+
377+
let (mut client_io, server_io) = tokio::io::duplex(1024);
378+
let input = include_bytes!("testdata/curl-example-com-client-hello.bin");
379+
let len = input.len();
380+
let client_task = tokio::spawn(async move {
381+
client_io
382+
.write_all(&*input)
383+
.await
384+
.expect("Write must suceed");
385+
});
386+
387+
let (sni, io) = detect_sni(server_io)
388+
.await
389+
.expect("SNI detection must not fail");
390+
391+
let identity = id::Name::from_str("example.com").unwrap();
392+
assert_eq!(sni, Some(ServerId(identity)));
393+
394+
match io {
395+
EitherIo::Left(_) => panic!("Detected IO should be buffered"),
396+
EitherIo::Right(io) => assert_eq!(io.prefix().len(), len, "All data must be buffered"),
397+
}
398+
399+
client_task.await.expect("Client must not fail");
400+
}
401+
}
517 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)