@@ -16,8 +16,8 @@ use std::{
1616 str:: FromStr ,
1717 sync:: Arc ,
1818 task:: { Context , Poll } ,
19- time:: Duration ,
2019} ;
20+ use tokio:: time:: { self , Duration } ;
2121pub use tokio_rustls:: server:: TlsStream ;
2222use tower:: util:: ServiceExt ;
2323use tracing:: { debug, trace, warn} ;
@@ -68,7 +68,8 @@ pub type ConditionalServerTls = Conditional<ServerTls, NoServerTls>;
6868
6969pub type Meta < T > = ( ConditionalServerTls , T ) ;
7070
71- pub type Io < T > = EitherIo < PrefixedIo < T > , TlsStream < PrefixedIo < T > > > ;
71+ type DetectIo < T > = EitherIo < T , PrefixedIo < T > > ;
72+ pub type Io < T > = EitherIo < TlsStream < DetectIo < T > > , DetectIo < T > > ;
7273
7374pub type Connection < T , I > = ( Meta < T > , Io < I > ) ;
7475
@@ -159,17 +160,35 @@ where
159160
160161 match self . local_identity . as_ref ( ) {
161162 Some ( local) => {
162- let config = Param :: < Config > :: param ( local) ;
163- let local_id = Param :: < LocalId > :: param ( local) ;
164- let timeout = tokio:: time:: sleep ( self . timeout ) ;
163+ let config: Config = local. param ( ) ;
164+ let LocalId ( local_id) = local. param ( ) ;
165165
166+ // Detect the SNI from a ClientHello (or timeout).
167+ let detect = time:: timeout ( self . timeout , detect_sni ( io) ) ;
166168 Box :: pin ( async move {
167- let ( peer, io) = tokio:: select! {
168- res = detect( io, config, local_id) => { res? }
169- ( ) = timeout => {
170- return Err ( DetectTimeout ( ( ) ) . into( ) ) ;
169+ let ( sni, io) = detect. await . map_err ( |_| DetectTimeout ( ( ) ) ) ??;
170+
171+ let ( peer, io) = match sni {
172+ // If we detected an SNI matching this proxy, terminate TLS.
173+ Some ( ServerId ( id) ) if id == local_id => {
174+ trace ! ( "Identified local SNI" ) ;
175+ let ( peer, io) = handshake ( config, io) . await ?;
176+ ( Conditional :: Some ( peer) , EitherIo :: Left ( io) )
171177 }
178+ // If we detected another SNI, continue proxying the
179+ // opaque stream.
180+ Some ( sni) => {
181+ debug ! ( %sni, "Identified foreign SNI" ) ;
182+ let peer = ServerTls :: Passthru { sni } ;
183+ ( Conditional :: Some ( peer) , EitherIo :: Right ( io) )
184+ }
185+ // If no TLS was detected, continue proxying the stream.
186+ None => (
187+ Conditional :: None ( NoServerTls :: NoClientHello ) ,
188+ EitherIo :: Right ( io) ,
189+ ) ,
172190 } ;
191+
173192 new_accept
174193 . new_service ( ( peer, target) )
175194 . oneshot ( io)
@@ -181,22 +200,20 @@ where
181200 None => {
182201 let peer = Conditional :: None ( NoServerTls :: Disabled ) ;
183202 let svc = new_accept. new_service ( ( peer, target) ) ;
184- Box :: pin ( svc. oneshot ( EitherIo :: Left ( io. into ( ) ) ) . err_into :: < Error > ( ) )
203+ Box :: pin (
204+ svc. oneshot ( EitherIo :: Right ( EitherIo :: Left ( io) ) )
205+ . err_into :: < Error > ( ) ,
206+ )
185207 }
186208 }
187209 }
188210}
189211
190- async fn detect < I > (
191- mut io : I ,
192- tls_config : Config ,
193- LocalId ( local_id) : LocalId ,
194- ) -> io:: Result < ( ConditionalServerTls , Io < I > ) >
212+ /// Peek or buffer the provided stream to determine an SNI value.
213+ async fn detect_sni < I > ( mut io : I ) -> io:: Result < ( Option < ServerId > , DetectIo < I > ) >
195214where
196215 I : io:: Peek + io:: AsyncRead + io:: AsyncWrite + Send + Sync + Unpin ,
197216{
198- const NO_TLS_META : ConditionalServerTls = Conditional :: None ( NoServerTls :: NoClientHello ) ;
199-
200217 // First, try to use MSG_PEEK to read the SNI from the TLS ClientHello.
201218 // Because peeked data does not need to be retained, we use a static
202219 // buffer to prevent needless heap allocation.
@@ -206,26 +223,15 @@ where
206223 let mut buf = [ 0u8 ; PEEK_CAPACITY ] ;
207224 let sz = io. peek ( & mut buf) . await ?;
208225 debug ! ( sz, "Peeked bytes from TCP stream" ) ;
209- match client_hello:: parse_sni ( & buf) {
210- Ok ( Some ( ServerId ( sni) ) ) if sni == local_id => {
211- trace ! ( %sni, "Identified matching SNI via peek" ) ;
212- // Terminate the TLS stream.
213- let ( tls, io) = handshake ( tls_config, PrefixedIo :: from ( io) ) . await ?;
214- return Ok ( ( Conditional :: Some ( tls) , EitherIo :: Right ( io) ) ) ;
215- }
216-
217- Ok ( Some ( sni) ) => {
218- trace ! ( %sni, "Identified non-matching SNI via peek" ) ;
219- let tls = Conditional :: Some ( ServerTls :: Passthru { sni } ) ;
220- return Ok ( ( tls, EitherIo :: Left ( io. into ( ) ) ) ) ;
221- }
226+ // Peek may return 0 bytes if the socket is not peekable.
227+ if sz > 0 {
228+ match client_hello:: parse_sni ( & buf) {
229+ Ok ( sni) => {
230+ return Ok ( ( sni, EitherIo :: Left ( io) ) ) ;
231+ }
222232
223- Ok ( None ) => {
224- trace ! ( "Not a matching TLS ClientHello" ) ;
225- return Ok ( ( NO_TLS_META , EitherIo :: Left ( io. into ( ) ) ) ) ;
233+ Err ( client_hello:: Incomplete ) => { }
226234 }
227-
228- Err ( client_hello:: Incomplete ) => { }
229235 }
230236
231237 // Peeking didn't return enough data, so instead we'll allocate more
@@ -236,25 +242,8 @@ where
236242 while io. read_buf ( & mut buf) . await ? != 0 {
237243 debug ! ( buf. len = %buf. len( ) , "Read bytes from TCP stream" ) ;
238244 match client_hello:: parse_sni ( buf. as_ref ( ) ) {
239- Ok ( Some ( ServerId ( sni) ) ) if sni == local_id => {
240- trace ! ( %sni, "Identified matching SNI via buffered read" ) ;
241- // Terminate the TLS stream.
242- let ( tls, io) =
243- handshake ( tls_config. clone ( ) , PrefixedIo :: new ( buf. freeze ( ) , io) ) . await ?;
244- return Ok ( ( Conditional :: Some ( tls) , EitherIo :: Right ( io) ) ) ;
245- }
246-
247- Ok ( Some ( sni) ) => {
248- trace ! ( %sni, "Identified non-matching SNI via peek" ) ;
249- let tls = Conditional :: Some ( ServerTls :: Passthru { sni } ) ;
250- let io = PrefixedIo :: new ( buf. freeze ( ) , io) ;
251- return Ok ( ( tls, EitherIo :: Left ( io) ) ) ;
252- }
253-
254- Ok ( None ) => {
255- trace ! ( "Not a matching TLS ClientHello" ) ;
256- let io = PrefixedIo :: new ( buf. freeze ( ) , io) ;
257- return Ok ( ( NO_TLS_META , EitherIo :: Left ( io) ) ) ;
245+ Ok ( sni) => {
246+ return Ok ( ( sni, EitherIo :: Right ( PrefixedIo :: new ( buf. freeze ( ) , io) ) ) ) ;
258247 }
259248
260249 Err ( client_hello:: Incomplete ) => {
@@ -271,8 +260,8 @@ where
271260 }
272261
273262 trace ! ( "Could not read TLS ClientHello via buffering" ) ;
274- let io = EitherIo :: Left ( PrefixedIo :: new ( buf. freeze ( ) , io) ) ;
275- Ok ( ( NO_TLS_META , io) )
263+ let io = EitherIo :: Right ( PrefixedIo :: new ( buf. freeze ( ) , io) ) ;
264+ Ok ( ( None , io) )
276265}
277266
278267async fn handshake < T > ( tls_config : Config , io : T ) -> io:: Result < ( ServerTls , TlsStream < T > ) >
@@ -373,3 +362,40 @@ impl fmt::Display for NoServerTls {
373362 }
374363 }
375364}
365+
366+ #[ cfg( test) ]
367+ mod tests {
368+ use io:: AsyncWriteExt ;
369+
370+ use super :: * ;
371+ use std:: str:: FromStr ;
372+
373+ #[ tokio:: test]
374+ async fn detect_buffered ( ) {
375+ let _ = tracing_subscriber:: fmt:: try_init ( ) ;
376+
377+ let ( mut client_io, server_io) = tokio:: io:: duplex ( 1024 ) ;
378+ let input = include_bytes ! ( "testdata/curl-example-com-client-hello.bin" ) ;
379+ let len = input. len ( ) ;
380+ let client_task = tokio:: spawn ( async move {
381+ client_io
382+ . write_all ( & * input)
383+ . await
384+ . expect ( "Write must suceed" ) ;
385+ } ) ;
386+
387+ let ( sni, io) = detect_sni ( server_io)
388+ . await
389+ . expect ( "SNI detection must not fail" ) ;
390+
391+ let identity = id:: Name :: from_str ( "example.com" ) . unwrap ( ) ;
392+ assert_eq ! ( sni, Some ( ServerId ( identity) ) ) ;
393+
394+ match io {
395+ EitherIo :: Left ( _) => panic ! ( "Detected IO should be buffered" ) ,
396+ EitherIo :: Right ( io) => assert_eq ! ( io. prefix( ) . len( ) , len, "All data must be buffered" ) ,
397+ }
398+
399+ client_task. await . expect ( "Client must not fail" ) ;
400+ }
401+ }
0 commit comments