Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ codecov = { repository = "jonhoo/msql-srv", branch = "master", service = "github
maintenance = { status = "experimental" }

[dependencies]
nom = "5"
mysql_common = "0.22"
byteorder = "1"
nom = "7.0.0-alpha2"
mysql_common = "0.27"
byteorder = "1.4"
chrono = "0.4"
time = "0.2.25"

[dev-dependencies]
mysql = "21"
postgres = "0.19.1"
mysql = "18"
mysql_async = "0.20.0"
slab = "0.4.2"
tokio = "0.1.19"
slab = "0.4.3"
tokio = { version = "1.0", features = ["full"] }
futures = "0.1.26"
bytes = "1"
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ resources:
type: github
name: crate-ci/azure-pipelines
ref: refs/heads/v0.4
endpoint: jonhoo
endpoint: datafuse-extras-msql-srv
4 changes: 3 additions & 1 deletion examples/psql_as_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ fn main() {
});

// we connect using MySQL bindings, but no MySQL server is running!
let mut db = mysql::Conn::new(&format!("mysql://127.0.0.1:{}", port)).unwrap();
let mut db =
mysql::Conn::new(mysql::Opts::from_url(&format!("mysql://127.0.0.1:{}", port)).unwrap())
.unwrap();
assert_eq!(db.ping(), true);
{
let mut results = db
Expand Down
113 changes: 113 additions & 0 deletions examples/serve_auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
//! After running this, you should be able to run:
//!
//! ```console
//! $ echo "SELECT * FROM foo" | mysql -h 127.0.0.1 --table
//! $
//! ```

extern crate msql_srv;
extern crate mysql;
extern crate mysql_common as myc;

use msql_srv::*;
use std::io;
use std::net;
use std::thread;

struct Backend;
impl<W: io::Write> MysqlShim<W> for Backend {
type Error = io::Error;

fn on_prepare(&mut self, _: &str, info: StatementMetaWriter<W>) -> io::Result<()> {
info.reply(42, &[], &[])
}
fn on_execute(
&mut self,
_: u32,
_: msql_srv::ParamParser,
results: QueryResultWriter<W>,
) -> io::Result<()> {
results.completed(0, 0)
}
fn on_close(&mut self, _: u32) {}

fn on_query(&mut self, sql: &str, results: QueryResultWriter<W>) -> io::Result<()> {
println!("execute sql {:?}", sql);
results.start(&[])?.finish()
}

/// authenticate method for the specified plugin
fn authenticate(
&self,
auth_plugin: &str,
username: &[u8],
salt: &[u8],
auth_data: &[u8],
) -> bool {
println!(
"auth_plugin, {:?}, user: {:?} , salt: {:?}, auth_data:{:?}",
auth_plugin, username, salt, auth_data
);

username == "default".as_bytes()
}

fn on_init(&mut self, _: &str, _: InitWriter<'_, W>) -> Result<(), Self::Error> {
Ok(())
}

fn version(&self) -> &str {
// 5.1.10 because that's what Ruby's ActiveRecord requires
"5.1.10-alpha-msql-proxy"
}

fn connect_id(&self) -> u32 {
u32::from_le_bytes([0x08, 0x00, 0x00, 0x00])
}

fn default_auth_plugin(&self) -> &str {
"mysql_native_password"
}

fn auth_plugin_for_username(&self, _user: &[u8]) -> &str {
"mysql_native_password"
}

fn salt(&self) -> [u8; 20] {
let bs = ";X,po_k}>o6^Wz!/kM}N".as_bytes();
let mut scramble: [u8; 20] = [0; 20];
for i in 0..20 {
scramble[i] = bs[i];
if scramble[i] == b'\0' || scramble[i] == b'$' {
scramble[i] = scramble[i] + 1;
}
}
scramble
}
}

fn main() {
let mut threads = Vec::new();
let listener = net::TcpListener::bind("127.0.0.1:3306").unwrap();

while let Ok((s, _)) = listener.accept() {
println!("{:?}", "got one socket");
threads.push(thread::spawn(move || {
MysqlIntermediary::run_on_tcp(Backend, s).unwrap();
}));
}

for t in threads {
t.join().unwrap();
}
}

#[test]
fn it_works() {
let c: u8 = b'\0';
let d: u8 = 0 as u8;
let e: u8 = 0x00;

assert_eq!(c, d);
assert_eq!(e, d);
}
4 changes: 3 additions & 1 deletion examples/serve_one.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ impl<W: io::Write> MysqlShim<W> for Backend {
}
fn on_close(&mut self, _: u32) {}

fn on_query(&mut self, _: &str, results: QueryResultWriter<W>) -> io::Result<()> {
fn on_query(&mut self, sql: &str, results: QueryResultWriter<W>) -> io::Result<()> {
println!("execute sql {:?}", sql);
results.start(&[])?.finish()
}
}
Expand All @@ -40,6 +41,7 @@ fn main() {
let listener = net::TcpListener::bind("127.0.0.1:3306").unwrap();

while let Ok((s, _)) = listener.accept() {
println!("{:?}", "got one socket");
threads.push(thread::spawn(move || {
MysqlIntermediary::run_on_tcp(Backend, s).unwrap();
}));
Expand Down
103 changes: 90 additions & 13 deletions src/commands.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,78 @@
use crate::myc::constants::{CapabilityFlags, Command as CommandByte};

#[derive(Debug)]
pub struct ClientHandshake<'a> {
pub struct ClientHandshake {
capabilities: CapabilityFlags,
maxps: u32,
collation: u16,
username: &'a [u8],
pub(crate) db: Option<Vec<u8>>,
pub(crate) username: Vec<u8>,
pub(crate) auth_response: Vec<u8>,
pub(crate) auth_plugin: Vec<u8>,
}

pub fn client_handshake(i: &[u8]) -> nom::IResult<&[u8], ClientHandshake<'_>> {
#[allow(clippy::branches_sharing_code)]
pub fn client_handshake(i: &[u8]) -> nom::IResult<&[u8], ClientHandshake> {
// mysql handshake protocol documentation
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html

let (i, cap) = nom::number::complete::le_u16(i)?;

if CapabilityFlags::from_bits_truncate(cap as u32).contains(CapabilityFlags::CLIENT_PROTOCOL_41)
{
let mut capabilities = CapabilityFlags::from_bits_truncate(cap as u32);
if capabilities.contains(CapabilityFlags::CLIENT_PROTOCOL_41) {
// HandshakeResponse41
let (i, cap2) = nom::number::complete::le_u16(i)?;
let cap = (cap2 as u32) << 16 | cap as u32;

capabilities = CapabilityFlags::from_bits_truncate(cap as u32);

let (i, maxps) = nom::number::complete::le_u32(i)?;
let (i, collation) = nom::bytes::complete::take(1u8)(i)?;

let (i, _) = nom::bytes::complete::take(23u8)(i)?;

let (i, username) = nom::bytes::complete::take_until(&b"\0"[..])(i)?;
let (i, _) = nom::bytes::complete::tag(b"\0")(i)?;

let (i, auth_response) =
if capabilities.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) {
let (i, size) = read_length_encoded_number(i)?;
nom::bytes::complete::take(size)(i)?
} else if capabilities.contains(CapabilityFlags::CLIENT_SECURE_CONNECTION) {
let (i, size) = nom::number::complete::le_u8(i)?;
nom::bytes::complete::take(size)(i)?
} else {
nom::bytes::complete::take_until(&b"\0"[..])(i)?
};

let (i, db) =
if capabilities.contains(CapabilityFlags::CLIENT_CONNECT_WITH_DB) && !i.is_empty() {
let (i, db) = nom::bytes::complete::take_until(&b"\0"[..])(i)?;
let (i, _) = nom::bytes::complete::tag(b"\0")(i)?;
(i, Some(db))
} else {
(i, None)
};

let (i, auth_plugin) =
if capabilities.contains(CapabilityFlags::CLIENT_PLUGIN_AUTH) && !i.is_empty() {
let (i, auth_plugin) = nom::bytes::complete::take_until(&b"\0"[..])(i)?;

let (i, _) = nom::bytes::complete::tag(b"\0")(i)?;
(i, auth_plugin)
} else {
(i, &b""[..])
};

Ok((
i,
ClientHandshake {
capabilities: CapabilityFlags::from_bits_truncate(cap),
maxps,
collation: u16::from(collation[0]),
username,
username: username.to_vec(),
db: db.map(|c| c.to_vec()),
auth_response: auth_response.to_vec(),
auth_plugin: auth_plugin.to_vec(),
},
))
} else {
Expand All @@ -41,19 +81,51 @@ pub fn client_handshake(i: &[u8]) -> nom::IResult<&[u8], ClientHandshake<'_>> {
let (i, maxps2) = nom::number::complete::le_u8(i)?;
let maxps = (maxps2 as u32) << 16 | maxps1 as u32;
let (i, username) = nom::bytes::complete::take_until(&b"\0"[..])(i)?;
let (i, _) = nom::bytes::complete::tag(b"\0")(i)?;

let (i, auth_response, db) =
if capabilities.contains(CapabilityFlags::CLIENT_CONNECT_WITH_DB) {
let (i, auth_response) = nom::bytes::complete::tag(b"\0")(i)?;
let (i, _) = nom::bytes::complete::tag(b"\0")(i)?;

let (i, db) = nom::bytes::complete::tag(b"\0")(i)?;
let (i, _) = nom::bytes::complete::tag(b"\0")(i)?;

(i, auth_response, Some(db))
} else {
(&b""[..], i, None)
};

Ok((
i,
ClientHandshake {
capabilities: CapabilityFlags::from_bits_truncate(cap as u32),
maxps,
collation: 0,
username,
username: username.to_vec(),
db: db.map(|c| c.to_vec()),
auth_response: auth_response.to_vec(),
auth_plugin: vec![],
},
))
}
}

fn read_length_encoded_number(i: &[u8]) -> nom::IResult<&[u8], u64> {
let (i, b) = nom::number::complete::le_u8(i)?;
let size: usize = match b {
0xfb => return Ok((i, 0)),
0xfc => 2,
0xfd => 3,
0xfe => 8,
_ => return Ok((i, b as u64)),
};
let mut bytes = [0u8; 8];
let (i, b) = nom::bytes::complete::take(size)(i)?;
bytes[..size].copy_from_slice(b);
Ok((i, u64::from_le_bytes(bytes)))
}

#[derive(Debug, PartialEq, Eq)]
pub enum Command<'a> {
Query(&'a [u8]),
Expand Down Expand Up @@ -142,10 +214,15 @@ mod tests {
#[test]
fn it_parses_handshake() {
let data = &[
0x25, 0x00, 0x00, 0x01, 0x85, 0xa6, 0x3f, 0x20, 0x00, 0x00, 0x00, 0x01, 0x21, 0x00,
0x5b, 0x00, 0x00, 0x01, 0x8d, 0xa6, 0xff, 0x09, 0x00, 0x00, 0x00, 0x01, 0x21, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6a, 0x6f, 0x6e, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c,
0x74, 0x00, 0x14, 0xf7, 0xd1, 0x6c, 0xe9, 0x0d, 0x2f, 0x34, 0xb0, 0x2f, 0xd8, 0x1d,
0x18, 0xc7, 0xa4, 0xe8, 0x98, 0x97, 0x67, 0xeb, 0xad, 0x64, 0x65, 0x66, 0x61, 0x75,
0x6c, 0x74, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76,
0x65, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00,
];

let r = Cursor::new(&data[..]);
let mut pr = PacketReader::new(r);
let (_, p) = pr.next().unwrap().unwrap();
Expand All @@ -157,14 +234,14 @@ mod tests {
assert!(handshake
.capabilities
.contains(CapabilityFlags::CLIENT_MULTI_RESULTS));
assert!(!handshake
assert!(handshake
.capabilities
.contains(CapabilityFlags::CLIENT_CONNECT_WITH_DB));
assert!(!handshake
assert!(handshake
.capabilities
.contains(CapabilityFlags::CLIENT_DEPRECATE_EOF));
assert_eq!(handshake.collation, UTF8_GENERAL_CI);
assert_eq!(handshake.username, &b"jon"[..]);
assert_eq!(handshake.username, &b"default"[..]);
assert_eq!(handshake.maxps, 16777216);
}

Expand Down
Loading