@@ -5,6 +5,14 @@ extern crate mysql_common as myc;
55extern crate nom;
66
77use mysql:: prelude:: * ;
8+ use mysql:: DriverError ;
9+ use mysql:: OptsBuilder ;
10+ use mysql:: SslOpts ;
11+ use rcgen:: generate_simple_self_signed;
12+ use rustls:: Certificate ;
13+ use rustls:: PrivateKey ;
14+ use rustls:: ServerConfig ;
15+ use std:: error:: Error ;
816use std:: io;
917use std:: net;
1018use std:: thread;
@@ -21,6 +29,8 @@ struct TestingShim<Q, P, E, I> {
2129 on_p : P ,
2230 on_e : E ,
2331 on_i : I ,
32+ server_tls : Option < rustls:: ServerConfig > ,
33+ client_tls : Option < SslOpts > ,
2434}
2535
2636impl < Q , P , E , I > MysqlShim < net:: TcpStream > for TestingShim < Q , P , E , I >
6373 ) -> io:: Result < ( ) > {
6474 ( self . on_q ) ( query, results)
6575 }
76+
77+ fn tls_config ( & self ) -> Option < & rustls:: ServerConfig > {
78+ self . server_tls . as_ref ( )
79+ }
6680}
6781
6882impl < Q , P , E , I > TestingShim < Q , P , E , I >
8296 on_p,
8397 on_e,
8498 on_i,
99+ server_tls : None ,
100+ client_tls : None ,
85101 }
86102 }
87103
@@ -95,21 +111,59 @@ where
95111 self
96112 }
97113
114+ fn with_server_tls ( mut self ) -> Self {
115+ let cert = generate_simple_self_signed ( vec ! [ "localhost" . to_string( ) ] ) . unwrap ( ) ;
116+
117+ self . server_tls = Some (
118+ ServerConfig :: builder ( )
119+ . with_safe_defaults ( )
120+ . with_no_client_auth ( )
121+ . with_single_cert (
122+ vec ! [ Certificate ( cert. serialize_der( ) . unwrap( ) ) ] ,
123+ PrivateKey ( cert. get_key_pair ( ) . serialize_der ( ) ) ,
124+ )
125+ . unwrap ( ) ,
126+ ) ;
127+
128+ self
129+ }
130+
131+ fn with_client_tls ( mut self ) -> Self {
132+ self . client_tls = Some ( SslOpts :: default ( ) . with_danger_accept_invalid_certs ( true ) ) ;
133+ self
134+ }
135+
98136 fn test < C > ( self , c : C )
99137 where
100138 C : FnOnce ( & mut mysql:: Conn ) -> ( ) ,
101139 {
140+ self . test_with_result ( c) . unwrap ( )
141+ }
142+
143+ fn test_with_result < C > ( self , c : C ) -> Result < ( ) , Box < dyn Error + ' static > >
144+ where
145+ C : FnOnce ( & mut mysql:: Conn ) -> ( ) ,
146+ {
147+ let client_tls = self . client_tls . clone ( ) ;
148+
102149 let listener = net:: TcpListener :: bind ( "127.0.0.1:0" ) . unwrap ( ) ;
103150 let port = listener. local_addr ( ) . unwrap ( ) . port ( ) ;
104151 let jh = thread:: spawn ( move || {
105152 let ( s, _) = listener. accept ( ) . unwrap ( ) ;
106153 MysqlIntermediary :: run_on_tcp ( self , s)
107154 } ) ;
108155
109- let mut db = mysql:: Conn :: new ( & format ! ( "mysql://127.0.0.1:{}" , port) ) . unwrap ( ) ;
156+ let opts = OptsBuilder :: default ( )
157+ . ip_or_hostname ( Some ( "localhost" ) )
158+ . tcp_port ( port)
159+ . ssl_opts ( client_tls) ;
160+
161+ let mut db = mysql:: Conn :: new ( opts) ?;
110162 c ( & mut db) ;
111163 drop ( db) ;
112164 jh. join ( ) . unwrap ( ) . unwrap ( ) ;
165+
166+ Ok ( ( ) )
113167 }
114168}
115169
@@ -121,9 +175,60 @@ fn it_connects() {
121175 |_, _, _| unreachable ! ( ) ,
122176 |_, _| unreachable ! ( ) ,
123177 )
178+ . test ( |_| { } ) ;
179+ }
180+
181+ #[ test]
182+ fn it_connects_tls_server_only ( ) {
183+ // Client can connect ok without SSL when SSL is enabled on the server.
184+ TestingShim :: new (
185+ |_, _| unreachable ! ( ) ,
186+ |_| unreachable ! ( ) ,
187+ |_, _, _| unreachable ! ( ) ,
188+ |_, _| unreachable ! ( ) ,
189+ )
190+ . with_server_tls ( )
191+ . test ( |_| { } )
192+ }
193+
194+ #[ test]
195+ fn it_connects_tls_both ( ) {
196+ // SSL connection when ssl enabled on server and used by client
197+ TestingShim :: new (
198+ |_, _| unreachable ! ( ) ,
199+ |_| unreachable ! ( ) ,
200+ |_, _, _| unreachable ! ( ) ,
201+ |_, _| unreachable ! ( ) ,
202+ )
203+ . with_server_tls ( )
204+ . with_client_tls ( )
124205 . test ( |_| { } )
125206}
126207
208+ #[ test]
209+ fn it_does_not_connect_tls_client_only ( ) {
210+ // Client requesting tls fails as expected when server does not support it.
211+ match TestingShim :: new (
212+ |_, _| unreachable ! ( ) ,
213+ |_| unreachable ! ( ) ,
214+ |_, _, _| unreachable ! ( ) ,
215+ |_, _| unreachable ! ( ) ,
216+ )
217+ . with_client_tls ( )
218+ . test_with_result ( |_| { } )
219+ {
220+ Ok ( ( ) ) => {
221+ panic ! ( "client should not have connected" )
222+ }
223+ Err ( e) => match e. downcast_ref :: < mysql:: Error > ( ) {
224+ Some ( mysql:: Error :: DriverError ( DriverError :: TlsNotSupported ) ) => {
225+ // this is what we expect.
226+ }
227+ _ => panic ! ( "unexpected error {}" , e) ,
228+ } ,
229+ }
230+ }
231+
127232#[ test]
128233fn it_inits_ok ( ) {
129234 TestingShim :: new (
0 commit comments