1- use std:: io;
1+ use std:: io:: { self , Chain , Cursor } ;
22use std:: io:: { Read , Write } ;
33use std:: sync:: Arc ;
44
@@ -17,7 +17,7 @@ pub(crate) struct SwitchableConn<T: Read + Write>(pub(crate) Option<EitherConn<T
1717
1818pub ( crate ) enum EitherConn < T : Read + Write > {
1919 Plain ( T ) ,
20- Tls ( rustls:: StreamOwned < ServerConnection , T > ) ,
20+ Tls ( rustls:: StreamOwned < ServerConnection , PrependedReader < T > > ) ,
2121}
2222
2323impl < T : Read + Write > Read for SwitchableConn < T > {
@@ -50,9 +50,16 @@ impl<T: Read + Write> SwitchableConn<T> {
5050 SwitchableConn ( Some ( EitherConn :: Plain ( rw) ) )
5151 }
5252
53- pub fn switch_to_tls ( & mut self , config : Arc < ServerConfig > ) -> io:: Result < ( ) > {
53+ pub fn switch_to_tls (
54+ & mut self ,
55+ config : Arc < ServerConfig > ,
56+ to_prepend : & [ u8 ] ,
57+ ) -> io:: Result < ( ) > {
5458 let replacement = match self . 0 . take ( ) {
55- Some ( EitherConn :: Plain ( plain) ) => Ok ( EitherConn :: Tls ( create_stream ( plain, config) ?) ) ,
59+ Some ( EitherConn :: Plain ( plain) ) => Ok ( EitherConn :: Tls ( create_stream (
60+ PrependedReader :: new ( to_prepend, plain) ,
61+ config,
62+ ) ?) ) ,
5663 Some ( EitherConn :: Tls ( _) ) => Err ( io:: Error :: new (
5764 io:: ErrorKind :: Other ,
5865 "tls variant found when plain was expected" ,
@@ -64,3 +71,48 @@ impl<T: Read + Write> SwitchableConn<T> {
6471 Ok ( ( ) )
6572 }
6673}
74+
75+ pub ( crate ) struct PrependedReader < RW : Read + Write > {
76+ inner : Chain < Cursor < Vec < u8 > > , RW > ,
77+ }
78+
79+ impl < RW : Read + Write > PrependedReader < RW > {
80+ fn new ( prepended : & [ u8 ] , rw : RW ) -> PrependedReader < RW > {
81+ PrependedReader {
82+ inner : Cursor :: new ( prepended. to_vec ( ) ) . chain ( rw) ,
83+ }
84+ }
85+ }
86+
87+ impl < RW : Read + Write > Read for PrependedReader < RW > {
88+ fn read ( & mut self , buf : & mut [ u8 ] ) -> io:: Result < usize > {
89+ self . inner . read ( buf)
90+ }
91+ }
92+
93+ impl < RW : Read + Write > Write for PrependedReader < RW > {
94+ fn write ( & mut self , buf : & [ u8 ] ) -> io:: Result < usize > {
95+ self . inner . get_mut ( ) . 1 . write ( buf)
96+ }
97+
98+ fn flush ( & mut self ) -> io:: Result < ( ) > {
99+ self . inner . get_mut ( ) . 1 . flush ( )
100+ }
101+ }
102+
103+ #[ cfg( test) ]
104+ mod tests {
105+ use std:: io:: { Cursor , Read } ;
106+
107+ use super :: PrependedReader ;
108+
109+ #[ test]
110+ fn test_bufreader_replace ( ) {
111+ let mut rw = Cursor :: new ( vec ! [ 1 , 2 , 3 ] ) ;
112+ let mut br = PrependedReader :: new ( & [ 0 , 1 , 2 ] , & mut rw) ;
113+ let mut out = Vec :: new ( ) ;
114+ br. read_to_end ( & mut out) . unwrap ( ) ;
115+
116+ assert_eq ! ( & out, & [ 0 , 1 , 2 , 1 , 2 , 3 ] ) ;
117+ }
118+ }
0 commit comments