@@ -118,11 +118,24 @@ pub trait TlsAccept {
118118
119119pub type TlsAcceptCallbacks = Box < dyn TlsAccept + Send + Sync > ;
120120
121+ /// Some protocols, such as the proxy protocol, must be inspected before the TLS
122+ /// handshake. The below trait provides access to the raw TCP stream right
123+ /// before TLS for these situations.
124+ #[ async_trait]
125+ pub trait InspectPreTls : Send + Sync {
126+ /// The implementation can read bytes from the stream (e.g., PROXY protocol header)
127+ /// before the TLS handshake takes place.
128+ ///
129+ /// If this method returns an error, the connection will be dropped.
130+ async fn inspect ( & self , stream : & mut L4Stream ) -> Result < ( ) > ;
131+ }
132+
121133struct TransportStackBuilder {
122134 l4 : ServerAddress ,
123135 tls : Option < TlsSettings > ,
124136 #[ cfg( feature = "connection_filter" ) ]
125137 connection_filter : Option < Arc < dyn ConnectionFilter > > ,
138+ pre_tls_inspector : Option < Arc < dyn InspectPreTls > > ,
126139}
127140
128141impl TransportStackBuilder {
@@ -148,6 +161,7 @@ impl TransportStackBuilder {
148161 Ok ( TransportStack {
149162 l4,
150163 tls : self . tls . take ( ) . map ( |tls| Arc :: new ( tls. build ( ) ) ) ,
164+ pre_tls_inspector : self . pre_tls_inspector . clone ( ) ,
151165 } )
152166 }
153167}
@@ -156,6 +170,7 @@ impl TransportStackBuilder {
156170pub ( crate ) struct TransportStack {
157171 l4 : ListenerEndpoint ,
158172 tls : Option < Arc < Acceptor > > ,
173+ pre_tls_inspector : Option < Arc < dyn InspectPreTls > > ,
159174}
160175
161176impl TransportStack {
@@ -168,6 +183,7 @@ impl TransportStack {
168183 Ok ( UninitializedStream {
169184 l4 : stream,
170185 tls : self . tls . clone ( ) ,
186+ pre_tls_inspector : self . pre_tls_inspector . clone ( ) ,
171187 } )
172188 }
173189
@@ -179,17 +195,27 @@ impl TransportStack {
179195pub ( crate ) struct UninitializedStream {
180196 l4 : L4Stream ,
181197 tls : Option < Arc < Acceptor > > ,
198+ pre_tls_inspector : Option < Arc < dyn InspectPreTls > > ,
182199}
183200
184201impl UninitializedStream {
185202 pub async fn handshake ( mut self ) -> Result < Stream > {
186203 self . l4 . set_buffer ( ) ;
187- if let Some ( tls) = self . tls {
204+
205+ // Expose raw l4 stream to any registered pre-TLS inspectors before
206+ // handshaking.
207+ if let Some ( inspector) = self . pre_tls_inspector . as_ref ( ) {
208+ inspector. inspect ( & mut self . l4 ) . await ?;
209+ }
210+
211+ let res_with_stream: Result < Stream > = if let Some ( tls) = self . tls {
188212 let tls_stream = tls. tls_handshake ( self . l4 ) . await ?;
189213 Ok ( Box :: new ( tls_stream) )
190214 } else {
191215 Ok ( Box :: new ( self . l4 ) )
192- }
216+ } ;
217+
218+ res_with_stream
193219 }
194220
195221 /// Get the peer address of the connection if available
@@ -205,6 +231,7 @@ pub struct Listeners {
205231 stacks : Vec < TransportStackBuilder > ,
206232 #[ cfg( feature = "connection_filter" ) ]
207233 connection_filter : Option < Arc < dyn ConnectionFilter > > ,
234+ pre_tls_inspector : Option < Arc < dyn InspectPreTls > > ,
208235}
209236
210237impl Listeners {
@@ -214,6 +241,7 @@ impl Listeners {
214241 stacks : vec ! [ ] ,
215242 #[ cfg( feature = "connection_filter" ) ]
216243 connection_filter : None ,
244+ pre_tls_inspector : None ,
217245 }
218246 }
219247 /// Create a new [`Listeners`] with a TCP server endpoint from the given string.
@@ -294,13 +322,31 @@ impl Listeners {
294322 }
295323 }
296324
325+ /// Set a pre-TLS inspector for all endpoints in this listener collection.
326+ ///
327+ /// The inspector will be invoked after TCP accept but before the TLS handshake,
328+ /// allowing the application to read and process data such as PROXY protocol
329+ /// headers that arrive before TLS.
330+ pub fn set_pre_tls_inspector ( & mut self , inspector : Arc < dyn InspectPreTls > ) {
331+ log:: debug!( "Setting pre-TLS inspector on Listeners" ) ;
332+
333+ // Store the inspector for future endpoints
334+ self . pre_tls_inspector = Some ( inspector. clone ( ) ) ;
335+
336+ // Apply to existing stacks
337+ for stack in & mut self . stacks {
338+ stack. pre_tls_inspector = Some ( inspector. clone ( ) ) ;
339+ }
340+ }
341+
297342 /// Add the given [`ServerAddress`] to `self` with the given [`TlsSettings`] if provided
298343 pub fn add_endpoint ( & mut self , l4 : ServerAddress , tls : Option < TlsSettings > ) {
299344 self . stacks . push ( TransportStackBuilder {
300345 l4,
301346 tls,
302347 #[ cfg( feature = "connection_filter" ) ]
303348 connection_filter : self . connection_filter . clone ( ) ,
349+ pre_tls_inspector : self . pre_tls_inspector . clone ( ) ,
304350 } )
305351 }
306352
@@ -341,8 +387,8 @@ mod test {
341387
342388 #[ tokio:: test]
343389 async fn test_listen_tcp ( ) {
344- let addr1 = "127.0.0.1:7101 " ;
345- let addr2 = "127.0.0.1:7102 " ;
390+ let addr1 = "127.0.0.1:7107 " ;
391+ let addr2 = "127.0.0.1:7108 " ;
346392 let mut listeners = Listeners :: tcp ( addr1) ;
347393 listeners. add_tcp ( addr2) ;
348394
@@ -460,4 +506,87 @@ mod test {
460506 ) ;
461507 }
462508 }
509+
510+ #[ tokio:: test]
511+ #[ cfg( any( feature = "openssl" , feature = "boringssl" ) ) ]
512+ async fn test_inspect_pre_tls ( ) {
513+ use pingora_error:: { Error , Result } ;
514+ use std:: pin:: Pin ;
515+ use std:: sync:: { Arc , Mutex } ;
516+ use tokio:: io:: { AsyncReadExt , AsyncWriteExt } ;
517+
518+ use crate :: protocols:: tls:: SslStream ;
519+ use crate :: tls:: ssl;
520+ struct HelloInspector {
521+ stored_bytes : Arc < Mutex < Vec < u8 > > > ,
522+ }
523+
524+ #[ async_trait]
525+ impl InspectPreTls for HelloInspector {
526+ async fn inspect ( & self , stream : & mut L4Stream ) -> Result < ( ) > {
527+ let mut buf = [ 0u8 ; 5 ] ;
528+ stream. read_exact ( & mut buf) . await . map_err ( |e| {
529+ Error :: new_str ( "failed to read pre-TLS bytes" ) . more_context ( format ! ( "{e}" ) )
530+ } ) ?;
531+ self . stored_bytes . lock ( ) . unwrap ( ) . extend_from_slice ( & buf) ;
532+ if & buf != b"hello" {
533+ return Err ( Error :: new_str ( "pre-TLS bytes did not match 'hello'" ) ) ;
534+ }
535+ Ok ( ( ) )
536+ }
537+ }
538+
539+ let stored = Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ;
540+ let inspector = Arc :: new ( HelloInspector {
541+ stored_bytes : stored. clone ( ) ,
542+ } ) ;
543+
544+ let addr = "127.0.0.1:7109" ;
545+ let cert_path = format ! ( "{}/tests/keys/server.crt" , env!( "CARGO_MANIFEST_DIR" ) ) ;
546+ let key_path = format ! ( "{}/tests/keys/key.pem" , env!( "CARGO_MANIFEST_DIR" ) ) ;
547+ let mut listeners = Listeners :: tls ( addr, & cert_path, & key_path) . unwrap ( ) ;
548+
549+ // Register HelloInspector on the listener so it fires before TLS handshaking.
550+ listeners. set_pre_tls_inspector ( inspector. clone ( ) ) ;
551+ let listener = listeners
552+ . build (
553+ #[ cfg( unix) ]
554+ None ,
555+ )
556+ . await
557+ . unwrap ( )
558+ . pop ( )
559+ . unwrap ( ) ;
560+
561+ let server_handle = tokio:: spawn ( async move {
562+ // Acceptor thread should handshake, which will perform pre-TLS inspection
563+ // and then the TLS handshake.
564+ let stream = listener. accept ( ) . await . unwrap ( ) ;
565+ stream. handshake ( ) . await . unwrap ( ) ;
566+ } ) ;
567+
568+ // make sure the above starts before the lines below
569+ sleep ( Duration :: from_millis ( 10 ) ) . await ;
570+
571+ let client_handle = tokio:: spawn ( async move {
572+ // Prepend the TLS handshake with the bytes "hello".
573+ let mut tcp_stream = tokio:: net:: TcpStream :: connect ( addr) . await . unwrap ( ) ;
574+ tcp_stream. write_all ( b"hello" ) . await . unwrap ( ) ;
575+
576+ // Perform the TLS handshake with verification disabled because the
577+ // certificates aren't actually valid.
578+ let ssl_context = ssl:: SslContext :: builder ( ssl:: SslMethod :: tls ( ) )
579+ . unwrap ( )
580+ . build ( ) ;
581+ let mut ssl_obj = ssl:: Ssl :: new ( & ssl_context) . unwrap ( ) ;
582+ ssl_obj. set_verify ( ssl:: SslVerifyMode :: NONE ) ;
583+ let mut tls_stream = SslStream :: new ( ssl_obj, tcp_stream) . unwrap ( ) ;
584+ Pin :: new ( & mut tls_stream) . connect ( ) . await . unwrap ( ) ;
585+ } ) ;
586+
587+ server_handle. await . unwrap ( ) ;
588+ client_handle. await . unwrap ( ) ;
589+
590+ assert_eq ! ( & * stored. lock( ) . unwrap( ) , b"hello" ) ;
591+ }
463592}
0 commit comments