Skip to content

Commit 441aefc

Browse files
authored
A 'USE schema' query invokes 'on_init' (#12)
1 parent 9822842 commit 441aefc

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

src/lib.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,8 @@ impl<B: MysqlShim<W>, R: Read, W: Write> MysqlIntermediary<B, R, W> {
300300
let cmd = commands::parse(&packet).unwrap().1;
301301
match cmd {
302302
Command::Query(q) => {
303-
let w = QueryResultWriter::new(&mut self.writer, false);
304303
if q.starts_with(b"SELECT @@") || q.starts_with(b"select @@") {
304+
let w = QueryResultWriter::new(&mut self.writer, false);
305305
let var = &q[b"SELECT @@".len()..];
306306
match var {
307307
b"max_allowed_packet" => {
@@ -319,7 +319,16 @@ impl<B: MysqlShim<W>, R: Read, W: Write> MysqlIntermediary<B, R, W> {
319319
w.completed(0, 0)?;
320320
}
321321
}
322+
} else if q.starts_with(b"USE ") || q.starts_with(b"use ") {
323+
let w = InitWriter {
324+
writer: &mut self.writer,
325+
};
326+
let schema = ::std::str::from_utf8(&q[b"USE ".len()..])
327+
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
328+
let schema = schema.trim().trim_end_matches(';').trim_matches('`');
329+
self.shim.on_init(&schema, w)?;
322330
} else {
331+
let w = QueryResultWriter::new(&mut self.writer, false);
323332
self.shim.on_query(
324333
::std::str::from_utf8(q)
325334
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,

tests/main.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,23 @@ fn it_inits_error() {
155155
.test(|db| assert_eq!(false, db.select_db("test")));
156156
}
157157

158+
#[test]
159+
fn it_inits_on_use_query_ok() {
160+
TestingShim::new(
161+
|_, _| unreachable!(),
162+
|_| unreachable!(),
163+
|_, _, _| unreachable!(),
164+
|schema, writer| {
165+
assert_eq!(schema, "test");
166+
writer.ok()
167+
},
168+
)
169+
.test(|db| match db.query_drop("USE `test`;") {
170+
Ok(_) => assert!(true),
171+
Err(_) => assert!(false),
172+
});
173+
}
174+
158175
#[test]
159176
fn it_pings() {
160177
TestingShim::new(

0 commit comments

Comments
 (0)