diff --git a/src/commands.rs b/src/commands.rs index 1f4085b..e437a1c 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -6,6 +6,24 @@ pub struct ClientHandshake<'a> { maxps: u32, collation: u16, username: &'a [u8], + pub database: Option<&'a [u8]>, +} + +fn lenenc_int<'a, E: nom::error::ParseError<&'a [u8]>>( + i: &'a [u8], +) -> nom::IResult<&'a [u8], u64, E> { + let (i, x) = nom::number::complete::le_u8(i)?; + match x { + x if x < 0xfc => Ok((i, x.into())), + 0xfc => nom::number::complete::le_u16(i).map(|(i, v)| (i, v as u64)), + 0xfd => nom::number::complete::le_u24(i).map(|(i, v)| (i, v as u64)), + 0xfe => nom::number::complete::le_u64(i).map(|(i, v)| (i, v as u64)), + 0xff => Err(nom::Err::Error(nom::error::make_error( + i, + nom::error::ErrorKind::Char, + ))), + _ => unreachable!(), + } } pub fn client_handshake(i: &[u8]) -> nom::IResult<&[u8], ClientHandshake<'_>> { @@ -19,20 +37,38 @@ pub fn client_handshake(i: &[u8]) -> nom::IResult<&[u8], ClientHandshake<'_>> { // HandshakeResponse41 let (i, cap2) = nom::number::complete::le_u16(i)?; let cap = (cap2 as u32) << 16 | cap as u32; + let capabilities = CapabilityFlags::from_bits_truncate(cap); let (i, maxps) = nom::number::complete::le_u32(i)?; let (i, collation) = nom::bytes::complete::take(1u8)(i)?; let (i, _) = nom::bytes::complete::take(23u8)(i)?; let (i, username) = nom::bytes::complete::take_until(&b"\0"[..])(i)?; - let (i, _) = nom::bytes::complete::tag(b"\0")(i)?; - + let (mut i, _) = nom::bytes::complete::tag(b"\0")(i)?; + if capabilities.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) { + let (i2, auth_response_length) = lenenc_int(i)?; + let (i2, _) = nom::bytes::complete::take(auth_response_length)(i2)?; + i = i2; + } else if capabilities.contains(CapabilityFlags::CLIENT_SECURE_CONNECTION) { + let (i2, auth_response_length) = nom::number::complete::le_u8(i)?; + let (i2, _) = nom::bytes::complete::take(auth_response_length)(i2)?; + i = i2; + } else { + let (i2, _) = nom::bytes::complete::tag(b"\0")(i)?; + i = i2; + } + let mut database = None; + if capabilities.contains(CapabilityFlags::CLIENT_CONNECT_WITH_DB) { + let (_, database2) = nom::bytes::complete::tag(b"\0")(i)?; + database = Some(database2); + } Ok(( i, ClientHandshake { - capabilities: CapabilityFlags::from_bits_truncate(cap), + capabilities: capabilities, maxps, collation: u16::from(collation[0]), username, + database, }, )) } else { @@ -49,6 +85,7 @@ pub fn client_handshake(i: &[u8]) -> nom::IResult<&[u8], ClientHandshake<'_>> { maxps, collation: 0, username, + database: None, }, )) } diff --git a/src/lib.rs b/src/lib.rs index 1ac50ad..c782c65 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,6 +98,7 @@ extern crate mysql_common as myc; use std::collections::HashMap; use std::io; use std::io::prelude::*; +use std::io::Write; use std::iter; use std::net; @@ -253,39 +254,45 @@ impl, R: Read, W: Write> MysqlIntermediary { self.writer.write_all(&b">o6^Wz!/kM}N\0"[..])?; // 4.1+ servers must extend salt self.writer.flush()?; - { - let (seq, handshake) = self.reader.next()?.ok_or_else(|| { - io::Error::new( - io::ErrorKind::ConnectionAborted, - "peer terminated connection", - ) - })?; - let _handshake = commands::client_handshake(&handshake) - .map_err(|e| match e { - nom::Err::Incomplete(_) => io::Error::new( - io::ErrorKind::UnexpectedEof, - "client sent incomplete handshake", - ), - nom::Err::Failure((input, nom_e_kind)) - | nom::Err::Error((input, nom_e_kind)) => { - if let nom::error::ErrorKind::Eof = nom_e_kind { - io::Error::new( - io::ErrorKind::UnexpectedEof, - format!("client did not complete handshake; got {:?}", input), - ) - } else { - io::Error::new( - io::ErrorKind::InvalidData, - format!("bad client handshake; got {:?} ({:?})", input, nom_e_kind), - ) - } + let (seq, handshake) = self.reader.next()?.ok_or_else(|| { + io::Error::new( + io::ErrorKind::ConnectionAborted, + "peer terminated connection", + ) + })?; + let handshake = commands::client_handshake(&handshake) + .map_err(|e| match e { + nom::Err::Incomplete(_) => io::Error::new( + io::ErrorKind::UnexpectedEof, + "client sent incomplete handshake", + ), + nom::Err::Failure((input, nom_e_kind)) | nom::Err::Error((input, nom_e_kind)) => { + if let nom::error::ErrorKind::Eof = nom_e_kind { + io::Error::new( + io::ErrorKind::UnexpectedEof, + format!("client did not complete handshake; got {:?}", input), + ) + } else { + io::Error::new( + io::ErrorKind::InvalidData, + format!("bad client handshake; got {:?} ({:?})", input, nom_e_kind), + ) } - })? - .1; - self.writer.set_seq(seq + 1); + } + })? + .1; + self.writer.set_seq(seq + 1); + if let Some(Ok(database)) = handshake.database.map(std::str::from_utf8) { + self.shim.on_init( + database, + InitWriter { + writer: &mut self.writer, + }, + )?; + } else { + writers::write_ok_packet(&mut self.writer, 0, 0, StatusFlags::empty())?; } - writers::write_ok_packet(&mut self.writer, 0, 0, StatusFlags::empty())?; self.writer.flush()?; Ok(())