@@ -12,8 +12,11 @@ use mysql::SslOpts;
1212use rustls:: { Certificate , PrivateKey , ServerConfig } ;
1313use std:: error:: Error ;
1414use std:: io;
15+ use std:: io:: Read ;
16+ use std:: io:: Write ;
1517use std:: net;
1618use std:: thread;
19+ use std:: time:: Duration ;
1720
1821use msql_srv:: {
1922 Column , ErrorKind , InitWriter , MysqlIntermediary , MysqlShim , ParamParser , QueryResultWriter ,
@@ -210,6 +213,112 @@ fn it_connects_tls_both() {
210213 . test ( |_| { } )
211214}
212215
216+ #[ test]
217+ #[ cfg( feature = "tls" ) ]
218+ fn it_connects_tls_both_with_delayed_server_read ( ) {
219+ // This test is to ensure correctly handle the case when we read both the pre-TLS data as well
220+ // as (at least part of) the TLS handshake into our the buffer. When that happens, we need to
221+ // ensure we correctly pass that TLS part of the data to rustls so that is can handle the TLS
222+ // handshake properly.
223+ use std:: { marker:: PhantomData , sync:: Arc } ;
224+
225+ struct MyShim < RW > {
226+ ph : PhantomData < RW > ,
227+ }
228+
229+ impl < RW : Read + Write > MysqlShim < RW > for MyShim < RW > {
230+ type Error = io:: Error ;
231+
232+ fn on_prepare (
233+ & mut self ,
234+ _: & str ,
235+ _: StatementMetaWriter < ' _ , RW > ,
236+ ) -> Result < ( ) , Self :: Error > {
237+ unreachable ! ( )
238+ }
239+
240+ fn on_execute (
241+ & mut self ,
242+ _: u32 ,
243+ _: ParamParser < ' _ > ,
244+ _: QueryResultWriter < ' _ , RW > ,
245+ ) -> Result < ( ) , Self :: Error > {
246+ unreachable ! ( )
247+ }
248+
249+ fn on_close ( & mut self , _: u32 ) {
250+ unreachable ! ( )
251+ }
252+
253+ fn on_query ( & mut self , _: & str , _: QueryResultWriter < ' _ , RW > ) -> Result < ( ) , Self :: Error > {
254+ unreachable ! ( )
255+ }
256+
257+ fn tls_config ( & self ) -> Option < Arc < ServerConfig > > {
258+ let cert = rcgen:: generate_simple_self_signed ( vec ! [ "localhost" . to_string( ) ] ) . unwrap ( ) ;
259+
260+ Some ( std:: sync:: Arc :: new (
261+ ServerConfig :: builder ( )
262+ . with_safe_defaults ( )
263+ . with_no_client_auth ( )
264+ . with_single_cert (
265+ vec ! [ Certificate ( cert. serialize_der( ) . unwrap( ) ) ] ,
266+ PrivateKey ( cert. get_key_pair ( ) . serialize_der ( ) ) ,
267+ )
268+ . unwrap ( ) ,
269+ ) )
270+ }
271+ }
272+
273+ let shim = MyShim {
274+ ph : PhantomData :: default ( ) ,
275+ } ;
276+
277+ let listener = net:: TcpListener :: bind ( "127.0.0.1:0" ) . unwrap ( ) ;
278+ let port = listener. local_addr ( ) . unwrap ( ) . port ( ) ;
279+ let jh = thread:: spawn ( move || {
280+ let ( s, _) = listener. accept ( ) . unwrap ( ) ;
281+ let s = DelayedReadRW {
282+ s,
283+ read_delay : Duration :: from_millis ( 200 ) ,
284+ } ;
285+ MysqlIntermediary :: run_on ( shim, s)
286+ } ) ;
287+
288+ let db = mysql:: Conn :: new (
289+ OptsBuilder :: default ( )
290+ . ip_or_hostname ( Some ( "localhost" ) )
291+ . tcp_port ( port)
292+ . ssl_opts ( Some (
293+ SslOpts :: default ( ) . with_danger_accept_invalid_certs ( true ) ,
294+ ) ) ,
295+ )
296+ . unwrap ( ) ;
297+ drop ( db) ;
298+ jh. join ( ) . unwrap ( ) . unwrap ( ) ;
299+ }
300+
301+ struct DelayedReadRW < RW : Read + Write > {
302+ s : RW ,
303+ read_delay : Duration ,
304+ }
305+
306+ impl < RW : Read + Write > Read for DelayedReadRW < RW > {
307+ fn read ( & mut self , buf : & mut [ u8 ] ) -> io:: Result < usize > {
308+ thread:: sleep ( self . read_delay ) ;
309+ self . s . read ( buf)
310+ }
311+ }
312+
313+ impl < RW : Read + Write > Write for DelayedReadRW < RW > {
314+ fn write ( & mut self , buf : & [ u8 ] ) -> io:: Result < usize > {
315+ self . s . write ( buf)
316+ }
317+
318+ fn flush ( & mut self ) -> io:: Result < ( ) > {
319+ self . s . flush ( )
320+ }
321+ }
213322#[ test]
214323fn it_does_not_connect_tls_client_only ( ) {
215324 // Client requesting tls fails as expected when server does not support it.
0 commit comments