Skip to content

Commit 26a5c35

Browse files
authored
Merge pull request #36 from mjgarton/further_tls_functionality
Further authentication & tls functionality
2 parents 492f0a0 + 1d0a72c commit 26a5c35

File tree

6 files changed

+293
-55
lines changed

6 files changed

+293
-55
lines changed

Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,8 @@ slab = "0.4.2"
4141
tokio = { version = "1.15.0", features = ["full"] }
4242
futures = "0.3.0"
4343
rcgen = "0.8.14"
44+
tempfile = "3.3.0"
45+
native-tls = "0.2.8"
46+
47+
[target.'cfg(unix)'.dev-dependencies]
48+
openssl = "0.10.38"

src/commands.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pub struct ClientHandshake<'a> {
66
pub capabilities: CapabilityFlags,
77
maxps: u32,
88
collation: u16,
9-
username: Option<&'a [u8]>,
9+
pub(crate) username: Option<&'a [u8]>,
1010
}
1111

1212
pub fn client_handshake(i: &[u8], after_tls: bool) -> nom::IResult<&[u8], ClientHandshake<'_>> {

src/lib.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,27 @@ pub trait MysqlShim<W: Read + Write> {
195195
fn tls_config(&self) -> Option<std::sync::Arc<rustls::ServerConfig>> {
196196
None
197197
}
198+
199+
/// Called after successful authentication (including TLS if applicable) passing relevant
200+
/// information to allow additional logic in the MySqlShim implementation.
201+
fn after_authentication(
202+
&mut self,
203+
_context: &AuthenticationContext<'_>,
204+
) -> Result<(), Self::Error> {
205+
Ok(())
206+
}
207+
}
208+
209+
/// Information about an authenticated user
210+
#[derive(Debug, Default, Clone, PartialEq)]
211+
pub struct AuthenticationContext<'a> {
212+
/// The username exactly as passed by the client,
213+
pub username: Option<Vec<u8>>,
214+
#[cfg(feature = "tls")]
215+
/// The TLS certificate chain presented by the client.
216+
pub tls_client_certs: Option<&'a [rustls::Certificate]>,
217+
#[cfg(not(feature = "tls"))]
218+
_pd: Option<&'a std::marker::PhantomData<()>>,
198219
}
199220

200221
/// A server that speaks the MySQL/MariaDB protocol, and can delegate client commands to a backend
@@ -265,6 +286,8 @@ impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
265286
self.rw.write_all(&b">o6^Wz!/kM}N\0"[..])?; // 4.1+ servers must extend salt
266287
self.rw.flush()?;
267288

289+
let mut auth_context = AuthenticationContext::default();
290+
268291
{
269292
let (seq, handshake) = self.rw.next()?.ok_or_else(|| {
270293
io::Error::new(
@@ -300,6 +323,8 @@ impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
300323
})?
301324
.1;
302325

326+
auth_context.username = handshake.username.map(|x| x.to_vec());
327+
303328
self.rw.set_seq(seq + 1);
304329

305330
#[cfg(not(feature = "tls"))]
@@ -328,7 +353,8 @@ impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
328353
"peer terminated connection",
329354
)
330355
})?;
331-
let _handshake = commands::client_handshake(&handshake, true)
356+
357+
let handshake = commands::client_handshake(&handshake, true)
332358
.map_err(|e| match e {
333359
nom::Err::Incomplete(_) => io::Error::new(
334360
io::ErrorKind::UnexpectedEof,
@@ -356,7 +382,21 @@ impl<B: MysqlShim<RW>, RW: Read + Write> MysqlIntermediary<B, RW> {
356382
})?
357383
.1;
358384

385+
auth_context.username = handshake.username.map(|x| x.to_vec());
386+
359387
self.rw.set_seq(seq + 1);
388+
389+
auth_context.tls_client_certs = self.rw.tls_certs();
390+
}
391+
392+
if let Err(e) = self.shim.after_authentication(&auth_context) {
393+
writers::write_err(
394+
ErrorKind::ER_ACCESS_DENIED_ERROR,
395+
"client authentication failed".as_ref(),
396+
&mut self.rw,
397+
)?;
398+
self.rw.flush()?;
399+
return Err(e);
360400
}
361401
}
362402

src/packet.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use byteorder::{ByteOrder, LittleEndian};
22
#[cfg(feature = "tls")]
3-
use rustls::ServerConfig;
3+
use rustls::{Certificate, ServerConfig};
44
use std::io;
55
use std::io::prelude::*;
66

@@ -83,6 +83,14 @@ impl<W: Read + Write> PacketConn<W> {
8383
self.remaining = 0;
8484
res
8585
}
86+
87+
#[cfg(feature = "tls")]
88+
pub fn tls_certs(&self) -> Option<&[Certificate]> {
89+
match &self.rw.0 {
90+
Some(tls::EitherConn::Tls(tls_conn)) => tls_conn.conn.peer_certificates(),
91+
_ => None,
92+
}
93+
}
8694
}
8795

8896
impl<W: Read + Write> PacketConn<W> {

src/tls.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub fn create_stream<T: Read + Write + Sized>(
1313
Ok(stream)
1414
}
1515

16-
pub(crate) struct SwitchableConn<T: Read + Write>(Option<EitherConn<T>>);
16+
pub(crate) struct SwitchableConn<T: Read + Write>(pub(crate) Option<EitherConn<T>>);
1717

1818
pub(crate) enum EitherConn<T: Read + Write> {
1919
Plain(T),

0 commit comments

Comments
 (0)