@@ -195,6 +195,27 @@ pub trait MysqlShim<W: Read + Write> {
195195 fn tls_config ( & self ) -> Option < std:: sync:: Arc < rustls:: ServerConfig > > {
196196 None
197197 }
198+
199+ /// Called after successful authentication (including TLS if applicable) passing relevant
200+ /// information to allow additional logic in the MySqlShim implementation.
201+ fn after_authentication (
202+ & mut self ,
203+ _context : & AuthenticationContext < ' _ > ,
204+ ) -> Result < ( ) , Self :: Error > {
205+ Ok ( ( ) )
206+ }
207+ }
208+
209+ /// Information about an authenticated user
210+ #[ derive( Debug , Default , Clone , PartialEq ) ]
211+ pub struct AuthenticationContext < ' a > {
212+ /// The username exactly as passed by the client,
213+ pub username : Option < Vec < u8 > > ,
214+ #[ cfg( feature = "tls" ) ]
215+ /// The TLS certificate chain presented by the client.
216+ pub tls_client_certs : Option < & ' a [ rustls:: Certificate ] > ,
217+ #[ cfg( not( feature = "tls" ) ) ]
218+ _pd : Option < & ' a std:: marker:: PhantomData < ( ) > > ,
198219}
199220
200221/// A server that speaks the MySQL/MariaDB protocol, and can delegate client commands to a backend
@@ -265,6 +286,8 @@ impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
265286 self . rw . write_all ( & b">o6^Wz!/kM}N\0 " [ ..] ) ?; // 4.1+ servers must extend salt
266287 self . rw . flush ( ) ?;
267288
289+ let mut auth_context = AuthenticationContext :: default ( ) ;
290+
268291 {
269292 let ( seq, handshake) = self . rw . next ( ) ?. ok_or_else ( || {
270293 io:: Error :: new (
@@ -300,6 +323,8 @@ impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
300323 } ) ?
301324 . 1 ;
302325
326+ auth_context. username = handshake. username . map ( |x| x. to_vec ( ) ) ;
327+
303328 self . rw . set_seq ( seq + 1 ) ;
304329
305330 #[ cfg( not( feature = "tls" ) ) ]
@@ -328,7 +353,8 @@ impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
328353 "peer terminated connection" ,
329354 )
330355 } ) ?;
331- let _handshake = commands:: client_handshake ( & handshake, true )
356+
357+ let handshake = commands:: client_handshake ( & handshake, true )
332358 . map_err ( |e| match e {
333359 nom:: Err :: Incomplete ( _) => io:: Error :: new (
334360 io:: ErrorKind :: UnexpectedEof ,
@@ -356,7 +382,21 @@ impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
356382 } ) ?
357383 . 1 ;
358384
385+ auth_context. username = handshake. username . map ( |x| x. to_vec ( ) ) ;
386+
359387 self . rw . set_seq ( seq + 1 ) ;
388+
389+ auth_context. tls_client_certs = self . rw . tls_certs ( ) ;
390+ }
391+
392+ if let Err ( e) = self . shim . after_authentication ( & auth_context) {
393+ writers:: write_err (
394+ ErrorKind :: ER_ACCESS_DENIED_ERROR ,
395+ "client authentication failed" . as_ref ( ) ,
396+ & mut self . rw ,
397+ ) ?;
398+ self . rw . flush ( ) ?;
399+ return Err ( e) ;
360400 }
361401 }
362402
0 commit comments