Skip to content

Commit 30a21e4

Browse files
authored
SERVER-120105 Expose pre-TLS trait on listeners (#2)
1 parent 286a015 commit 30a21e4

File tree

2 files changed

+137
-5
lines changed

2 files changed

+137
-5
lines changed

pingora-core/src/listeners/mod.rs

Lines changed: 133 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,24 @@ pub trait TlsAccept {
118118

119119
pub type TlsAcceptCallbacks = Box<dyn TlsAccept + Send + Sync>;
120120

121+
/// Some protocols, such as the proxy protocol, must be inspected before the TLS
122+
/// handshake. The below trait provides access to the raw TCP stream right
123+
/// before TLS for these situations.
124+
#[async_trait]
125+
pub trait InspectPreTls: Send + Sync {
126+
/// The implementation can read bytes from the stream (e.g., PROXY protocol header)
127+
/// before the TLS handshake takes place.
128+
///
129+
/// If this method returns an error, the connection will be dropped.
130+
async fn inspect(&self, stream: &mut L4Stream) -> Result<()>;
131+
}
132+
121133
struct TransportStackBuilder {
122134
l4: ServerAddress,
123135
tls: Option<TlsSettings>,
124136
#[cfg(feature = "connection_filter")]
125137
connection_filter: Option<Arc<dyn ConnectionFilter>>,
138+
pre_tls_inspector: Option<Arc<dyn InspectPreTls>>,
126139
}
127140

128141
impl TransportStackBuilder {
@@ -148,6 +161,7 @@ impl TransportStackBuilder {
148161
Ok(TransportStack {
149162
l4,
150163
tls: self.tls.take().map(|tls| Arc::new(tls.build())),
164+
pre_tls_inspector: self.pre_tls_inspector.clone(),
151165
})
152166
}
153167
}
@@ -156,6 +170,7 @@ impl TransportStackBuilder {
156170
pub(crate) struct TransportStack {
157171
l4: ListenerEndpoint,
158172
tls: Option<Arc<Acceptor>>,
173+
pre_tls_inspector: Option<Arc<dyn InspectPreTls>>,
159174
}
160175

161176
impl TransportStack {
@@ -168,6 +183,7 @@ impl TransportStack {
168183
Ok(UninitializedStream {
169184
l4: stream,
170185
tls: self.tls.clone(),
186+
pre_tls_inspector: self.pre_tls_inspector.clone(),
171187
})
172188
}
173189

@@ -179,17 +195,27 @@ impl TransportStack {
179195
pub(crate) struct UninitializedStream {
180196
l4: L4Stream,
181197
tls: Option<Arc<Acceptor>>,
198+
pre_tls_inspector: Option<Arc<dyn InspectPreTls>>,
182199
}
183200

184201
impl UninitializedStream {
185202
pub async fn handshake(mut self) -> Result<Stream> {
186203
self.l4.set_buffer();
187-
if let Some(tls) = self.tls {
204+
205+
// Expose raw l4 stream to any registered pre-TLS inspectors before
206+
// handshaking.
207+
if let Some(inspector) = self.pre_tls_inspector.as_ref() {
208+
inspector.inspect(&mut self.l4).await?;
209+
}
210+
211+
let res_with_stream: Result<Stream> = if let Some(tls) = self.tls {
188212
let tls_stream = tls.tls_handshake(self.l4).await?;
189213
Ok(Box::new(tls_stream))
190214
} else {
191215
Ok(Box::new(self.l4))
192-
}
216+
};
217+
218+
res_with_stream
193219
}
194220

195221
/// Get the peer address of the connection if available
@@ -205,6 +231,7 @@ pub struct Listeners {
205231
stacks: Vec<TransportStackBuilder>,
206232
#[cfg(feature = "connection_filter")]
207233
connection_filter: Option<Arc<dyn ConnectionFilter>>,
234+
pre_tls_inspector: Option<Arc<dyn InspectPreTls>>,
208235
}
209236

210237
impl Listeners {
@@ -214,6 +241,7 @@ impl Listeners {
214241
stacks: vec![],
215242
#[cfg(feature = "connection_filter")]
216243
connection_filter: None,
244+
pre_tls_inspector: None,
217245
}
218246
}
219247
/// Create a new [`Listeners`] with a TCP server endpoint from the given string.
@@ -294,13 +322,31 @@ impl Listeners {
294322
}
295323
}
296324

325+
/// Set a pre-TLS inspector for all endpoints in this listener collection.
326+
///
327+
/// The inspector will be invoked after TCP accept but before the TLS handshake,
328+
/// allowing the application to read and process data such as PROXY protocol
329+
/// headers that arrive before TLS.
330+
pub fn set_pre_tls_inspector(&mut self, inspector: Arc<dyn InspectPreTls>) {
331+
log::debug!("Setting pre-TLS inspector on Listeners");
332+
333+
// Store the inspector for future endpoints
334+
self.pre_tls_inspector = Some(inspector.clone());
335+
336+
// Apply to existing stacks
337+
for stack in &mut self.stacks {
338+
stack.pre_tls_inspector = Some(inspector.clone());
339+
}
340+
}
341+
297342
/// Add the given [`ServerAddress`] to `self` with the given [`TlsSettings`] if provided
298343
pub fn add_endpoint(&mut self, l4: ServerAddress, tls: Option<TlsSettings>) {
299344
self.stacks.push(TransportStackBuilder {
300345
l4,
301346
tls,
302347
#[cfg(feature = "connection_filter")]
303348
connection_filter: self.connection_filter.clone(),
349+
pre_tls_inspector: self.pre_tls_inspector.clone(),
304350
})
305351
}
306352

@@ -341,8 +387,8 @@ mod test {
341387

342388
#[tokio::test]
343389
async fn test_listen_tcp() {
344-
let addr1 = "127.0.0.1:7101";
345-
let addr2 = "127.0.0.1:7102";
390+
let addr1 = "127.0.0.1:7107";
391+
let addr2 = "127.0.0.1:7108";
346392
let mut listeners = Listeners::tcp(addr1);
347393
listeners.add_tcp(addr2);
348394

@@ -460,4 +506,87 @@ mod test {
460506
);
461507
}
462508
}
509+
510+
#[tokio::test]
511+
#[cfg(any(feature = "openssl", feature = "boringssl"))]
512+
async fn test_inspect_pre_tls() {
513+
use pingora_error::{Error, Result};
514+
use std::pin::Pin;
515+
use std::sync::{Arc, Mutex};
516+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
517+
518+
use crate::protocols::tls::SslStream;
519+
use crate::tls::ssl;
520+
struct HelloInspector {
521+
stored_bytes: Arc<Mutex<Vec<u8>>>,
522+
}
523+
524+
#[async_trait]
525+
impl InspectPreTls for HelloInspector {
526+
async fn inspect(&self, stream: &mut L4Stream) -> Result<()> {
527+
let mut buf = [0u8; 5];
528+
stream.read_exact(&mut buf).await.map_err(|e| {
529+
Error::new_str("failed to read pre-TLS bytes").more_context(format!("{e}"))
530+
})?;
531+
self.stored_bytes.lock().unwrap().extend_from_slice(&buf);
532+
if &buf != b"hello" {
533+
return Err(Error::new_str("pre-TLS bytes did not match 'hello'"));
534+
}
535+
Ok(())
536+
}
537+
}
538+
539+
let stored = Arc::new(Mutex::new(Vec::new()));
540+
let inspector = Arc::new(HelloInspector {
541+
stored_bytes: stored.clone(),
542+
});
543+
544+
let addr = "127.0.0.1:7109";
545+
let cert_path = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR"));
546+
let key_path = format!("{}/tests/keys/key.pem", env!("CARGO_MANIFEST_DIR"));
547+
let mut listeners = Listeners::tls(addr, &cert_path, &key_path).unwrap();
548+
549+
// Register HelloInspector on the listener so it fires before TLS handshaking.
550+
listeners.set_pre_tls_inspector(inspector.clone());
551+
let listener = listeners
552+
.build(
553+
#[cfg(unix)]
554+
None,
555+
)
556+
.await
557+
.unwrap()
558+
.pop()
559+
.unwrap();
560+
561+
let server_handle = tokio::spawn(async move {
562+
// Acceptor thread should handshake, which will perform pre-TLS inspection
563+
// and then the TLS handshake.
564+
let stream = listener.accept().await.unwrap();
565+
stream.handshake().await.unwrap();
566+
});
567+
568+
// make sure the above starts before the lines below
569+
sleep(Duration::from_millis(10)).await;
570+
571+
let client_handle = tokio::spawn(async move {
572+
// Prepend the TLS handshake with the bytes "hello".
573+
let mut tcp_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
574+
tcp_stream.write_all(b"hello").await.unwrap();
575+
576+
// Perform the TLS handshake with verification disabled because the
577+
// certificates aren't actually valid.
578+
let ssl_context = ssl::SslContext::builder(ssl::SslMethod::tls())
579+
.unwrap()
580+
.build();
581+
let mut ssl_obj = ssl::Ssl::new(&ssl_context).unwrap();
582+
ssl_obj.set_verify(ssl::SslVerifyMode::NONE);
583+
let mut tls_stream = SslStream::new(ssl_obj, tcp_stream).unwrap();
584+
Pin::new(&mut tls_stream).connect().await.unwrap();
585+
});
586+
587+
server_handle.await.unwrap();
588+
client_handle.await.unwrap();
589+
590+
assert_eq!(&*stored.lock().unwrap(), b"hello");
591+
}
463592
}

pingora-core/src/protocols/l4/stream.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,10 @@ impl Stream {
446446
}
447447

448448
/// Put Some data back to the head of the stream to be read again
449-
pub(crate) fn rewind(&mut self, data: &[u8]) {
449+
/// This can be used in cases where we "peek" at data only to find
450+
/// it doesn't match what's expected, and so it needs to be put back
451+
/// for a different protocol to potentially use it.
452+
pub fn rewind(&mut self, data: &[u8]) {
450453
if !data.is_empty() {
451454
self.rewind_read_buf.push(data.to_vec());
452455
}

0 commit comments

Comments
 (0)